diff --git a/cli/main.py b/cli/main.py index 33fb10d4..9f4549e3 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1579,17 +1579,62 @@ def run_pipeline( console.print( f"\n[cyan]Running TradingAgents for {len(candidates)} tickers...[/cyan]" + f" [dim](up to 2 concurrent)[/dim]\n" ) - try: - with Live( - Spinner("dots", text="Analyzing..."), console=console, transient=True - ): + for c in candidates: + console.print( + f" [dim]▷ Queued:[/dim] [bold cyan]{c.ticker}[/bold cyan]" + f" [dim]{c.sector} · {c.conviction.upper()} conviction[/dim]" + ) + console.print() + import time as _time + from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn + + pipeline_start = _time.monotonic() + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[cyan]{task.completed}/{task.total}[/cyan]"), + TimeElapsedColumn(), + console=console, + transient=False, + ) as progress: + overall = progress.add_task("[bold]Pipeline progress[/bold]", total=len(candidates)) + + def on_done(result, done_count, total_count): + ticker_elapsed = _time.monotonic() - pipeline_start + if result.error: + console.print( + f" [red]✗ {result.ticker}[/red]" + f" [dim]failed ({ticker_elapsed:.0f}s elapsed) — {result.error[:80]}[/dim]" + ) + else: + decision_preview = str(result.final_trade_decision)[:70].replace("\n", " ") + console.print( + f" [green]✓ {result.ticker}[/green]" + f" [dim]({done_count}/{total_count}, {ticker_elapsed:.0f}s elapsed)[/dim]" + f" → {decision_preview}" + ) + progress.advance(overall) + + try: results = asyncio.run( - run_all_tickers(candidates, macro_context, config, analysis_date) + run_all_tickers( + candidates, macro_context, config, analysis_date, + on_ticker_done=on_done, + ) ) - except Exception as e: - console.print(f"[red]Pipeline failed: {e}[/red]") - raise typer.Exit(1) + except Exception as e: + console.print(f"[red]Pipeline failed: {e}[/red]") + raise typer.Exit(1) + + elapsed_total = _time.monotonic() - pipeline_start + console.print( + f"\n[bold green]All {len(candidates)} ticker(s) finished in {elapsed_total:.0f}s[/bold green]\n" + ) + save_results(results, macro_context, output_dir) diff --git a/tests/integration/test_stockstats_live.py b/tests/integration/test_stockstats_live.py new file mode 100644 index 00000000..500b643c --- /dev/null +++ b/tests/integration/test_stockstats_live.py @@ -0,0 +1,215 @@ +"""Live-data integration tests for the stockstats utilities. + +These tests call the real yfinance API and therefore require network access. +They are marked with ``integration`` and are excluded from the default test +run (which uses ``--ignore=tests/integration``). + +Run them explicitly with: + + python -m pytest tests/integration/test_stockstats_live.py -v --override-ini="addopts=" + +The tests validate the fix for the "Invalid number of return arguments after +parsing column name: 'Date'" error that occurred because stockstats.wrap() +promotes the lowercase ``date`` column to the DataFrame index, so the old +``df["Date"]`` access caused stockstats to try to parse "Date" as an indicator +name. The fix uses ``df.index.strftime("%Y-%m-%d")`` instead. +""" + +import pytest +import pandas as pd + + +pytestmark = pytest.mark.integration + +# A well-known trading day we can use for assertions +_TEST_DATE = "2025-01-02" +_TEST_TICKER = "AAPL" + + +# --------------------------------------------------------------------------- +# StockstatsUtils.get_stock_stats +# --------------------------------------------------------------------------- + +class TestStockstatsUtilsLive: + """Live tests for StockstatsUtils.get_stock_stats against real yfinance data.""" + + def test_close_50_sma_returns_numeric(self): + """close_50_sma indicator returns a numeric value for a known trading day.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", _TEST_DATE) + + assert result != "N/A: Not a trading day (weekend or holiday)", ( + f"Expected a numeric value for {_TEST_DATE}, got N/A (check if it's a holiday)" + ) + # Should be a finite float-like value + assert float(result) > 0, f"close_50_sma should be positive, got: {result}" + + def test_rsi_returns_value_in_valid_range(self): + """RSI indicator returns a value in [0, 100] for a known trading day.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "rsi", _TEST_DATE) + + assert result != "N/A: Not a trading day (weekend or holiday)", ( + f"Expected numeric RSI for {_TEST_DATE}" + ) + rsi = float(result) + assert 0.0 <= rsi <= 100.0, f"RSI must be in [0, 100], got: {rsi}" + + def test_macd_returns_numeric(self): + """MACD indicator returns a numeric value for a known trading day.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "macd", _TEST_DATE) + + assert result != "N/A: Not a trading day (weekend or holiday)" + # MACD can be positive or negative — just confirm it's a valid float + float(result) # raises ValueError if not numeric + + def test_weekend_returns_na(self): + """A weekend date returns the N/A holiday/weekend message.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + # 2025-01-04 is a Saturday + result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", "2025-01-04") + + assert result == "N/A: Not a trading day (weekend or holiday)", ( + f"Expected N/A for Saturday 2025-01-04, got: {result}" + ) + + def test_no_date_column_error(self): + """Calling get_stock_stats must NOT raise the 'Date' column parsing error.""" + from tradingagents.dataflows.stockstats_utils import StockstatsUtils + + # This previously raised: Invalid number of return arguments after + # parsing column name: 'Date' + try: + StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", _TEST_DATE) + except Exception as e: + if "Invalid number of return arguments" in str(e) and "Date" in str(e): + pytest.fail( + "Regression: stockstats is still trying to parse 'Date' as an " + f"indicator. Error: {e}" + ) + raise # re-raise unexpected errors + + +# --------------------------------------------------------------------------- +# _get_stock_stats_bulk +# --------------------------------------------------------------------------- + +class TestGetStockStatsBulkLive: + """Live tests for _get_stock_stats_bulk against real yfinance data.""" + + def test_returns_dict_with_date_keys(self): + """Bulk method returns a non-empty dict with YYYY-MM-DD string keys.""" + from tradingagents.dataflows.y_finance import _get_stock_stats_bulk + + result = _get_stock_stats_bulk(_TEST_TICKER, "rsi", _TEST_DATE) + + assert isinstance(result, dict), "Expected dict from _get_stock_stats_bulk" + assert len(result) > 0, "Expected non-empty result dict" + + # Keys should all be YYYY-MM-DD strings + for key in list(result.keys())[:5]: + pd.Timestamp(key) # raises if not parseable + + def test_trading_day_has_numeric_value(self): + """A known trading day has a numeric (non-N/A) value in the bulk result.""" + from tradingagents.dataflows.y_finance import _get_stock_stats_bulk + + result = _get_stock_stats_bulk(_TEST_TICKER, "rsi", _TEST_DATE) + + assert _TEST_DATE in result, ( + f"Expected {_TEST_DATE} in bulk result dict. Keys sample: {list(result.keys())[:5]}" + ) + value = result[_TEST_DATE] + assert value != "N/A", ( + f"Expected numeric RSI for {_TEST_DATE}, got N/A (check if it's a holiday)" + ) + float(value) # should be convertible to float + + def test_no_date_column_parsing_error(self): + """Bulk method must not raise the 'Date' column parsing error (regression guard).""" + from tradingagents.dataflows.y_finance import _get_stock_stats_bulk + + try: + _get_stock_stats_bulk(_TEST_TICKER, "close_50_sma", _TEST_DATE) + except Exception as e: + if "Invalid number of return arguments" in str(e) and "Date" in str(e): + pytest.fail( + "Regression: _get_stock_stats_bulk still hits the 'Date' indicator " + f"parsing error. Error: {e}" + ) + raise + + def test_multiple_indicators_all_work(self): + """All supported indicators can be computed without error.""" + from tradingagents.dataflows.y_finance import _get_stock_stats_bulk + + indicators = [ + "close_50_sma", + "close_200_sma", + "close_10_ema", + "macd", + "macds", + "macdh", + "rsi", + "boll", + "boll_ub", + "boll_lb", + "atr", + ] + + for indicator in indicators: + try: + result = _get_stock_stats_bulk(_TEST_TICKER, indicator, _TEST_DATE) + assert isinstance(result, dict), f"{indicator}: expected dict" + assert len(result) > 0, f"{indicator}: expected non-empty dict" + except Exception as e: + pytest.fail(f"Indicator '{indicator}' raised an unexpected error: {e}") + + +# --------------------------------------------------------------------------- +# get_stock_stats_indicators_window (end-to-end with live data) +# --------------------------------------------------------------------------- + +class TestGetStockStatsIndicatorsWindowLive: + """Live end-to-end tests for get_stock_stats_indicators_window.""" + + def test_rsi_window_returns_formatted_string(self): + """Window function returns a multi-line string with RSI values over a date range.""" + from tradingagents.dataflows.y_finance import get_stock_stats_indicators_window + + result = get_stock_stats_indicators_window(_TEST_TICKER, "rsi", _TEST_DATE, look_back_days=5) + + assert isinstance(result, str) + assert "rsi" in result.lower() + assert _TEST_DATE in result + # Should have date: value lines + lines = [l for l in result.split("\n") if ":" in l and "-" in l] + assert len(lines) > 0, "Expected date:value lines in result" + + def test_close_50_sma_window_contains_numeric_values(self): + """50-day SMA window result contains actual numeric price values.""" + from tradingagents.dataflows.y_finance import get_stock_stats_indicators_window + + result = get_stock_stats_indicators_window( + _TEST_TICKER, "close_50_sma", _TEST_DATE, look_back_days=10 + ) + + assert isinstance(result, str) + # At least some lines should have numeric values (not all N/A) + value_lines = [l for l in result.split("\n") if ":" in l and l.strip().startswith("20")] + numeric_values = [] + for line in value_lines: + try: + val = line.split(":", 1)[1].strip() + numeric_values.append(float(val)) + except (ValueError, IndexError): + pass # N/A lines are expected for weekends + + assert len(numeric_values) > 0, ( + "Expected at least some numeric 50-SMA values in the 10-day window" + ) diff --git a/tests/unit/test_incident_fixes.py b/tests/unit/test_incident_fixes.py new file mode 100644 index 00000000..944ed1e5 --- /dev/null +++ b/tests/unit/test_incident_fixes.py @@ -0,0 +1,259 @@ +"""Unit tests for the incident-fix improvements across the dataflows layer. + +Tests cover: + 1. _load_or_fetch_ohlcv — dynamic cache filename, corruption detection + re-fetch + 2. YFinanceError — propagated by get_stockstats_indicator (no more silent return "") + 3. _filter_csv_by_date_range — explicit date column discovery (no positional assumption) +""" + +import os +import pytest +import pandas as pd +from unittest.mock import patch, MagicMock + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _minimal_ohlcv_df(periods: int = 200) -> pd.DataFrame: + """Return a minimal valid OHLCV DataFrame with a Date column.""" + idx = pd.date_range("2024-01-02", periods=periods, freq="B") + return pd.DataFrame( + { + "Date": idx.strftime("%Y-%m-%d"), + "Open": [100.0] * periods, + "High": [105.0] * periods, + "Low": [95.0] * periods, + "Close": [102.0] * periods, + "Volume": [1_000_000] * periods, + } + ) + + +# --------------------------------------------------------------------------- +# _load_or_fetch_ohlcv — cache + download logic +# --------------------------------------------------------------------------- + +class TestLoadOrFetchOhlcv: + """Tests for the unified OHLCV loader.""" + + def test_downloads_and_caches_on_first_call(self, tmp_path): + """When no cache exists, yfinance is called and data is written to cache.""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + expected_df = _minimal_ohlcv_df() + mock_downloaded = expected_df.copy() + mock_downloaded.index = pd.RangeIndex(len(mock_downloaded)) # simulate reset_index output + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=expected_df.set_index("Date")) as mock_dl, + ): + result = _load_or_fetch_ohlcv("AAPL") + mock_dl.assert_called_once() + + # Cache file must exist after the call + csv_files = list(tmp_path.glob("AAPL-YFin-data-*.csv")) + assert len(csv_files) == 1, "Expected exactly one cache file to be created" + + def test_uses_cache_on_second_call(self, tmp_path): + """When cache already exists, yfinance.download is NOT called again.""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + # Write a valid cache file manually + df = _minimal_ohlcv_df() + today = pd.Timestamp.today() + start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d") + end = today.strftime("%Y-%m-%d") + cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv" + df.to_csv(cache_file, index=False) + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download") as mock_dl, + ): + result = _load_or_fetch_ohlcv("AAPL") + mock_dl.assert_not_called() + + assert len(result) == 200 + + def test_corrupt_cache_is_deleted_and_refetched(self, tmp_path): + """A corrupt (unparseable) cache file is deleted and yfinance is called again.""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + today = pd.Timestamp.today() + start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d") + end = today.strftime("%Y-%m-%d") + cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv" + cache_file.write_text(",,,,CORRUPT,,\x00\x00BINARY GARBAGE") + + fresh_df = _minimal_ohlcv_df() + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=fresh_df.set_index("Date")) as mock_dl, + ): + result = _load_or_fetch_ohlcv("AAPL") + mock_dl.assert_called_once() + + def test_truncated_cache_triggers_refetch(self, tmp_path): + """A cache file with fewer than 50 rows is treated as truncated and re-fetched.""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + tiny_df = _minimal_ohlcv_df(periods=10) # only 10 rows — well below threshold + today = pd.Timestamp.today() + start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d") + end = today.strftime("%Y-%m-%d") + cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv" + tiny_df.to_csv(cache_file, index=False) + + fresh_df = _minimal_ohlcv_df() + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=fresh_df.set_index("Date")) as mock_dl, + ): + result = _load_or_fetch_ohlcv("AAPL") + mock_dl.assert_called_once() + + def test_empty_download_raises_yfinance_error(self, tmp_path): + """An empty DataFrame from yfinance raises YFinanceError (not a silent return).""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv, YFinanceError + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=pd.DataFrame()), + ): + with pytest.raises(YFinanceError, match="no data"): + _load_or_fetch_ohlcv("INVALID_TICKER_XYZ") + + def test_cache_filename_is_dynamic_not_hardcoded(self, tmp_path): + """Cache filename contains today's date (not a hardcoded historical date like 2025-03-25).""" + from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv + + df = _minimal_ohlcv_df() + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=df.set_index("Date")), + ): + _load_or_fetch_ohlcv("AAPL") + + csv_files = list(tmp_path.glob("AAPL-YFin-data-*.csv")) + assert len(csv_files) == 1 + filename = csv_files[0].name + # Must NOT contain the old hardcoded stale date + assert "2025-03-25" not in filename, ( + f"Cache filename contains the old hardcoded stale date! Got: {filename}" + ) + # Must contain today's year + today_year = str(pd.Timestamp.today().year) + assert today_year in filename, f"Expected today's year {today_year} in filename: {filename}" + + +# --------------------------------------------------------------------------- +# YFinanceError propagation (no more silent return "") +# --------------------------------------------------------------------------- + +class TestYFinanceErrorPropagation: + """Tests that YFinanceError is raised (not swallowed) by get_stockstats_indicator.""" + + def test_get_stockstats_indicator_raises_yfinance_error_on_failure(self, tmp_path): + """get_stockstats_indicator raises YFinanceError when yfinance returns empty data.""" + from tradingagents.dataflows.stockstats_utils import YFinanceError + from tradingagents.dataflows.y_finance import get_stockstats_indicator + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=pd.DataFrame()), + ): + with pytest.raises(YFinanceError): + get_stockstats_indicator("INVALID", "rsi", "2025-01-02") + + def test_yfinance_error_is_not_swallowed_as_empty_string(self, tmp_path): + """Regression test: get_stockstats_indicator must NOT return empty string on error.""" + from tradingagents.dataflows.stockstats_utils import YFinanceError + from tradingagents.dataflows.y_finance import get_stockstats_indicator + + with ( + patch("tradingagents.dataflows.stockstats_utils.get_config", + return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}), + patch("tradingagents.dataflows.stockstats_utils.yf.download", + return_value=pd.DataFrame()), + ): + result = None + try: + result = get_stockstats_indicator("INVALID", "rsi", "2025-01-02") + except YFinanceError: + pass # This is the correct behaviour + + assert result is None, ( + "get_stockstats_indicator should raise YFinanceError, not silently return a value. " + f"Got: {result!r}" + ) + + +# --------------------------------------------------------------------------- +# _filter_csv_by_date_range — explicit column discovery +# --------------------------------------------------------------------------- + +class TestFilterCsvByDateRange: + """Tests for the fixed _filter_csv_by_date_range in alpha_vantage_common.""" + + def _make_av_csv(self, date_col_name: str = "time") -> str: + return ( + f"{date_col_name},SMA\n" + "2024-01-02,230.5\n" + "2024-01-03,231.0\n" + "2024-01-08,235.5\n" + ) + + def test_filters_with_standard_time_column(self): + """Standard Alpha Vantage CSV with 'time' header filters correctly.""" + from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range + + result = _filter_csv_by_date_range(self._make_av_csv("time"), "2024-01-03", "2024-01-08") + assert "2024-01-02" not in result + assert "2024-01-03" in result + assert "2024-01-08" in result + + def test_filters_with_timestamp_column(self): + """CSV with 'timestamp' header (alternative AV format) also works.""" + from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range + + result = _filter_csv_by_date_range(self._make_av_csv("timestamp"), "2024-01-03", "2024-01-08") + assert "2024-01-02" not in result + assert "2024-01-03" in result + + def test_missing_date_column_raises_not_silently_filters_wrong_column(self): + """When no recognised date column exists, raises ValueError immediately. + Previously it would silently use df.columns[0] and return garbage.""" + from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range + + bad_csv = "price,volume\n102.0,1000\n103.0,2000\n" + + # The fixed code raises ValueError; the old code would silently try to + # parse the 'price' column as a date and return the original data. + with pytest.raises(ValueError, match="Date column not found"): + _filter_csv_by_date_range(bad_csv, "2024-01-01", "2024-01-31") + + def test_empty_csv_returns_empty(self): + """Empty input returns empty output without error.""" + from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range + + result = _filter_csv_by_date_range("", "2024-01-01", "2024-01-31") + assert result == "" diff --git a/tradingagents/dataflows/alpha_vantage_common.py b/tradingagents/dataflows/alpha_vantage_common.py index 5aaa27a5..d6ac3f83 100644 --- a/tradingagents/dataflows/alpha_vantage_common.py +++ b/tradingagents/dataflows/alpha_vantage_common.py @@ -1,3 +1,4 @@ +import logging import os import requests import pandas as pd @@ -7,6 +8,9 @@ import time as _time from datetime import datetime from io import StringIO +logger = logging.getLogger(__name__) + + API_BASE_URL = "https://www.alphavantage.co/query" def get_api_key() -> str: @@ -228,8 +232,15 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> # Parse CSV data df = pd.read_csv(StringIO(csv_data)) - # Assume the first column is the date column (timestamp) - date_col = df.columns[0] + # Find the date column by name rather than positional assumption. + # Alpha Vantage returns 'time' or 'timestamp' as the date column header. + _date_candidates = [c for c in df.columns if c.lower() in ("time", "timestamp", "date")] + if not _date_candidates: + raise ValueError( + f"Date column not found in Alpha Vantage CSV. " + f"Expected 'time' or 'timestamp', got columns: {list(df.columns)}" + ) + date_col = _date_candidates[0] df[date_col] = pd.to_datetime(df[date_col]) # Filter by date range @@ -241,7 +252,11 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> # Convert back to CSV string return filtered_df.to_csv(index=False) + except ValueError: + # ValueError = a programming/data-contract error (e.g. missing date column). + # Re-raise so callers see it immediately rather than getting silently wrong data. + raise except Exception as e: - # If filtering fails, return original data with a warning - print(f"Warning: Failed to filter CSV data by date range: {e}") + # Transient errors (I/O, malformed CSV): log and return original data as-is. + logger.warning("Failed to filter CSV data by date range: %s", e) return csv_data diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 788fadd2..29a3d6d6 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -40,6 +40,7 @@ from .alpha_vantage_scanner import ( ) from .alpha_vantage_common import AlphaVantageError, AlphaVantageRateLimitError, RateLimitError from .finnhub_common import FinnhubError +from .stockstats_utils import YFinanceError from .finnhub_news import get_insider_transactions as get_finnhub_insider_transactions from .finnhub_scanner import ( get_market_indices_finnhub, @@ -262,7 +263,7 @@ def route_to_vendor(method: str, *args, **kwargs): if rl: rl.log_vendor_call(method, vendor, True, (time.time() - t0) * 1000, args_summary=args_summary) return result - except (AlphaVantageError, FinnhubError, ConnectionError, TimeoutError) as exc: + except (AlphaVantageError, FinnhubError, YFinanceError, ConnectionError, TimeoutError) as exc: if rl: rl.log_vendor_call(method, vendor, False, (time.time() - t0) * 1000, error=str(exc)[:200], args_summary=args_summary) last_error = exc diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index 9d43cf0e..34b9943f 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -1,3 +1,4 @@ +import logging import pandas as pd import yfinance as yf from stockstats import wrap @@ -5,6 +6,22 @@ from typing import Annotated import os from .config import get_config +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Public exception — lets callers catch stockstats/yfinance failures by type +# --------------------------------------------------------------------------- + + +class YFinanceError(Exception): + """Raised when yfinance or stockstats data fetching/processing fails.""" + pass + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame: """Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps. @@ -29,6 +46,80 @@ def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame: return df +def _load_or_fetch_ohlcv(symbol: str) -> pd.DataFrame: + """Single authority for loading OHLCV data: cache → yfinance download → normalize. + + Cache filename is always derived from today's date (15-year window) so the + cache key never goes stale. If a cached file exists but is corrupt (too few + rows to be useful), it is deleted and re-fetched rather than silently + returning bad data. + + Raises: + YFinanceError: if the download returns an empty DataFrame or fails. + """ + config = get_config() + + today_date = pd.Timestamp.today() + start_date = today_date - pd.DateOffset(years=15) + start_date_str = start_date.strftime("%Y-%m-%d") + end_date_str = today_date.strftime("%Y-%m-%d") + + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", + ) + + # ── Try to load from cache ──────────────────────────────────────────────── + if os.path.exists(data_file): + try: + data = pd.read_csv(data_file) # no on_bad_lines="skip" — we want to know about corruption + except Exception as exc: + logger.warning( + "Corrupt cache file for %s (%s) — deleting and re-fetching.", symbol, exc + ) + os.remove(data_file) + data = None + else: + # Validate: a 15-year daily file should have well over 100 rows + if len(data) < 50: + logger.warning( + "Cache file for %s has only %d rows — likely truncated, re-fetching.", + symbol, len(data), + ) + os.remove(data_file) + data = None + else: + data = None + + # ── Download from yfinance if cache miss / corrupt ──────────────────────── + if data is None: + raw = yf.download( + symbol, + start=start_date_str, + end=end_date_str, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + if raw.empty: + raise YFinanceError( + f"yfinance returned no data for symbol '{symbol}' " + f"({start_date_str} → {end_date_str})" + ) + data = raw.reset_index() + data.to_csv(data_file, index=False) + logger.debug("Downloaded and cached OHLCV for %s → %s", symbol, data_file) + + return data + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + class StockstatsUtils: @staticmethod def get_stock_stats( @@ -40,48 +131,20 @@ class StockstatsUtils: str, "curr date for retrieving stock price data, YYYY-mm-dd" ], ): - config = get_config() - - today_date = pd.Timestamp.today() curr_date_dt = pd.to_datetime(curr_date) - - end_date = today_date - start_date = today_date - pd.DateOffset(years=15) - start_date_str = start_date.strftime("%Y-%m-%d") - end_date_str = end_date.strftime("%Y-%m-%d") - - # Ensure cache directory exists - os.makedirs(config["data_cache_dir"], exist_ok=True) - - data_file = os.path.join( - config["data_cache_dir"], - f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", - ) - - if os.path.exists(data_file): - data = pd.read_csv(data_file, on_bad_lines="skip") - else: - data = yf.download( - symbol, - start=start_date_str, - end=end_date_str, - multi_level_index=False, - progress=False, - auto_adjust=True, - ) - data = data.reset_index() - data.to_csv(data_file, index=False) - - data = _clean_dataframe(data) - df = wrap(data) - df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") curr_date_str = curr_date_dt.strftime("%Y-%m-%d") + data = _load_or_fetch_ohlcv(symbol) + data = _clean_dataframe(data) + df = wrap(data) + # After wrap(), the date column becomes the datetime index (named 'date'). + # Access via df.index, not df["Date"] which stockstats would try to parse as an indicator. + df[indicator] # trigger stockstats to calculate the indicator - matching_rows = df[df["Date"].str.startswith(curr_date_str)] + date_index_strs = df.index.strftime("%Y-%m-%d") + matching_rows = df[date_index_strs == curr_date_str] if not matching_rows.empty: - indicator_value = matching_rows[indicator].values[0] - return indicator_value + return matching_rows[indicator].values[0] else: return "N/A: Not a trading day (weekend or holiday)" diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index fd58bf39..c837594a 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -1,9 +1,14 @@ +import logging from typing import Annotated from datetime import datetime from dateutil.relativedelta import relativedelta +import pandas as pd import yfinance as yf import os -from .stockstats_utils import StockstatsUtils, _clean_dataframe +from .stockstats_utils import StockstatsUtils, YFinanceError, _clean_dataframe, _load_or_fetch_ohlcv + +logger = logging.getLogger(__name__) + def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], @@ -191,70 +196,32 @@ def _get_stock_stats_bulk( ) -> dict: """ Optimized bulk calculation of stock stats indicators. - Fetches data once and calculates indicator for all available dates. + Fetches data once (via shared _load_or_fetch_ohlcv cache) and calculates + the indicator for all available dates. Returns dict mapping date strings to indicator values. + + Raises: + YFinanceError: if data cannot be loaded or indicator calculation fails. """ - from .config import get_config - import pandas as pd from stockstats import wrap - import os - - config = get_config() - online = config["data_vendors"]["technical_indicators"] != "local" - - if not online: - # Local data path - try: - data = pd.read_csv( - os.path.join( - config.get("data_cache_dir", "data"), - f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ), - on_bad_lines="skip", - ) - except FileNotFoundError: - raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") - else: - # Online data fetching with caching - today_date = pd.Timestamp.today() - curr_date_dt = pd.to_datetime(curr_date) - - end_date = today_date - start_date = today_date - pd.DateOffset(years=15) - start_date_str = start_date.strftime("%Y-%m-%d") - end_date_str = end_date.strftime("%Y-%m-%d") - - os.makedirs(config["data_cache_dir"], exist_ok=True) - - data_file = os.path.join( - config["data_cache_dir"], - f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", - ) - - if os.path.exists(data_file): - data = pd.read_csv(data_file, on_bad_lines="skip") - else: - data = yf.download( - symbol, - start=start_date_str, - end=end_date_str, - multi_level_index=False, - progress=False, - auto_adjust=True, - ) - data = data.reset_index() - data.to_csv(data_file, index=False) + # Single authority: _load_or_fetch_ohlcv handles both online and local modes, + # dynamic cache filename, and corpus validation — no duplicated download logic here. + data = _load_or_fetch_ohlcv(symbol) data = _clean_dataframe(data) df = wrap(data) - df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") - + # After wrap(), the date column becomes the datetime index (named 'date'). + # Access via df.index, not df["Date"] which stockstats would try to parse as an indicator. + # Calculate the indicator for all rows at once df[indicator] # This triggers stockstats to calculate the indicator - + # Create a dictionary mapping date strings to indicator values - # Optimized: replaced iterrows() with vectorized operations for performance - return df.set_index("Date")[indicator].fillna("N/A").astype(str).to_dict() + # Optimized: vectorized operations for performance using correct DatetimeIndex + series = df[indicator].copy() + series.index = series.index.strftime("%Y-%m-%d") + return series.fillna("N/A").astype(str).to_dict() + def get_stockstats_indicator( @@ -264,25 +231,15 @@ def get_stockstats_indicator( str, "The current trading date you are trading on, YYYY-mm-dd" ], ) -> str: - curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") curr_date = curr_date_dt.strftime("%Y-%m-%d") - try: - indicator_value = StockstatsUtils.get_stock_stats( - symbol, - indicator, - curr_date, - ) - except Exception as e: - print( - f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}" - ) - return "" - + # Raises YFinanceError on failure — caller (route_to_vendor) catches typed exceptions. + indicator_value = StockstatsUtils.get_stock_stats(symbol, indicator, curr_date) return str(indicator_value) + def get_fundamentals( ticker: Annotated[str, "ticker symbol of the company"], curr_date: Annotated[str, "current date (not used for yfinance)"] = None diff --git a/tradingagents/pipeline/macro_bridge.py b/tradingagents/pipeline/macro_bridge.py index 42759c63..e18c6ef9 100644 --- a/tradingagents/pipeline/macro_bridge.py +++ b/tradingagents/pipeline/macro_bridge.py @@ -5,12 +5,14 @@ from __future__ import annotations import asyncio import json import logging +import time from concurrent.futures import ThreadPoolExecutor from tradingagents.agents.utils.json_utils import extract_json from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Literal +from typing import Callable, Literal + logger = logging.getLogger(__name__) @@ -172,15 +174,6 @@ def run_ticker_analysis( NOTE: TradingAgentsGraph is synchronous — call this from a thread pool when running multiple tickers concurrently (see run_all_tickers). - - Args: - candidate: The stock candidate to analyse. - macro_context: Macro context to embed in the result. - config: TradingAgents configuration dict. - analysis_date: Date string in YYYY-MM-DD format. - - Returns: - TickerResult with all report fields populated, or error set on failure. """ result = TickerResult( ticker=candidate.ticker, @@ -189,11 +182,14 @@ def run_ticker_analysis( analysis_date=analysis_date, ) - logger.info("Starting analysis for %s on %s", candidate.ticker, analysis_date) + t0 = time.monotonic() + logger.info( + "[%s] ▶ Starting analysis (%s, %s conviction)", + candidate.ticker, candidate.sector, candidate.conviction, + ) try: from tradingagents.graph.trading_graph import TradingAgentsGraph - from tradingagents.observability import get_run_logger rl = get_run_logger() @@ -210,23 +206,31 @@ def run_ticker_analysis( result.risk_debate = str(final_state.get("risk_debate_state", "")) result.final_trade_decision = decision + elapsed = time.monotonic() - t0 logger.info( - "Analysis complete for %s: %s", candidate.ticker, str(decision)[:120] + "[%s] ✓ Analysis complete in %.0fs — decision: %s", + candidate.ticker, elapsed, str(decision)[:80], ) except Exception as exc: - logger.error("Analysis failed for %s: %s", candidate.ticker, exc, exc_info=True) + elapsed = time.monotonic() - t0 + logger.error( + "[%s] ✗ Analysis FAILED after %.0fs: %s", + candidate.ticker, elapsed, exc, exc_info=True, + ) result.error = str(exc) return result + async def run_all_tickers( candidates: list[StockCandidate], macro_context: MacroContext, config: dict, analysis_date: str, max_concurrent: int = 2, + on_ticker_done: Callable[[TickerResult, int, int], None] | None = None, ) -> list[TickerResult]: """Run TradingAgents for every candidate with controlled concurrency. @@ -239,28 +243,45 @@ async def run_all_tickers( config: TradingAgents configuration dict. analysis_date: Date string in YYYY-MM-DD format. max_concurrent: Maximum number of tickers to process in parallel. + on_ticker_done: Optional callback(result, done_count, total_count) fired + after each ticker finishes — use this to drive a progress bar. Returns: List of TickerResult in completion order. """ loop = asyncio.get_running_loop() - executor = ThreadPoolExecutor(max_workers=max_concurrent) - try: - tasks = [ - loop.run_in_executor( - executor, + total = len(candidates) + results: list[TickerResult] = [] + + # Use a semaphore so at most max_concurrent tickers run simultaneously, + # but we still get individual completion callbacks via as_completed. + semaphore = asyncio.Semaphore(max_concurrent) + + async def _run_one(candidate: StockCandidate) -> TickerResult: + async with semaphore: + return await loop.run_in_executor( + None, # use default ThreadPoolExecutor run_ticker_analysis, - c, + candidate, macro_context, config, analysis_date, ) - for c in candidates - ] - results = await asyncio.gather(*tasks) - return list(results) - finally: - executor.shutdown(wait=False) + + tasks = [asyncio.create_task(_run_one(c)) for c in candidates] + done_count = 0 + for coro in asyncio.as_completed(tasks): + result = await coro + done_count += 1 + results.append(result) + if on_ticker_done is not None: + try: + on_ticker_done(result, done_count, total) + except Exception: # never let a callback crash the pipeline + pass + + return results + # ─── Reporting ────────────────────────────────────────────────────────────────