diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index b31935b7..467156a2 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -6,6 +6,19 @@ import os from .config import get_config +def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame: + """Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps.""" + data["Date"] = pd.to_datetime(data["Date"], errors="coerce") + data = data.dropna(subset=["Date"]) + + price_cols = [c for c in ["Open", "High", "Low", "Close", "Volume"] if c in data.columns] + data[price_cols] = data[price_cols].apply(pd.to_numeric, errors="coerce") + data = data.dropna(subset=["Close"]) + data[price_cols] = data[price_cols].ffill().bfill() + + return data + + class StockstatsUtils: @staticmethod def get_stock_stats( @@ -36,8 +49,7 @@ class StockstatsUtils: ) if os.path.exists(data_file): - data = pd.read_csv(data_file) - data["Date"] = pd.to_datetime(data["Date"]) + data = pd.read_csv(data_file, on_bad_lines="skip") else: data = yf.download( symbol, @@ -50,6 +62,7 @@ class StockstatsUtils: data = data.reset_index() data.to_csv(data_file, index=False) + data = _clean_dataframe(data) df = wrap(data) df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") curr_date_str = curr_date_dt.strftime("%Y-%m-%d") diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index bc78d8b3..b915490d 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -3,7 +3,7 @@ from datetime import datetime from dateutil.relativedelta import relativedelta import yfinance as yf import os -from .stockstats_utils import StockstatsUtils +from .stockstats_utils import StockstatsUtils, _clean_dataframe def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], @@ -209,31 +209,30 @@ def _get_stock_stats_bulk( os.path.join( config.get("data_cache_dir", "data"), f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) + ), + on_bad_lines="skip", ) - df = wrap(data) except FileNotFoundError: raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") else: # Online data fetching with caching today_date = pd.Timestamp.today() curr_date_dt = pd.to_datetime(curr_date) - + end_date = today_date start_date = today_date - pd.DateOffset(years=15) start_date_str = start_date.strftime("%Y-%m-%d") end_date_str = end_date.strftime("%Y-%m-%d") - + os.makedirs(config["data_cache_dir"], exist_ok=True) - + data_file = os.path.join( config["data_cache_dir"], f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", ) - + if os.path.exists(data_file): - data = pd.read_csv(data_file) - data["Date"] = pd.to_datetime(data["Date"]) + data = pd.read_csv(data_file, on_bad_lines="skip") else: data = yf.download( symbol, @@ -245,9 +244,10 @@ def _get_stock_stats_bulk( ) data = data.reset_index() data.to_csv(data_file, index=False) - - df = wrap(data) - df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") + + data = _clean_dataframe(data) + df = wrap(data) + df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") # Calculate the indicator for all rows at once df[indicator] # This triggers stockstats to calculate the indicator