Merge pull request #85 from aguzererler/fix/dataflows-incident-hardened-error-handling

fix: harden dataflows layer against silent failures and data corruption
This commit is contained in:
ahmet guzererler 2026-03-22 07:43:33 +01:00 committed by GitHub
commit cef65d922d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 721 additions and 145 deletions

View File

@ -1579,18 +1579,63 @@ def run_pipeline(
console.print(
f"\n[cyan]Running TradingAgents for {len(candidates)} tickers...[/cyan]"
f" [dim](up to 2 concurrent)[/dim]\n"
)
for c in candidates:
console.print(
f" [dim]▷ Queued:[/dim] [bold cyan]{c.ticker}[/bold cyan]"
f" [dim]{c.sector} · {c.conviction.upper()} conviction[/dim]"
)
console.print()
import time as _time
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn
pipeline_start = _time.monotonic()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[cyan]{task.completed}/{task.total}[/cyan]"),
TimeElapsedColumn(),
console=console,
transient=False,
) as progress:
overall = progress.add_task("[bold]Pipeline progress[/bold]", total=len(candidates))
def on_done(result, done_count, total_count):
ticker_elapsed = _time.monotonic() - pipeline_start
if result.error:
console.print(
f" [red]✗ {result.ticker}[/red]"
f" [dim]failed ({ticker_elapsed:.0f}s elapsed) — {result.error[:80]}[/dim]"
)
else:
decision_preview = str(result.final_trade_decision)[:70].replace("\n", " ")
console.print(
f" [green]✓ {result.ticker}[/green]"
f" [dim]({done_count}/{total_count}, {ticker_elapsed:.0f}s elapsed)[/dim]"
f"{decision_preview}"
)
progress.advance(overall)
try:
with Live(
Spinner("dots", text="Analyzing..."), console=console, transient=True
):
results = asyncio.run(
run_all_tickers(candidates, macro_context, config, analysis_date)
run_all_tickers(
candidates, macro_context, config, analysis_date,
on_ticker_done=on_done,
)
)
except Exception as e:
console.print(f"[red]Pipeline failed: {e}[/red]")
raise typer.Exit(1)
elapsed_total = _time.monotonic() - pipeline_start
console.print(
f"\n[bold green]All {len(candidates)} ticker(s) finished in {elapsed_total:.0f}s[/bold green]\n"
)
save_results(results, macro_context, output_dir)
# Write observability log

View File

