Merge pull request #70 from aguzererler/fix/resolve-prs-56-58-60-4650038896702716891

🧪 Resolve PRs #56, #58, and #60
This commit is contained in:
ahmet guzererler 2026-03-21 17:54:10 +01:00 committed by GitHub
commit a7b8c996f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 172 additions and 28 deletions

View File

@ -1,22 +1,153 @@
import pandas as pd import pandas as pd
import numpy as np
import pytest
from tradingagents.dataflows.stockstats_utils import _clean_dataframe from tradingagents.dataflows.stockstats_utils import _clean_dataframe
def test_clean_dataframe_lowercases_columns(): def test_clean_dataframe_valid_data():
"""Test _clean_dataframe with valid data where no rows should be dropped."""
df = pd.DataFrame({ df = pd.DataFrame({
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"], "Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
"Open": [10.0, 11.0, 12.0], "Open": [10.0, 11.0, 12.0],
"HIGH": [10.5, 11.5, 12.5], "High": [10.5, 11.5, 12.5],
"low": [9.5, 10.5, 11.5], "Low": [9.5, 10.5, 11.5],
"ClOsE": [10.2, 11.2, 12.2], "Close": [10.2, 11.2, 12.2],
"Volume": [1000, 1100, 1200] "Volume": [100, 200, 300]
}) })
cleaned = _clean_dataframe(df) cleaned_df = _clean_dataframe(df.copy())
assert list(cleaned.columns) == ["date", "open", "high", "low", "close", "volume"] assert len(cleaned_df) == 3
assert len(cleaned) == 3 assert "date" in cleaned_df.columns
assert pd.api.types.is_datetime64_any_dtype(cleaned_df["date"])
def test_clean_dataframe_handles_non_string_columns(): # Check if price columns are correctly parsed as float/numeric
for col in ["open", "high", "low", "close", "volume"]:
assert pd.api.types.is_numeric_dtype(cleaned_df[col])
assert (cleaned_df[col] == df[col.capitalize()]).all()
def test_clean_dataframe_invalid_dates():
"""Test _clean_dataframe drops rows with invalid or missing dates."""
df = pd.DataFrame({
"Date": ["2023-01-01", "invalid_date", None],
"Open": [10.0, 11.0, 12.0],
"Close": [10.2, 11.2, 12.2]
})
cleaned_df = _clean_dataframe(df.copy())
assert len(cleaned_df) == 1
assert cleaned_df.iloc[0]["date"] == pd.to_datetime("2023-01-01")
def test_clean_dataframe_missing_close():
"""Test _clean_dataframe drops rows where Close price is missing."""
df = pd.DataFrame({
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
"Open": [10.0, 11.0, 12.0],
"Close": [10.2, np.nan, 12.2]
})
cleaned_df = _clean_dataframe(df.copy())
assert len(cleaned_df) == 2
assert cleaned_df.iloc[0]["date"] == pd.to_datetime("2023-01-01")
assert cleaned_df.iloc[1]["date"] == pd.to_datetime("2023-01-03")
def test_clean_dataframe_numeric_coercion():
"""Test _clean_dataframe coerces non-numeric strings to NaN in price columns,
but handles ffill/bfill for them."""
df = pd.DataFrame({
"Date": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"Open": [10.0, "invalid", 12.0, 13.0],
"Close": [10.2, 11.2, 12.2, 13.2]
})
cleaned_df = _clean_dataframe(df.copy())
assert len(cleaned_df) == 4
# "invalid" is coerced to NaN, then ffill will fill it with 10.0 (from previous row)
assert cleaned_df.iloc[1]["open"] == 10.0
def test_clean_dataframe_ffill_bfill():
"""Test _clean_dataframe forward and backward fills missing values in price columns."""
df = pd.DataFrame({
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
"Open": [np.nan, 11.0, np.nan],
"Close": [10.2, 11.2, 12.2]
})
cleaned_df = _clean_dataframe(df.copy())
assert len(cleaned_df) == 3
# The first row Open is NaN -> bfill uses the next valid value (11.0)
assert cleaned_df.iloc[0]["open"] == 11.0
# The last row Open is NaN -> ffill uses the previous valid value (11.0)
assert cleaned_df.iloc[2]["open"] == 11.0
def test_clean_dataframe_empty():
"""Test _clean_dataframe with an empty DataFrame."""
df = pd.DataFrame(columns=["Date", "Open", "Close"])
cleaned_df = _clean_dataframe(df.copy())
assert len(cleaned_df) == 0
assert "date" in cleaned_df.columns
assert "open" in cleaned_df.columns
assert "close" in cleaned_df.columns
def test_clean_dataframe_missing_columns():
"""Test _clean_dataframe when some optional price columns are missing."""
df = pd.DataFrame({
"Date": ["2023-01-01", "2023-01-02"],
"Close": [10.2, 11.2]
})
cleaned_df = _clean_dataframe(df.copy())
assert len(cleaned_df) == 2
assert "close" in cleaned_df.columns
assert "open" not in cleaned_df.columns
def test_clean_dataframe_lowercase_columns():
"""Test _clean_dataframe successfully lowercases all column names."""
# Given a DataFrame with mixed case and uppercase columns
df = pd.DataFrame({
"Date": ["2023-01-01"],
"OPEN": [10.0],
"High": [10.5],
"loW": [9.5],
"Close": [10.2],
"Volume": [100]
})
# When _clean_dataframe is called
cleaned_df = _clean_dataframe(df)
# Then all columns should be lowercase
expected_columns = ["date", "open", "high", "low", "close", "volume"]
assert list(cleaned_df.columns) == expected_columns
# And the original DataFrame should not be mutated
assert list(df.columns) == ["Date", "OPEN", "High", "loW", "Close", "Volume"]
def test_clean_dataframe_non_string_columns():
"""Test _clean_dataframe successfully handles non-string column names by converting them to string then lowercase."""
# Given a DataFrame with integer columns (which won't match Date or Close processing but will be lowercased)
df = pd.DataFrame({
"Date": ["2023-01-01"],
"Close": [10.0],
0: [100.0],
1: [200.0]
})
# When _clean_dataframe is called
cleaned_df = _clean_dataframe(df)
# Then all columns should be strings and lowercase
expected_columns = ["date", "close", "0", "1"]
assert list(cleaned_df.columns) == expected_columns
def test_clean_dataframe_handles_no_date_or_close():
"""Test _clean_dataframe correctly formats column names if there's no date or close"""
df = pd.DataFrame({ df = pd.DataFrame({
1: [10.0, 11.0], 1: [10.0, 11.0],
"Open": [10.0, 11.0] "Open": [10.0, 11.0]
@ -25,13 +156,4 @@ def test_clean_dataframe_handles_non_string_columns():
cleaned = _clean_dataframe(df) cleaned = _clean_dataframe(df)
assert list(cleaned.columns) == ["1", "open"] assert list(cleaned.columns) == ["1", "open"]
assert len(cleaned) == 2
def test_clean_dataframe_does_not_mutate_original():
df = pd.DataFrame({
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
"Open": [10.0, 11.0, 12.0]
})
_clean_dataframe(df)
assert list(df.columns) == ["Date", "Open"]

View File

@ -60,15 +60,21 @@ class TestFindCol:
from tradingagents.dataflows.ttm_analysis import _find_col from tradingagents.dataflows.ttm_analysis import _find_col
self.find_col = _find_col self.find_col = _find_col
def test_find_col_success(self): def test_find_col_match(self):
df = pd.DataFrame({"Total Revenue": [100], "Gross Profit": [40]}) """Should return the matching column name."""
candidates = ["Revenue", "Total Revenue", "totalRevenue"] df = pd.DataFrame({"Revenue": [1, 2, 3], "Cost": [4, 5, 6]})
assert self.find_col(df, candidates) == "Total Revenue" assert self.find_col(df, ["Revenue", "Total Revenue"]) == "Revenue"
def test_find_col_no_match(self):
"""Should return None if no candidate matches."""
df = pd.DataFrame({"Cost": [4, 5, 6], "Profit": [7, 8, 9]})
assert self.find_col(df, ["Revenue", "Total Revenue"]) is None
def test_find_col_empty_df(self):
"""Should return None for empty DataFrame."""
df = pd.DataFrame()
assert self.find_col(df, ["Revenue", "Total Revenue"]) is None
def test_find_col_not_found(self):
df = pd.DataFrame({"Unrelated": [100], "Another": [40]})
candidates = ["Revenue", "Total Revenue", "totalRevenue"]
assert self.find_col(df, candidates) is None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Unit tests for compute_ttm_metrics # Unit tests for compute_ttm_metrics

View File

@ -7,9 +7,25 @@ from .config import get_config
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame: def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
"""Ensure DataFrame has lowercase columns for stockstats.""" """Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps.
Ensure DataFrame has lowercase columns for stockstats."""
df = data.copy() df = data.copy()
df.columns = [str(c).lower() for c in df.columns] df.columns = [str(c).lower() for c in df.columns]
if "date" in df.columns:
df["date"] = pd.to_datetime(df["date"], errors="coerce")
df = df.dropna(subset=["date"])
price_cols = [c for c in ["open", "high", "low", "close", "volume"] if c in df.columns]
if price_cols:
df[price_cols] = df[price_cols].apply(pd.to_numeric, errors="coerce")
if "close" in df.columns:
df = df.dropna(subset=["close"])
if price_cols:
df[price_cols] = df[price_cols].ffill().bfill()
return df return df