Merge pull request #70 from aguzererler/fix/resolve-prs-56-58-60-4650038896702716891
🧪 Resolve PRs #56, #58, and #60
This commit is contained in:
commit
a7b8c996f2
|
|
@ -1,22 +1,153 @@
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tradingagents.dataflows.stockstats_utils import _clean_dataframe
|
||||
|
||||
def test_clean_dataframe_lowercases_columns():
|
||||
def test_clean_dataframe_valid_data():
|
||||
"""Test _clean_dataframe with valid data where no rows should be dropped."""
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
|
||||
"Open": [10.0, 11.0, 12.0],
|
||||
"HIGH": [10.5, 11.5, 12.5],
|
||||
"low": [9.5, 10.5, 11.5],
|
||||
"ClOsE": [10.2, 11.2, 12.2],
|
||||
"Volume": [1000, 1100, 1200]
|
||||
"High": [10.5, 11.5, 12.5],
|
||||
"Low": [9.5, 10.5, 11.5],
|
||||
"Close": [10.2, 11.2, 12.2],
|
||||
"Volume": [100, 200, 300]
|
||||
})
|
||||
|
||||
cleaned = _clean_dataframe(df)
|
||||
cleaned_df = _clean_dataframe(df.copy())
|
||||
|
||||
assert list(cleaned.columns) == ["date", "open", "high", "low", "close", "volume"]
|
||||
assert len(cleaned) == 3
|
||||
assert len(cleaned_df) == 3
|
||||
assert "date" in cleaned_df.columns
|
||||
assert pd.api.types.is_datetime64_any_dtype(cleaned_df["date"])
|
||||
|
||||
def test_clean_dataframe_handles_non_string_columns():
|
||||
# Check if price columns are correctly parsed as float/numeric
|
||||
for col in ["open", "high", "low", "close", "volume"]:
|
||||
assert pd.api.types.is_numeric_dtype(cleaned_df[col])
|
||||
assert (cleaned_df[col] == df[col.capitalize()]).all()
|
||||
|
||||
def test_clean_dataframe_invalid_dates():
|
||||
"""Test _clean_dataframe drops rows with invalid or missing dates."""
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01", "invalid_date", None],
|
||||
"Open": [10.0, 11.0, 12.0],
|
||||
"Close": [10.2, 11.2, 12.2]
|
||||
})
|
||||
|
||||
cleaned_df = _clean_dataframe(df.copy())
|
||||
|
||||
assert len(cleaned_df) == 1
|
||||
assert cleaned_df.iloc[0]["date"] == pd.to_datetime("2023-01-01")
|
||||
|
||||
def test_clean_dataframe_missing_close():
|
||||
"""Test _clean_dataframe drops rows where Close price is missing."""
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
|
||||
"Open": [10.0, 11.0, 12.0],
|
||||
"Close": [10.2, np.nan, 12.2]
|
||||
})
|
||||
|
||||
cleaned_df = _clean_dataframe(df.copy())
|
||||
|
||||
assert len(cleaned_df) == 2
|
||||
assert cleaned_df.iloc[0]["date"] == pd.to_datetime("2023-01-01")
|
||||
assert cleaned_df.iloc[1]["date"] == pd.to_datetime("2023-01-03")
|
||||
|
||||
def test_clean_dataframe_numeric_coercion():
|
||||
"""Test _clean_dataframe coerces non-numeric strings to NaN in price columns,
|
||||
but handles ffill/bfill for them."""
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
|
||||
"Open": [10.0, "invalid", 12.0, 13.0],
|
||||
"Close": [10.2, 11.2, 12.2, 13.2]
|
||||
})
|
||||
|
||||
cleaned_df = _clean_dataframe(df.copy())
|
||||
|
||||
assert len(cleaned_df) == 4
|
||||
# "invalid" is coerced to NaN, then ffill will fill it with 10.0 (from previous row)
|
||||
assert cleaned_df.iloc[1]["open"] == 10.0
|
||||
|
||||
def test_clean_dataframe_ffill_bfill():
|
||||
"""Test _clean_dataframe forward and backward fills missing values in price columns."""
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
|
||||
"Open": [np.nan, 11.0, np.nan],
|
||||
"Close": [10.2, 11.2, 12.2]
|
||||
})
|
||||
|
||||
cleaned_df = _clean_dataframe(df.copy())
|
||||
|
||||
assert len(cleaned_df) == 3
|
||||
# The first row Open is NaN -> bfill uses the next valid value (11.0)
|
||||
assert cleaned_df.iloc[0]["open"] == 11.0
|
||||
# The last row Open is NaN -> ffill uses the previous valid value (11.0)
|
||||
assert cleaned_df.iloc[2]["open"] == 11.0
|
||||
|
||||
def test_clean_dataframe_empty():
|
||||
"""Test _clean_dataframe with an empty DataFrame."""
|
||||
df = pd.DataFrame(columns=["Date", "Open", "Close"])
|
||||
|
||||
cleaned_df = _clean_dataframe(df.copy())
|
||||
|
||||
assert len(cleaned_df) == 0
|
||||
assert "date" in cleaned_df.columns
|
||||
assert "open" in cleaned_df.columns
|
||||
assert "close" in cleaned_df.columns
|
||||
|
||||
def test_clean_dataframe_missing_columns():
|
||||
"""Test _clean_dataframe when some optional price columns are missing."""
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01", "2023-01-02"],
|
||||
"Close": [10.2, 11.2]
|
||||
})
|
||||
|
||||
cleaned_df = _clean_dataframe(df.copy())
|
||||
|
||||
assert len(cleaned_df) == 2
|
||||
assert "close" in cleaned_df.columns
|
||||
assert "open" not in cleaned_df.columns
|
||||
|
||||
def test_clean_dataframe_lowercase_columns():
|
||||
"""Test _clean_dataframe successfully lowercases all column names."""
|
||||
# Given a DataFrame with mixed case and uppercase columns
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01"],
|
||||
"OPEN": [10.0],
|
||||
"High": [10.5],
|
||||
"loW": [9.5],
|
||||
"Close": [10.2],
|
||||
"Volume": [100]
|
||||
})
|
||||
|
||||
# When _clean_dataframe is called
|
||||
cleaned_df = _clean_dataframe(df)
|
||||
|
||||
# Then all columns should be lowercase
|
||||
expected_columns = ["date", "open", "high", "low", "close", "volume"]
|
||||
assert list(cleaned_df.columns) == expected_columns
|
||||
|
||||
# And the original DataFrame should not be mutated
|
||||
assert list(df.columns) == ["Date", "OPEN", "High", "loW", "Close", "Volume"]
|
||||
|
||||
def test_clean_dataframe_non_string_columns():
|
||||
"""Test _clean_dataframe successfully handles non-string column names by converting them to string then lowercase."""
|
||||
# Given a DataFrame with integer columns (which won't match Date or Close processing but will be lowercased)
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01"],
|
||||
"Close": [10.0],
|
||||
0: [100.0],
|
||||
1: [200.0]
|
||||
})
|
||||
|
||||
# When _clean_dataframe is called
|
||||
cleaned_df = _clean_dataframe(df)
|
||||
|
||||
# Then all columns should be strings and lowercase
|
||||
expected_columns = ["date", "close", "0", "1"]
|
||||
assert list(cleaned_df.columns) == expected_columns
|
||||
|
||||
def test_clean_dataframe_handles_no_date_or_close():
|
||||
"""Test _clean_dataframe correctly formats column names if there's no date or close"""
|
||||
df = pd.DataFrame({
|
||||
1: [10.0, 11.0],
|
||||
"Open": [10.0, 11.0]
|
||||
|
|
@ -25,13 +156,4 @@ def test_clean_dataframe_handles_non_string_columns():
|
|||
cleaned = _clean_dataframe(df)
|
||||
|
||||
assert list(cleaned.columns) == ["1", "open"]
|
||||
|
||||
def test_clean_dataframe_does_not_mutate_original():
|
||||
df = pd.DataFrame({
|
||||
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
|
||||
"Open": [10.0, 11.0, 12.0]
|
||||
})
|
||||
|
||||
_clean_dataframe(df)
|
||||
|
||||
assert list(df.columns) == ["Date", "Open"]
|
||||
assert len(cleaned) == 2
|
||||
|
|
|
|||
|
|
@ -60,15 +60,21 @@ class TestFindCol:
|
|||
from tradingagents.dataflows.ttm_analysis import _find_col
|
||||
self.find_col = _find_col
|
||||
|
||||
def test_find_col_success(self):
|
||||
df = pd.DataFrame({"Total Revenue": [100], "Gross Profit": [40]})
|
||||
candidates = ["Revenue", "Total Revenue", "totalRevenue"]
|
||||
assert self.find_col(df, candidates) == "Total Revenue"
|
||||
def test_find_col_match(self):
|
||||
"""Should return the matching column name."""
|
||||
df = pd.DataFrame({"Revenue": [1, 2, 3], "Cost": [4, 5, 6]})
|
||||
assert self.find_col(df, ["Revenue", "Total Revenue"]) == "Revenue"
|
||||
|
||||
def test_find_col_no_match(self):
|
||||
"""Should return None if no candidate matches."""
|
||||
df = pd.DataFrame({"Cost": [4, 5, 6], "Profit": [7, 8, 9]})
|
||||
assert self.find_col(df, ["Revenue", "Total Revenue"]) is None
|
||||
|
||||
def test_find_col_empty_df(self):
|
||||
"""Should return None for empty DataFrame."""
|
||||
df = pd.DataFrame()
|
||||
assert self.find_col(df, ["Revenue", "Total Revenue"]) is None
|
||||
|
||||
def test_find_col_not_found(self):
|
||||
df = pd.DataFrame({"Unrelated": [100], "Another": [40]})
|
||||
candidates = ["Revenue", "Total Revenue", "totalRevenue"]
|
||||
assert self.find_col(df, candidates) is None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for compute_ttm_metrics
|
||||
|
|
|
|||
|
|
@ -7,9 +7,25 @@ from .config import get_config
|
|||
|
||||
|
||||
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Ensure DataFrame has lowercase columns for stockstats."""
|
||||
"""Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps.
|
||||
Ensure DataFrame has lowercase columns for stockstats."""
|
||||
df = data.copy()
|
||||
df.columns = [str(c).lower() for c in df.columns]
|
||||
|
||||
if "date" in df.columns:
|
||||
df["date"] = pd.to_datetime(df["date"], errors="coerce")
|
||||
df = df.dropna(subset=["date"])
|
||||
|
||||
price_cols = [c for c in ["open", "high", "low", "close", "volume"] if c in df.columns]
|
||||
if price_cols:
|
||||
df[price_cols] = df[price_cols].apply(pd.to_numeric, errors="coerce")
|
||||
|
||||
if "close" in df.columns:
|
||||
df = df.dropna(subset=["close"])
|
||||
|
||||
if price_cols:
|
||||
df[price_cols] = df[price_cols].ffill().bfill()
|
||||
|
||||
return df
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue