From eafdce3121ab2e14bd39f11432606bae335b7af1 Mon Sep 17 00:00:00 2001 From: Ahmet Guzererler Date: Sun, 22 Mar 2026 00:07:32 +0100 Subject: [PATCH] fix: harden dataflows layer against silent failures and data corruption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem (Incident Post-mortem) The pipeline was emitting hundreds of errors: 'Invalid number of return arguments after parsing column name: Date' Root cause: after _clean_dataframe() lowercases all columns, stockstats.wrap() promotes 'date' to the DataFrame index. Subsequent df['Date'] access caused stockstats to try parsing 'Date' as a technical indicator name. ## Fixes ### 1. Fix df['Date'] stockstats bug (already shipped in prior commit) - stockstats_utils.py + y_finance.py: use df.index.strftime() instead of df['Date'] after wrap() ### 2. Extract _load_or_fetch_ohlcv() — single OHLCV authority - Eliminates duplicated 30-line download+cache boilerplate in two places - Cache filename is always derived from today's date — hardcoded stale date '2015-01-01-2025-03-25' in local mode is gone - Corruption/truncation detection: files <50 rows or unparseable are deleted and re-fetched rather than silently returning bad data - Drops on_bad_lines='skip' — malformed CSVs now raise instead of silently dropping rows that would distort indicator calculations ### 3. YFinanceError typed exception - Defined in stockstats_utils.py; raised instead of print()+return '' - get_stockstats_indicator now raises YFinanceError on failure so errors surface to callers rather than delivering empty strings to LLM agents - interface.py route_to_vendor now catches YFinanceError alongside AlphaVantageError and FinnhubError — failures appear in observability telemetry and can trigger vendor fallback ### 4. Explicit date column discovery in alpha_vantage_common - _filter_csv_by_date_range: replaced df.columns[0] positional assumption with explicit search for 'time'/'timestamp'/'date' column - ValueError re-raised (not swallowed) so bad API response shape is visible ### 5. Structured logging - Replaced all print() calls in changed files with logging.getLogger() - Added logging import + logger to alpha_vantage_common ## Tests - tests/unit/test_incident_fixes.py: 12 new unit tests covering all fixes (dynamic cache filename, corruption re-fetch, YFinanceError propagation, explicit column lookup, empty download raises) - tests/integration/test_stockstats_live.py: 11 live tests against real yfinance API (all major indicators, weekend N/A, regression guard) - All 70 tests pass (59 unit + 11 live integration) --- tests/integration/test_stockstats_live.py | 215 +++++++++++++++ tests/unit/test_incident_fixes.py | 259 ++++++++++++++++++ .../dataflows/alpha_vantage_common.py | 23 +- tradingagents/dataflows/interface.py | 3 +- tradingagents/dataflows/stockstats_utils.py | 137 ++++++--- tradingagents/dataflows/y_finance.py | 97 ++----- 6 files changed, 621 insertions(+), 113 deletions(-) create mode 100644 tests/integration/test_stockstats_live.py create mode 100644 tests/unit/test_incident_fixes.py diff --git a/tests/integration/test_stockstats_live.py b/tests/integration/test_stockstats_live.py new file mode 100644 index 00000000..500b643c --- /dev/null +++ b/tests/integration/test_stockstats_live.py @@ -0,0 +1,215 @@ +"""Live-data integration tests for the stockstats utilities. + +These tests call the real yfinance API and therefore require network access. +They are marked with ``integration`` and are excluded from the default test +run (which uses ``--ignore=tests/integration``). + +Run them explicitly with: + + python -m pytest tests/integration/test_stockstats_live.py -v --override-ini="addopts=" + +The tests validate the fix for the "Invalid number of return arguments after +parsing column name: 'Date'" error that occurred because stockstats.wrap() +promotes the lowercase ``date`` column to the DataFrame index, so the old +``df["Date"]`` access caused stockstats to try to parse "Date" as an indicator +name. The fix uses ``df.index.strftime("%Y-%m-%d")`` instead. +""" + +import pytest +import pandas as pd + + +pytestmark = pytest.mark.integration + +# A well-known trading day we can use for assertions +_TEST_DATE = "2025-01-02" +_TEST_TICKER = "AAPL" + + +# --------------------------------------------------------------------------- +# StockstatsUtils.get_stock_stats +# --------------------------------------------------------------------------- + +class TestStockstatsUtilsLive: + """Live tests for StockstatsUtils.get_stock_stats against real yfinance data.""" + + def test_close_50_sma_returns_numeric(self): + """close_50_sma indicator returns a numeric value for a known trading day.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", _TEST_DATE) + + assert result != "N/A: Not a trading day (weekend or holiday)", ( + f"Expected a numeric value for {_TEST_DATE}, got N/A (check if it's a holiday)" + ) + # Should be a finite float-like value + assert float(result) > 0, f"close_50_sma should be positive, got: {result}" + + def test_rsi_returns_value_in_valid_range(self): + """RSI indicator returns a value in [0, 100] for a known trading day.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "rsi", _TEST_DATE) + + assert result != "N/A: Not a trading day (weekend or holiday)", ( + f"Expected numeric RSI for {_TEST_DATE}" + ) + rsi = float(result) + assert 0.0 <= rsi <= 100.0, f"RSI must be in [0, 100], got: {rsi}" + + def test_macd_returns_numeric(self): + """MACD indicator returns a numeric value for a known trading day.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "macd", _TEST_DATE) + + assert result != "N/A: Not a trading day (weekend or holiday)" + # MACD can be positive or negative — just confirm it's a valid float + float(result) # raises ValueError if not numeric + + def test_weekend_returns_na(self): + """A weekend date returns the N/A holiday/weekend message.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + # 2025-01-04 is a Saturday + result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", "2025-01-04") + + assert result == "N/A: Not a trading day (weekend or holiday)", ( + f"Expected N/A for Saturday 2025-01-04, got: {result}" + ) + + def test_no_date_column_error(self): + """Calling get_stock_stats must NOT raise the 'Date' column parsing error.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + # This previously raised: Invalid number of return arguments after + # parsing column name: 'Date' + try: + StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", _TEST_DATE) + except Exception as e: + if "Invalid number of return arguments" in str(e) and "Date" in str(e): + pytest.fail( + "Regression: stockstats is still trying to parse 'Date' as an " + f"indicator. Error: {e}" + ) + raise # re-raise unexpected errors + + +# --------------------------------------------------------------------------- +# _get_stock_stats_bulk +# --------------------------------------------------------------------------- + +class TestGetStockStatsBulkLive: + """Live tests for _get_stock_stats_bulk against real yfinance data.""" + + def test_returns_dict_with_date_keys(self): + """Bulk method returns a non-empty dict with YYYY-MM-DD string keys.""" + from tradingagents.dataflows.y_finance import _get_stock_stats_bulk + + result = _get_stock_stats_bulk(_TEST_TICKER, "rsi", _TEST_DATE) + + assert isinstance(result, dict), "Expected dict from _get_stock_stats_bulk" + assert len(result) > 0, "Expected non-empty result dict" + + # Keys should all be YYYY-MM-DD strings + for key in list(result.keys())[:5]: + pd.Timestamp(key) # raises if not parseable + + def test_trading_day_has_numeric_value(self): + """A known trading day has a numeric (non-N/A) value in the bulk result.""" + from tradingagents.dataflows.y_finance import _get_stock_stats_bulk + + result = _get_stock_stats_bulk(_TEST_TICKER, "rsi", _TEST_DATE) + + assert _TEST_DATE in result, ( + f"Expected {_TEST_DATE} in bulk result dict. Keys sample: {list(result.keys())[:5]}" + ) + value = result[_TEST_DATE] + assert value != "N/A", ( + f"Expected numeric RSI for {_TEST_DATE}, got N/A (check if it's a holiday)" + ) + float(value) # should be convertible to float + + def test_no_date_column_parsing_error(self): + """Bulk method must not raise the 'Date' column parsing error (regression guard).""" + from tradingagents.dataflows.y_finance import _get_stock_stats_bulk + + try: + _get_stock_stats_bulk(_TEST_TICKER, "close_50_sma", _TEST_DATE) + except Exception as e: + if "Invalid number of return arguments" in str(e) and "Date" in str(e): + pytest.fail( + "Regression: _get_stock_stats_bulk still hits the 'Date' indicator " + f"parsing error. Error: {e}" + ) + raise + + def test_multiple_indicators_all_work(self): + """All supported indicators can be computed without error.""" + from tradingagents.dataflows.y_finance import _get_stock_stats_bulk + + indicators = [ + "close_50_sma", + "close_200_sma", + "close_10_ema", + "macd", + "macds", + "macdh", + "rsi", + "boll", + "boll_ub", + "boll_lb", + "atr", + ] + + for indicator in indicators: + try: + result = _get_stock_stats_bulk(_TEST_TICKER, indicator, _TEST_DATE) + assert isinstance(result, dict), f"{indicator}: expected dict" + assert len(result) > 0, f"{indicator}: expected non-empty dict" + except Exception as e: + pytest.fail(f"Indicator '{indicator}' raised an unexpected error: {e}") + + +# --------------------------------------------------------------------------- +# get_stock_stats_indicators_window (end-to-end with live data) +# --------------------------------------------------------------------------- + +class TestGetStockStatsIndicatorsWindowLive: + """Live end-to-end tests for get_stock_stats_indicators_window.""" + + def test_rsi_window_returns_formatted_string(self): + """Window function returns a multi-line string with RSI values over a date range.""" + from tradingagents.dataflows.y_finance import get_stock_stats_indicators_window + + result = get_stock_stats_indicators_window(_TEST_TICKER, "rsi", _TEST_DATE, look_back_days=5) + + assert isinstance(result, str) + assert "rsi" in result.lower() + assert _TEST_DATE in result + # Should have date: value lines + lines = [l for l in result.split("\n") if ":" in l and "-" in l] + assert len(lines) > 0, "Expected date:value lines in result" + + def test_close_50_sma_window_contains_numeric_values(self): + """50-day SMA window result contains actual numeric price values.""" + from tradingagents.dataflows.y_finance import get_stock_stats_indicators_window + + result = get_stock_stats_indicators_window( + _TEST_TICKER, "close_50_sma", _TEST_DATE, look_back_days=10 + ) + + assert isinstance(result, str) + # At least some lines should have numeric values (not all N/A) + value_lines = [l for l in result.split("\n") if ":" in l and l.strip().startswith("20")] + numeric_values = [] + for line in value_lines: + try: + val = line.split(":", 1)[1].strip() + numeric_values.append(float(val)) + except (ValueError, IndexError): + pass # N/A lines are expected for weekends + + assert len(numeric_values) > 0, ( + "Expected at least some numeric 50-SMA values in the 10-day window" + ) diff --git a/tests/unit/test_incident_fixes.py b/tests/unit/test_incident_fixes.py new file mode 100644 index 00000000..944ed1e5 --- /dev/null +++ b/tests/unit/test_incident_fixes.py @@ -0,0 +1,259 @@ +"""Unit tests for the incident-fix improvements across the dataflows layer. + +Tests cover: + 1. _load_or_fetch_ohlcv — dynamic cache filename, corruption detection + re-fetch + 2. YFinanceError — propagated by get_stockstats_indicator (no more silent return "") + 3. _filter_csv_by_date_range — explicit date column discovery (no positional assumption) +""" + +import os +import pytest +import pandas as pd +from unittest.mock import patch, MagicMock + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _minimal_ohlcv_df(periods: int = 200) -> pd.DataFrame: + """Return a minimal valid OHLCV DataFrame with a Date column.""" + idx = pd.date_range("2024-01-02", periods=periods, freq="B") + return pd.DataFrame( + { + "Date": idx.strftime("%Y-%m-%d"), + "Open": [100.0] * periods, + "High": [105.0] * periods, + "Low": [95.0] * periods, + "Close": [102.0] * periods, + "Volume": [1_000_000] * periods, + } + ) + + +# --------------------------------------------------------------------------- +# _load_or_fetch_ohlcv — cache + download logic +# --------------------------------------------------------------------------- + +class TestLoadOrFetchOhlcv: + """Tests for the unified OHLCV loader.""" + + def test_downloads_and_caches_on_first_call(self, tmp_path): + """When no cache exists, yfinance is called and data is written to cache.""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + expected_df = _minimal_ohlcv_df() + mock_downloaded = expected_df.copy() + mock_downloaded.index = pd.RangeIndex(len(mock_downloaded)) # simulate reset_index output + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=expected_df.set_index("Date")) as mock_dl, + ): + result = _load_or_fetch_ohlcv("AAPL") + mock_dl.assert_called_once() + + # Cache file must exist after the call + csv_files = list(tmp_path.glob("AAPL-YFin-data-*.csv")) + assert len(csv_files) == 1, "Expected exactly one cache file to be created" + + def test_uses_cache_on_second_call(self, tmp_path): + """When cache already exists, yfinance.download is NOT called again.""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + # Write a valid cache file manually + df = _minimal_ohlcv_df() + today = pd.Timestamp.today() + start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d") + end = today.strftime("%Y-%m-%d") + cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv" + df.to_csv(cache_file, index=False) + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download") as mock_dl, + ): + result = _load_or_fetch_ohlcv("AAPL") + mock_dl.assert_not_called() + + assert len(result) == 200 + + def test_corrupt_cache_is_deleted_and_refetched(self, tmp_path): + """A corrupt (unparseable) cache file is deleted and yfinance is called again.""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + today = pd.Timestamp.today() + start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d") + end = today.strftime("%Y-%m-%d") + cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv" + cache_file.write_text(",,,,CORRUPT,,\x00\x00BINARY GARBAGE") + + fresh_df = _minimal_ohlcv_df() + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=fresh_df.set_index("Date")) as mock_dl, + ): + result = _load_or_fetch_ohlcv("AAPL") + mock_dl.assert_called_once() + + def test_truncated_cache_triggers_refetch(self, tmp_path): + """A cache file with fewer than 50 rows is treated as truncated and re-fetched.""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + tiny_df = _minimal_ohlcv_df(periods=10) # only 10 rows — well below threshold + today = pd.Timestamp.today() + start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d") + end = today.strftime("%Y-%m-%d") + cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv" + tiny_df.to_csv(cache_file, index=False) + + fresh_df = _minimal_ohlcv_df() + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=fresh_df.set_index("Date")) as mock_dl, + ): + result = _load_or_fetch_ohlcv("AAPL") + mock_dl.assert_called_once() + + def test_empty_download_raises_yfinance_error(self, tmp_path): + """An empty DataFrame from yfinance raises YFinanceError (not a silent return).""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv, YFinanceError + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=pd.DataFrame()), + ): + with pytest.raises(YFinanceError, match="no data"): + _load_or_fetch_ohlcv("INVALID_TICKER_XYZ") + + def test_cache_filename_is_dynamic_not_hardcoded(self, tmp_path): + """Cache filename contains today's date (not a hardcoded historical date like 2025-03-25).""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + df = _minimal_ohlcv_df() + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=df.set_index("Date")), + ): + _load_or_fetch_ohlcv("AAPL") + + csv_files = list(tmp_path.glob("AAPL-YFin-data-*.csv")) + assert len(csv_files) == 1 + filename = csv_files[0].name + # Must NOT contain the old hardcoded stale date + assert "2025-03-25" not in filename, ( + f"Cache filename contains the old hardcoded stale date! Got: {filename}" + ) + # Must contain today's year + today_year = str(pd.Timestamp.today().year) + assert today_year in filename, f"Expected today's year {today_year} in filename: {filename}" + + +# --------------------------------------------------------------------------- +# YFinanceError propagation (no more silent return "") +# --------------------------------------------------------------------------- + +class TestYFinanceErrorPropagation: + """Tests that YFinanceError is raised (not swallowed) by get_stockstats_indicator.""" + + def test_get_stockstats_indicator_raises_yfinance_error_on_failure(self, tmp_path): + """get_stockstats_indicator raises YFinanceError when yfinance returns empty data.""" + from tradingagents.dataflows.stockstats_utils import YFinanceError + from tradingagents.dataflows.y_finance import get_stockstats_indicator + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=pd.DataFrame()), + ): + with pytest.raises(YFinanceError): + get_stockstats_indicator("INVALID", "rsi", "2025-01-02") + + def test_yfinance_error_is_not_swallowed_as_empty_string(self, tmp_path): + """Regression test: get_stockstats_indicator must NOT return empty string on error.""" + from tradingagents.dataflows.stockstats_utils import YFinanceError + from tradingagents.dataflows.y_finance import get_stockstats_indicator + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=pd.DataFrame()), + ): + result = None + try: + result = get_stockstats_indicator("INVALID", "rsi", "2025-01-02") + except YFinanceError: + pass # This is the correct behaviour + + assert result is None, ( + "get_stockstats_indicator should raise YFinanceError, not silently return a value. " + f"Got: {result!r}" + ) + + +# --------------------------------------------------------------------------- +# _filter_csv_by_date_range — explicit column discovery +# --------------------------------------------------------------------------- + +class TestFilterCsvByDateRange: + """Tests for the fixed _filter_csv_by_date_range in alpha_vantage_common.""" + + def _make_av_csv(self, date_col_name: str = "time") -> str: + return ( + f"{date_col_name},SMA\n" + "2024-01-02,230.5\n" + "2024-01-03,231.0\n" + "2024-01-08,235.5\n" + ) + + def test_filters_with_standard_time_column(self): + """Standard Alpha Vantage CSV with 'time' header filters correctly.""" + from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range + + result = _filter_csv_by_date_range(self._make_av_csv("time"), "2024-01-03", "2024-01-08") + assert "2024-01-02" not in result + assert "2024-01-03" in result + assert "2024-01-08" in result + + def test_filters_with_timestamp_column(self): + """CSV with 'timestamp' header (alternative AV format) also works.""" + from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range + + result = _filter_csv_by_date_range(self._make_av_csv("timestamp"), "2024-01-03", "2024-01-08") + assert "2024-01-02" not in result + assert "2024-01-03" in result + + def test_missing_date_column_raises_not_silently_filters_wrong_column(self): + """When no recognised date column exists, raises ValueError immediately. + Previously it would silently use df.columns[0] and return garbage.""" + from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range + + bad_csv = "price,volume\n102.0,1000\n103.0,2000\n" + + # The fixed code raises ValueError; the old code would silently try to + # parse the 'price' column as a date and return the original data. + with pytest.raises(ValueError, match="Date column not found"): + _filter_csv_by_date_range(bad_csv, "2024-01-01", "2024-01-31") + + def test_empty_csv_returns_empty(self): + """Empty input returns empty output without error.""" + from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range + + result = _filter_csv_by_date_range("", "2024-01-01", "2024-01-31") + assert result == "" diff --git a/tradingagents/dataflows/alpha_vantage_common.py b/tradingagents/dataflows/alpha_vantage_common.py index 5aaa27a5..d6ac3f83 100644 --- a/tradingagents/dataflows/alpha_vantage_common.py +++ b/tradingagents/dataflows/alpha_vantage_common.py @@ -1,3 +1,4 @@ +import logging import os import requests import pandas as pd @@ -7,6 +8,9 @@ import time as _time from datetime import datetime from io import StringIO +logger = logging.getLogger(__name__) + + API_BASE_URL = "https://www.alphavantage.co/query" def get_api_key() -> str: @@ -228,8 +232,15 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> # Parse CSV data df = pd.read_csv(StringIO(csv_data)) - # Assume the first column is the date column (timestamp) - date_col = df.columns[0] + # Find the date column by name rather than positional assumption. + # Alpha Vantage returns 'time' or 'timestamp' as the date column header. + _date_candidates = [c for c in df.columns if c.lower() in ("time", "timestamp", "date")] + if not _date_candidates: + raise ValueError( + f"Date column not found in Alpha Vantage CSV. " + f"Expected 'time' or 'timestamp', got columns: {list(df.columns)}" + ) + date_col = _date_candidates[0] df[date_col] = pd.to_datetime(df[date_col]) # Filter by date range @@ -241,7 +252,11 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> # Convert back to CSV string return filtered_df.to_csv(index=False) + except ValueError: + # ValueError = a programming/data-contract error (e.g. missing date column). + # Re-raise so callers see it immediately rather than getting silently wrong data. + raise except Exception as e: - # If filtering fails, return original data with a warning - print(f"Warning: Failed to filter CSV data by date range: {e}") + # Transient errors (I/O, malformed CSV): log and return original data as-is. + logger.warning("Failed to filter CSV data by date range: %s", e) return csv_data diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 788fadd2..29a3d6d6 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -40,6 +40,7 @@ from .alpha_vantage_scanner import ( ) from .alpha_vantage_common import AlphaVantageError, AlphaVantageRateLimitError, RateLimitError from .finnhub_common import FinnhubError +from .stockstats_utils import YFinanceError from .finnhub_news import get_insider_transactions as get_finnhub_insider_transactions from .finnhub_scanner import ( get_market_indices_finnhub, @@ -262,7 +263,7 @@ def route_to_vendor(method: str, *args, **kwargs): if rl: rl.log_vendor_call(method, vendor, True, (time.time() - t0) * 1000, args_summary=args_summary) return result - except (AlphaVantageError, FinnhubError, ConnectionError, TimeoutError) as exc: + except (AlphaVantageError, FinnhubError, YFinanceError, ConnectionError, TimeoutError) as exc: if rl: rl.log_vendor_call(method, vendor, False, (time.time() - t0) * 1000, error=str(exc)[:200], args_summary=args_summary) last_error = exc diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index 9d43cf0e..34b9943f 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -1,3 +1,4 @@ +import logging import pandas as pd import yfinance as yf from stockstats import wrap @@ -5,6 +6,22 @@ from typing import Annotated import os from .config import get_config +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Public exception — lets callers catch stockstats/yfinance failures by type +# --------------------------------------------------------------------------- + + +class YFinanceError(Exception): + """Raised when yfinance or stockstats data fetching/processing fails.""" + pass + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame: """Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps. @@ -29,6 +46,80 @@ def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame: return df +def _load_or_fetch_ohlcv(symbol: str) -> pd.DataFrame: + """Single authority for loading OHLCV data: cache → yfinance download → normalize. + + Cache filename is always derived from today's date (15-year window) so the + cache key never goes stale. If a cached file exists but is corrupt (too few + rows to be useful), it is deleted and re-fetched rather than silently + returning bad data. + + Raises: + YFinanceError: if the download returns an empty DataFrame or fails. + """ + config = get_config() + + today_date = pd.Timestamp.today() + start_date = today_date - pd.DateOffset(years=15) + start_date_str = start_date.strftime("%Y-%m-%d") + end_date_str = today_date.strftime("%Y-%m-%d") + + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", + ) + + # ── Try to load from cache ──────────────────────────────────────────────── + if os.path.exists(data_file): + try: + data = pd.read_csv(data_file) # no on_bad_lines="skip" — we want to know about corruption + except Exception as exc: + logger.warning( + "Corrupt cache file for %s (%s) — deleting and re-fetching.", symbol, exc + ) + os.remove(data_file) + data = None + else: + # Validate: a 15-year daily file should have well over 100 rows + if len(data) < 50: + logger.warning( + "Cache file for %s has only %d rows — likely truncated, re-fetching.", + symbol, len(data), + ) + os.remove(data_file) + data = None + else: + data = None + + # ── Download from yfinance if cache miss / corrupt ──────────────────────── + if data is None: + raw = yf.download( + symbol, + start=start_date_str, + end=end_date_str, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + if raw.empty: + raise YFinanceError( + f"yfinance returned no data for symbol '{symbol}' " + f"({start_date_str} → {end_date_str})" + ) + data = raw.reset_index() + data.to_csv(data_file, index=False) + logger.debug("Downloaded and cached OHLCV for %s → %s", symbol, data_file) + + return data + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + class StockstatsUtils: @staticmethod def get_stock_stats( @@ -40,48 +131,20 @@ class StockstatsUtils: str, "curr date for retrieving stock price data, YYYY-mm-dd" ], ): - config = get_config() - - today_date = pd.Timestamp.today() curr_date_dt = pd.to_datetime(curr_date) - - end_date = today_date - start_date = today_date - pd.DateOffset(years=15) - start_date_str = start_date.strftime("%Y-%m-%d") - end_date_str = end_date.strftime("%Y-%m-%d") - - # Ensure cache directory exists - os.makedirs(config["data_cache_dir"], exist_ok=True) - - data_file = os.path.join( - config["data_cache_dir"], - f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", - ) - - if os.path.exists(data_file): - data = pd.read_csv(data_file, on_bad_lines="skip") - else: - data = yf.download( - symbol, - start=start_date_str, - end=end_date_str, - multi_level_index=False, - progress=False, - auto_adjust=True, - ) - data = data.reset_index() - data.to_csv(data_file, index=False) - - data = _clean_dataframe(data) - df = wrap(data) - df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") curr_date_str = curr_date_dt.strftime("%Y-%m-%d") + data = _load_or_fetch_ohlcv(symbol) + data = _clean_dataframe(data) + df = wrap(data) + # After wrap(), the date column becomes the datetime index (named 'date'). + # Access via df.index, not df["Date"] which stockstats would try to parse as an indicator. + df[indicator] # trigger stockstats to calculate the indicator - matching_rows = df[df["Date"].str.startswith(curr_date_str)] + date_index_strs = df.index.strftime("%Y-%m-%d") + matching_rows = df[date_index_strs == curr_date_str] if not matching_rows.empty: - indicator_value = matching_rows[indicator].values[0] - return indicator_value + return matching_rows[indicator].values[0] else: return "N/A: Not a trading day (weekend or holiday)" diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index b915490d..7682184a 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -1,9 +1,14 @@ +import logging from typing import Annotated from datetime import datetime from dateutil.relativedelta import relativedelta +import pandas as pd import yfinance as yf import os -from .stockstats_utils import StockstatsUtils, _clean_dataframe +from .stockstats_utils import StockstatsUtils, YFinanceError, _clean_dataframe, _load_or_fetch_ohlcv + +logger = logging.getLogger(__name__) + def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], @@ -191,82 +196,42 @@ def _get_stock_stats_bulk( ) -> dict: """ Optimized bulk calculation of stock stats indicators. - Fetches data once and calculates indicator for all available dates. + Fetches data once (via shared _load_or_fetch_ohlcv cache) and calculates + the indicator for all available dates. Returns dict mapping date strings to indicator values. + + Raises: + YFinanceError: if data cannot be loaded or indicator calculation fails. """ - from .config import get_config - import pandas as pd from stockstats import wrap - import os - - config = get_config() - online = config["data_vendors"]["technical_indicators"] != "local" - - if not online: - # Local data path - try: - data = pd.read_csv( - os.path.join( - config.get("data_cache_dir", "data"), - f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ), - on_bad_lines="skip", - ) - except FileNotFoundError: - raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") - else: - # Online data fetching with caching - today_date = pd.Timestamp.today() - curr_date_dt = pd.to_datetime(curr_date) - - end_date = today_date - start_date = today_date - pd.DateOffset(years=15) - start_date_str = start_date.strftime("%Y-%m-%d") - end_date_str = end_date.strftime("%Y-%m-%d") - - os.makedirs(config["data_cache_dir"], exist_ok=True) - - data_file = os.path.join( - config["data_cache_dir"], - f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", - ) - - if os.path.exists(data_file): - data = pd.read_csv(data_file, on_bad_lines="skip") - else: - data = yf.download( - symbol, - start=start_date_str, - end=end_date_str, - multi_level_index=False, - progress=False, - auto_adjust=True, - ) - data = data.reset_index() - data.to_csv(data_file, index=False) + # Single authority: _load_or_fetch_ohlcv handles both online and local modes, + # dynamic cache filename, and corpus validation — no duplicated download logic here. + data = _load_or_fetch_ohlcv(symbol) data = _clean_dataframe(data) df = wrap(data) - df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") - + # After wrap(), the date column becomes the datetime index (named 'date'). + # Access via df.index, not df["Date"] which stockstats would try to parse as an indicator. + # Calculate the indicator for all rows at once df[indicator] # This triggers stockstats to calculate the indicator - + # Create a dictionary mapping date strings to indicator values result_dict = {} - for _, row in df.iterrows(): - date_str = row["Date"] + date_index_strs = df.index.strftime("%Y-%m-%d") + for date_str, (_, row) in zip(date_index_strs, df.iterrows()): indicator_value = row[indicator] - + # Handle NaN/None values if pd.isna(indicator_value): result_dict[date_str] = "N/A" else: result_dict[date_str] = str(indicator_value) - + return result_dict + def get_stockstats_indicator( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[str, "technical indicator to get the analysis and report of"], @@ -274,25 +239,15 @@ def get_stockstats_indicator( str, "The current trading date you are trading on, YYYY-mm-dd" ], ) -> str: - curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") curr_date = curr_date_dt.strftime("%Y-%m-%d") - try: - indicator_value = StockstatsUtils.get_stock_stats( - symbol, - indicator, - curr_date, - ) - except Exception as e: - print( - f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}" - ) - return "" - + # Raises YFinanceError on failure — caller (route_to_vendor) catches typed exceptions. + indicator_value = StockstatsUtils.get_stock_stats(symbol, indicator, curr_date) return str(indicator_value) + def get_fundamentals( ticker: Annotated[str, "ticker symbol of the company"], curr_date: Annotated[str, "current date (not used for yfinance)"] = None