@ -0,0 +1,215 @@
"""Live-data integration tests for the stockstats utilities.
These tests call the real yfinance API and therefore require network access.
They are marked with ``integration`` and are excluded from the default test
run (which uses ``--ignore=tests/integration``).
Run them explicitly with:
python -m pytest tests/integration/test_stockstats_live.py -v --override-ini="addopts="
The tests validate the fix for the "Invalid number of return arguments after
parsing column name: 'Date'" error that occurred because stockstats.wrap()
promotes the lowercase ``date`` column to the DataFrame index, so the old
``df["Date"]`` access caused stockstats to try to parse "Date" as an indicator
name. The fix uses ``df.index.strftime("%Y-%m-%d")`` instead.
"""
import pytest
import pandas as pd
pytestmark = pytest.mark.integration
# A well-known trading day we can use for assertions
_TEST_DATE = "2025-01-02"
_TEST_TICKER = "AAPL"
# ---------------------------------------------------------------------------
# StockstatsUtils.get_stock_stats
# ---------------------------------------------------------------------------
class TestStockstatsUtilsLive:
"""Live tests for StockstatsUtils.get_stock_stats against real yfinance data."""
def test_close_50_sma_returns_numeric(self):
"""close_50_sma indicator returns a numeric value for a known trading day."""
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", _TEST_DATE)
assert result != "N/A: Not a trading day (weekend or holiday)", (
f"Expected a numeric value for {_TEST_DATE}, got N/A (check if it's a holiday)"
)
# Should be a finite float-like value
assert float(result) > 0, f"close_50_sma should be positive, got: {result}"
def test_rsi_returns_value_in_valid_range(self):
"""RSI indicator returns a value in [0, 100] for a known trading day."""
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "rsi", _TEST_DATE)
assert result != "N/A: Not a trading day (weekend or holiday)", (
f"Expected numeric RSI for {_TEST_DATE}"
)
rsi = float(result)
assert 0.0 <= rsi <= 100.0, f"RSI must be in [0, 100], got: {rsi}"
def test_macd_returns_numeric(self):
"""MACD indicator returns a numeric value for a known trading day."""
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "macd", _TEST_DATE)
assert result != "N/A: Not a trading day (weekend or holiday)"
# MACD can be positive or negative — just confirm it's a valid float
float(result) # raises ValueError if not numeric
def test_weekend_returns_na(self):
"""A weekend date returns the N/A holiday/weekend message."""
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
# 2025-01-04 is a Saturday
result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", "2025-01-04")
assert result == "N/A: Not a trading day (weekend or holiday)", (
f"Expected N/A for Saturday 2025-01-04, got: {result}"
)
def test_no_date_column_error(self):
"""Calling get_stock_stats must NOT raise the 'Date' column parsing error."""
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
# This previously raised: Invalid number of return arguments after
# parsing column name: 'Date'
try:
StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", _TEST_DATE)
except Exception as e:
if "Invalid number of return arguments" in str(e) and "Date" in str(e):
pytest.fail(
"Regression: stockstats is still trying to parse 'Date' as an "
f"indicator. Error: {e}"
)
raise # re-raise unexpected errors
# ---------------------------------------------------------------------------
# _get_stock_stats_bulk
# ---------------------------------------------------------------------------
class TestGetStockStatsBulkLive:
"""Live tests for _get_stock_stats_bulk against real yfinance data."""
def test_returns_dict_with_date_keys(self):
"""Bulk method returns a non-empty dict with YYYY-MM-DD string keys."""
from tradingagents.dataflows.y_finance import _get_stock_stats_bulk
result = _get_stock_stats_bulk(_TEST_TICKER, "rsi", _TEST_DATE)
assert isinstance(result, dict), "Expected dict from _get_stock_stats_bulk"
assert len(result) > 0, "Expected non-empty result dict"
# Keys should all be YYYY-MM-DD strings
for key in list(result.keys())[:5]:
pd.Timestamp(key) # raises if not parseable
def test_trading_day_has_numeric_value(self):
"""A known trading day has a numeric (non-N/A) value in the bulk result."""
from tradingagents.dataflows.y_finance import _get_stock_stats_bulk
result = _get_stock_stats_bulk(_TEST_TICKER, "rsi", _TEST_DATE)
assert _TEST_DATE in result, (
f"Expected {_TEST_DATE} in bulk result dict. Keys sample: {list(result.keys())[:5]}"
)
value = result[_TEST_DATE]
assert value != "N/A", (
f"Expected numeric RSI for {_TEST_DATE}, got N/A (check if it's a holiday)"
)
float(value) # should be convertible to float
def test_no_date_column_parsing_error(self):
"""Bulk method must not raise the 'Date' column parsing error (regression guard)."""
from tradingagents.dataflows.y_finance import _get_stock_stats_bulk
try:
_get_stock_stats_bulk(_TEST_TICKER, "close_50_sma", _TEST_DATE)
except Exception as e:
if "Invalid number of return arguments" in str(e) and "Date" in str(e):
pytest.fail(
"Regression: _get_stock_stats_bulk still hits the 'Date' indicator "
f"parsing error. Error: {e}"
)
raise
def test_multiple_indicators_all_work(self):
"""All supported indicators can be computed without error."""
from tradingagents.dataflows.y_finance import _get_stock_stats_bulk
indicators = [
"close_50_sma",
"close_200_sma",
"close_10_ema",
"macd",
"macds",
"macdh",
"rsi",
"boll",
"boll_ub",
"boll_lb",
"atr",
]
for indicator in indicators:
try:
result = _get_stock_stats_bulk(_TEST_TICKER, indicator, _TEST_DATE)
assert isinstance(result, dict), f"{indicator}: expected dict"
assert len(result) > 0, f"{indicator}: expected non-empty dict"
except Exception as e:
pytest.fail(f"Indicator '{indicator}' raised an unexpected error: {e}")
# ---------------------------------------------------------------------------
# get_stock_stats_indicators_window (end-to-end with live data)
# ---------------------------------------------------------------------------
class TestGetStockStatsIndicatorsWindowLive:
"""Live end-to-end tests for get_stock_stats_indicators_window."""
def test_rsi_window_returns_formatted_string(self):
"""Window function returns a multi-line string with RSI values over a date range."""
from tradingagents.dataflows.y_finance import get_stock_stats_indicators_window
result = get_stock_stats_indicators_window(_TEST_TICKER, "rsi", _TEST_DATE, look_back_days=5)
assert isinstance(result, str)
assert "rsi" in result.lower()
assert _TEST_DATE in result
# Should have date: value lines
lines = [l for l in result.split("\n") if ":" in l and "-" in l]
assert len(lines) > 0, "Expected date:value lines in result"
def test_close_50_sma_window_contains_numeric_values(self):
"""50-day SMA window result contains actual numeric price values."""
from tradingagents.dataflows.y_finance import get_stock_stats_indicators_window
result = get_stock_stats_indicators_window(
_TEST_TICKER, "close_50_sma", _TEST_DATE, look_back_days=10
)
assert isinstance(result, str)
# At least some lines should have numeric values (not all N/A)
value_lines = [l for l in result.split("\n") if ":" in l and l.strip().startswith("20")]
numeric_values = []
for line in value_lines:
try:
val = line.split(":", 1)[1].strip()
numeric_values.append(float(val))
except (ValueError, IndexError):
pass # N/A lines are expected for weekends
assert len(numeric_values) > 0, (
"Expected at least some numeric 50-SMA values in the 10-day window"
)

View File

@ -0,0 +1,259 @@
"""Unit tests for the incident-fix improvements across the dataflows layer.
Tests cover:
1. _load_or_fetch_ohlcv dynamic cache filename, corruption detection + re-fetch
2. YFinanceError propagated by get_stockstats_indicator (no more silent return "")
3. _filter_csv_by_date_range explicit date column discovery (no positional assumption)
"""
import os
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _minimal_ohlcv_df(periods: int = 200) -> pd.DataFrame:
"""Return a minimal valid OHLCV DataFrame with a Date column."""
idx = pd.date_range("2024-01-02", periods=periods, freq="B")
return pd.DataFrame(
{
"Date": idx.strftime("%Y-%m-%d"),
"Open": [100.0] * periods,
"High": [105.0] * periods,
"Low": [95.0] * periods,
"Close": [102.0] * periods,
"Volume": [1_000_000] * periods,
}
)
# ---------------------------------------------------------------------------
# _load_or_fetch_ohlcv — cache + download logic
# ---------------------------------------------------------------------------
class TestLoadOrFetchOhlcv:
"""Tests for the unified OHLCV loader."""
def test_downloads_and_caches_on_first_call(self, tmp_path):
"""When no cache exists, yfinance is called and data is written to cache."""
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
expected_df = _minimal_ohlcv_df()
mock_downloaded = expected_df.copy()
mock_downloaded.index = pd.RangeIndex(len(mock_downloaded)) # simulate reset_index output
with (
patch("tradingagents.dataflows.stockstats_utils.get_config",
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
patch("tradingagents.dataflows.stockstats_utils.yf.download",
return_value=expected_df.set_index("Date")) as mock_dl,
):
result = _load_or_fetch_ohlcv("AAPL")
mock_dl.assert_called_once()
# Cache file must exist after the call
csv_files = list(tmp_path.glob("AAPL-YFin-data-*.csv"))
assert len(csv_files) == 1, "Expected exactly one cache file to be created"
def test_uses_cache_on_second_call(self, tmp_path):
"""When cache already exists, yfinance.download is NOT called again."""
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
# Write a valid cache file manually
df = _minimal_ohlcv_df()
today = pd.Timestamp.today()
start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d")
end = today.strftime("%Y-%m-%d")
cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv"
df.to_csv(cache_file, index=False)
with (
patch("tradingagents.dataflows.stockstats_utils.get_config",
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
patch("tradingagents.dataflows.stockstats_utils.yf.download") as mock_dl,
):
result = _load_or_fetch_ohlcv("AAPL")
mock_dl.assert_not_called()
assert len(result) == 200
def test_corrupt_cache_is_deleted_and_refetched(self, tmp_path):
"""A corrupt (unparseable) cache file is deleted and yfinance is called again."""
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
today = pd.Timestamp.today()
start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d")
end = today.strftime("%Y-%m-%d")
cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv"
cache_file.write_text(",,,,CORRUPT,,\x00\x00BINARY GARBAGE")
fresh_df = _minimal_ohlcv_df()
with (
patch("tradingagents.dataflows.stockstats_utils.get_config",
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
patch("tradingagents.dataflows.stockstats_utils.yf.download",
return_value=fresh_df.set_index("Date")) as mock_dl,
):
result = _load_or_fetch_ohlcv("AAPL")
mock_dl.assert_called_once()
def test_truncated_cache_triggers_refetch(self, tmp_path):
"""A cache file with fewer than 50 rows is treated as truncated and re-fetched."""
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
tiny_df = _minimal_ohlcv_df(periods=10) # only 10 rows — well below threshold
today = pd.Timestamp.today()
start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d")
end = today.strftime("%Y-%m-%d")
cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv"
tiny_df.to_csv(cache_file, index=False)
fresh_df = _minimal_ohlcv_df()
with (
patch("tradingagents.dataflows.stockstats_utils.get_config",
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
patch("tradingagents.dataflows.stockstats_utils.yf.download",
return_value=fresh_df.set_index("Date")) as mock_dl,
):
result = _load_or_fetch_ohlcv("AAPL")
mock_dl.assert_called_once()
def test_empty_download_raises_yfinance_error(self, tmp_path):
"""An empty DataFrame from yfinance raises YFinanceError (not a silent return)."""
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv, YFinanceError
with (
patch("tradingagents.dataflows.stockstats_utils.get_config",
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
patch("tradingagents.dataflows.stockstats_utils.yf.download",
return_value=pd.DataFrame()),
):
with pytest.raises(YFinanceError, match="no data"):
_load_or_fetch_ohlcv("INVALID_TICKER_XYZ")
def test_cache_filename_is_dynamic_not_hardcoded(self, tmp_path):
"""Cache filename contains today's date (not a hardcoded historical date like 2025-03-25)."""
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
df = _minimal_ohlcv_df()
with (
patch("tradingagents.dataflows.stockstats_utils.get_config",
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
patch("tradingagents.dataflows.stockstats_utils.yf.download",
return_value=df.set_index("Date")),
):
_load_or_fetch_ohlcv("AAPL")
csv_files = list(tmp_path.glob("AAPL-YFin-data-*.csv"))
assert len(csv_files) == 1
filename = csv_files[0].name
# Must NOT contain the old hardcoded stale date
assert "2025-03-25" not in filename, (
f"Cache filename contains the old hardcoded stale date! Got: {filename}"
)
# Must contain today's year
today_year = str(pd.Timestamp.today().year)
assert today_year in filename, f"Expected today's year {today_year} in filename: {filename}"
# ---------------------------------------------------------------------------
# YFinanceError propagation (no more silent return "")
# ---------------------------------------------------------------------------
class TestYFinanceErrorPropagation:
"""Tests that YFinanceError is raised (not swallowed) by get_stockstats_indicator."""
def test_get_stockstats_indicator_raises_yfinance_error_on_failure(self, tmp_path):
"""get_stockstats_indicator raises YFinanceError when yfinance returns empty data."""
from tradingagents.dataflows.stockstats_utils import YFinanceError
from tradingagents.dataflows.y_finance import get_stockstats_indicator
with (
patch("tradingagents.dataflows.stockstats_utils.get_config",
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
patch("tradingagents.dataflows.stockstats_utils.yf.download",
return_value=pd.DataFrame()),
):
with pytest.raises(YFinanceError):
get_stockstats_indicator("INVALID", "rsi", "2025-01-02")
def test_yfinance_error_is_not_swallowed_as_empty_string(self, tmp_path):
"""Regression test: get_stockstats_indicator must NOT return empty string on error."""
from tradingagents.dataflows.stockstats_utils import YFinanceError
from tradingagents.dataflows.y_finance import get_stockstats_indicator
with (
patch("tradingagents.dataflows.stockstats_utils.get_config",
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
patch("tradingagents.dataflows.stockstats_utils.yf.download",
return_value=pd.DataFrame()),
):
result = None
try:
result = get_stockstats_indicator("INVALID", "rsi", "2025-01-02")
except YFinanceError:
pass # This is the correct behaviour
assert result is None, (
"get_stockstats_indicator should raise YFinanceError, not silently return a value. "
f"Got: {result!r}"
)
# ---------------------------------------------------------------------------
# _filter_csv_by_date_range — explicit column discovery
# ---------------------------------------------------------------------------
class TestFilterCsvByDateRange:
"""Tests for the fixed _filter_csv_by_date_range in alpha_vantage_common."""
def _make_av_csv(self, date_col_name: str = "time") -> str:
return (
f"{date_col_name},SMA\n"
"2024-01-02,230.5\n"
"2024-01-03,231.0\n"
"2024-01-08,235.5\n"
)
def test_filters_with_standard_time_column(self):
"""Standard Alpha Vantage CSV with 'time' header filters correctly."""
from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range
result = _filter_csv_by_date_range(self._make_av_csv("time"), "2024-01-03", "2024-01-08")
assert "2024-01-02" not in result
assert "2024-01-03" in result
assert "2024-01-08" in result
def test_filters_with_timestamp_column(self):
"""CSV with 'timestamp' header (alternative AV format) also works."""
from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range
result = _filter_csv_by_date_range(self._make_av_csv("timestamp"), "2024-01-03", "2024-01-08")
assert "2024-01-02" not in result
assert "2024-01-03" in result
def test_missing_date_column_raises_not_silently_filters_wrong_column(self):
"""When no recognised date column exists, raises ValueError immediately.
Previously it would silently use df.columns[0] and return garbage."""
from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range
bad_csv = "price,volume\n102.0,1000\n103.0,2000\n"
# The fixed code raises ValueError; the old code would silently try to
# parse the 'price' column as a date and return the original data.
with pytest.raises(ValueError, match="Date column not found"):
_filter_csv_by_date_range(bad_csv, "2024-01-01", "2024-01-31")
def test_empty_csv_returns_empty(self):
"""Empty input returns empty output without error."""
from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range
result = _filter_csv_by_date_range("", "2024-01-01", "2024-01-31")
assert result == ""

View File

@ -1,3 +1,4 @@
import logging
import os
import requests
import pandas as pd
@ -7,6 +8,9 @@ import time as _time
from datetime import datetime
from io import StringIO
logger = logging.getLogger(__name__)
API_BASE_URL = "https://www.alphavantage.co/query"
def get_api_key() -> str:
@ -228,8 +232,15 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
# Parse CSV data
df = pd.read_csv(StringIO(csv_data))
# Assume the first column is the date column (timestamp)
date_col = df.columns[0]
# Find the date column by name rather than positional assumption.
# Alpha Vantage returns 'time' or 'timestamp' as the date column header.
_date_candidates = [c for c in df.columns if c.lower() in ("time", "timestamp", "date")]
if not _date_candidates:
raise ValueError(
f"Date column not found in Alpha Vantage CSV. "
f"Expected 'time' or 'timestamp', got columns: {list(df.columns)}"
)
date_col = _date_candidates[0]
df[date_col] = pd.to_datetime(df[date_col])
# Filter by date range
@ -241,7 +252,11 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
# Convert back to CSV string
return filtered_df.to_csv(index=False)
except ValueError:
# ValueError = a programming/data-contract error (e.g. missing date column).
# Re-raise so callers see it immediately rather than getting silently wrong data.
raise
except Exception as e:
# If filtering fails, return original data with a warning
print(f"Warning: Failed to filter CSV data by date range: {e}")
# Transient errors (I/O, malformed CSV): log and return original data as-is.
logger.warning("Failed to filter CSV data by date range: %s", e)
return csv_data

View File

@ -40,6 +40,7 @@ from .alpha_vantage_scanner import (
)
from .alpha_vantage_common import AlphaVantageError, AlphaVantageRateLimitError, RateLimitError
from .finnhub_common import FinnhubError
from .stockstats_utils import YFinanceError
from .finnhub_news import get_insider_transactions as get_finnhub_insider_transactions
from .finnhub_scanner import (
get_market_indices_finnhub,
@ -262,7 +263,7 @@ def route_to_vendor(method: str, *args, **kwargs):
if rl:
rl.log_vendor_call(method, vendor, True, (time.time() - t0) * 1000, args_summary=args_summary)
return result
except (AlphaVantageError, FinnhubError, ConnectionError, TimeoutError) as exc:
except (AlphaVantageError, FinnhubError, YFinanceError, ConnectionError, TimeoutError) as exc:
if rl:
rl.log_vendor_call(method, vendor, False, (time.time() - t0) * 1000, error=str(exc)[:200], args_summary=args_summary)
last_error = exc

View File

@ -1,3 +1,4 @@
import logging
import pandas as pd
import yfinance as yf
from stockstats import wrap
@ -5,6 +6,22 @@ from typing import Annotated
import os
from .config import get_config
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Public exception — lets callers catch stockstats/yfinance failures by type
# ---------------------------------------------------------------------------
class YFinanceError(Exception):
"""Raised when yfinance or stockstats data fetching/processing fails."""
pass
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
"""Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps.
@ -29,6 +46,80 @@ def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
return df
def _load_or_fetch_ohlcv(symbol: str) -> pd.DataFrame:
"""Single authority for loading OHLCV data: cache → yfinance download → normalize.
Cache filename is always derived from today's date (15-year window) so the
cache key never goes stale. If a cached file exists but is corrupt (too few
rows to be useful), it is deleted and re-fetched rather than silently
returning bad data.
Raises:
YFinanceError: if the download returns an empty DataFrame or fails.
"""
config = get_config()
today_date = pd.Timestamp.today()
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = today_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
# ── Try to load from cache ────────────────────────────────────────────────
if os.path.exists(data_file):
try:
data = pd.read_csv(data_file) # no on_bad_lines="skip" — we want to know about corruption
except Exception as exc:
logger.warning(
"Corrupt cache file for %s (%s) — deleting and re-fetching.", symbol, exc
)
os.remove(data_file)
data = None
else:
# Validate: a 15-year daily file should have well over 100 rows
if len(data) < 50:
logger.warning(
"Cache file for %s has only %d rows — likely truncated, re-fetching.",
symbol, len(data),
)
os.remove(data_file)
data = None
else:
data = None
# ── Download from yfinance if cache miss / corrupt ────────────────────────
if data is None:
raw = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
if raw.empty:
raise YFinanceError(
f"yfinance returned no data for symbol '{symbol}' "
f"({start_date_str}{end_date_str})"
)
data = raw.reset_index()
data.to_csv(data_file, index=False)
logger.debug("Downloaded and cached OHLCV for %s%s", symbol, data_file)
return data
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
class StockstatsUtils:
@staticmethod
def get_stock_stats(
@ -40,48 +131,20 @@ class StockstatsUtils:
str, "curr date for retrieving stock price data, YYYY-mm-dd"
],
):
config = get_config()
today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date)
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
# Ensure cache directory exists
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file, on_bad_lines="skip")
else:
data = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data = _clean_dataframe(data)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
data = _load_or_fetch_ohlcv(symbol)
data = _clean_dataframe(data)
df = wrap(data)
# After wrap(), the date column becomes the datetime index (named 'date').
# Access via df.index, not df["Date"] which stockstats would try to parse as an indicator.
df[indicator] # trigger stockstats to calculate the indicator
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
date_index_strs = df.index.strftime("%Y-%m-%d")
matching_rows = df[date_index_strs == curr_date_str]
if not matching_rows.empty:
indicator_value = matching_rows[indicator].values[0]
return indicator_value
return matching_rows[indicator].values[0]
else:
return "N/A: Not a trading day (weekend or holiday)"

View File

@ -1,9 +1,14 @@
import logging
from typing import Annotated
from datetime import datetime
from dateutil.relativedelta import relativedelta
import pandas as pd
import yfinance as yf
import os
from .stockstats_utils import StockstatsUtils, _clean_dataframe
from .stockstats_utils import StockstatsUtils, YFinanceError, _clean_dataframe, _load_or_fetch_ohlcv
logger = logging.getLogger(__name__)
def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"],
@ -191,70 +196,32 @@ def _get_stock_stats_bulk(
) -> dict:
"""
Optimized bulk calculation of stock stats indicators.
Fetches data once and calculates indicator for all available dates.
Fetches data once (via shared _load_or_fetch_ohlcv cache) and calculates
the indicator for all available dates.
Returns dict mapping date strings to indicator values.
Raises:
YFinanceError: if data cannot be loaded or indicator calculation fails.
"""
from .config import get_config
import pandas as pd
from stockstats import wrap
import os
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
if not online:
# Local data path
try:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
),
on_bad_lines="skip",
)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
else:
# Online data fetching with caching
today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date)
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file, on_bad_lines="skip")
else:
data = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
# Single authority: _load_or_fetch_ohlcv handles both online and local modes,
# dynamic cache filename, and corpus validation — no duplicated download logic here.
data = _load_or_fetch_ohlcv(symbol)
data = _clean_dataframe(data)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
# After wrap(), the date column becomes the datetime index (named 'date').
# Access via df.index, not df["Date"] which stockstats would try to parse as an indicator.
# Calculate the indicator for all rows at once
df[indicator] # This triggers stockstats to calculate the indicator
# Create a dictionary mapping date strings to indicator values
# Optimized: replaced iterrows() with vectorized operations for performance
return df.set_index("Date")[indicator].fillna("N/A").astype(str).to_dict()
# Optimized: vectorized operations for performance using correct DatetimeIndex
series = df[indicator].copy()
series.index = series.index.strftime("%Y-%m-%d")
return series.fillna("N/A").astype(str).to_dict()
def get_stockstats_indicator(
@ -264,25 +231,15 @@ def get_stockstats_indicator(
str, "The current trading date you are trading on, YYYY-mm-dd"
],
) -> str:
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
curr_date = curr_date_dt.strftime("%Y-%m-%d")
try:
indicator_value = StockstatsUtils.get_stock_stats(
symbol,
indicator,
curr_date,
)
except Exception as e:
print(
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
)
return ""
# Raises YFinanceError on failure — caller (route_to_vendor) catches typed exceptions.
indicator_value = StockstatsUtils.get_stock_stats(symbol, indicator, curr_date)
return str(indicator_value)
def get_fundamentals(
ticker: Annotated[str, "ticker symbol of the company"],
curr_date: Annotated[str, "current date (not used for yfinance)"] = None

View File

@ -5,12 +5,14 @@ from __future__ import annotations
import asyncio
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from tradingagents.agents.utils.json_utils import extract_json
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Literal
from typing import Callable, Literal
logger = logging.getLogger(__name__)
@ -172,15 +174,6 @@ def run_ticker_analysis(
NOTE: TradingAgentsGraph is synchronous call this from a thread pool
when running multiple tickers concurrently (see run_all_tickers).
Args:
candidate: The stock candidate to analyse.
macro_context: Macro context to embed in the result.
config: TradingAgents configuration dict.
analysis_date: Date string in YYYY-MM-DD format.
Returns:
TickerResult with all report fields populated, or error set on failure.
"""
result = TickerResult(
ticker=candidate.ticker,
@ -189,11 +182,14 @@ def run_ticker_analysis(
analysis_date=analysis_date,
)
logger.info("Starting analysis for %s on %s", candidate.ticker, analysis_date)
t0 = time.monotonic()
logger.info(
"[%s] ▶ Starting analysis (%s, %s conviction)",
candidate.ticker, candidate.sector, candidate.conviction,
)
try:
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.observability import get_run_logger
rl = get_run_logger()
@ -210,23 +206,31 @@ def run_ticker_analysis(
result.risk_debate = str(final_state.get("risk_debate_state", ""))
result.final_trade_decision = decision
elapsed = time.monotonic() - t0
logger.info(
"Analysis complete for %s: %s", candidate.ticker, str(decision)[:120]
"[%s] ✓ Analysis complete in %.0fs — decision: %s",
candidate.ticker, elapsed, str(decision)[:80],
)
except Exception as exc:
logger.error("Analysis failed for %s: %s", candidate.ticker, exc, exc_info=True)
elapsed = time.monotonic() - t0
logger.error(
"[%s] ✗ Analysis FAILED after %.0fs: %s",
candidate.ticker, elapsed, exc, exc_info=True,
)
result.error = str(exc)
return result
async def run_all_tickers(
candidates: list[StockCandidate],
macro_context: MacroContext,
config: dict,
analysis_date: str,
max_concurrent: int = 2,
on_ticker_done: Callable[[TickerResult, int, int], None] | None = None,
) -> list[TickerResult]:
"""Run TradingAgents for every candidate with controlled concurrency.
@ -239,28 +243,45 @@ async def run_all_tickers(
config: TradingAgents configuration dict.
analysis_date: Date string in YYYY-MM-DD format.
max_concurrent: Maximum number of tickers to process in parallel.
on_ticker_done: Optional callback(result, done_count, total_count) fired
after each ticker finishes use this to drive a progress bar.
Returns:
List of TickerResult in completion order.
"""
loop = asyncio.get_running_loop()
executor = ThreadPoolExecutor(max_workers=max_concurrent)
try:
tasks = [
loop.run_in_executor(
executor,
total = len(candidates)
results: list[TickerResult] = []
# Use a semaphore so at most max_concurrent tickers run simultaneously,
# but we still get individual completion callbacks via as_completed.
semaphore = asyncio.Semaphore(max_concurrent)
async def _run_one(candidate: StockCandidate) -> TickerResult:
async with semaphore:
return await loop.run_in_executor(
None, # use default ThreadPoolExecutor
run_ticker_analysis,
c,
candidate,
macro_context,
config,
analysis_date,
)
for c in candidates
]
results = await asyncio.gather(*tasks)
return list(results)
finally:
executor.shutdown(wait=False)
tasks = [asyncio.create_task(_run_one(c)) for c in candidates]
done_count = 0
for coro in asyncio.as_completed(tasks):
result = await coro
done_count += 1
results.append(result)
if on_ticker_done is not None:
try:
on_ticker_done(result, done_count, total)
except Exception: # never let a callback crash the pipeline
pass
return results
# ─── Reporting ────────────────────────────────────────────────────────────────