Merge pull request #85 from aguzererler/fix/dataflows-incident-hardened-error-handling
fix: harden dataflows layer against silent failures and data corruption
This commit is contained in:
commit
cef65d922d
61
cli/main.py
61
cli/main.py
|
|
@ -1579,17 +1579,62 @@ def run_pipeline(
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
f"\n[cyan]Running TradingAgents for {len(candidates)} tickers...[/cyan]"
|
f"\n[cyan]Running TradingAgents for {len(candidates)} tickers...[/cyan]"
|
||||||
|
f" [dim](up to 2 concurrent)[/dim]\n"
|
||||||
)
|
)
|
||||||
try:
|
for c in candidates:
|
||||||
with Live(
|
console.print(
|
||||||
Spinner("dots", text="Analyzing..."), console=console, transient=True
|
f" [dim]▷ Queued:[/dim] [bold cyan]{c.ticker}[/bold cyan]"
|
||||||
):
|
f" [dim]{c.sector} · {c.conviction.upper()} conviction[/dim]"
|
||||||
|
)
|
||||||
|
console.print()
|
||||||
|
import time as _time
|
||||||
|
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn
|
||||||
|
|
||||||
|
pipeline_start = _time.monotonic()
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TextColumn("[cyan]{task.completed}/{task.total}[/cyan]"),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
console=console,
|
||||||
|
transient=False,
|
||||||
|
) as progress:
|
||||||
|
overall = progress.add_task("[bold]Pipeline progress[/bold]", total=len(candidates))
|
||||||
|
|
||||||
|
def on_done(result, done_count, total_count):
|
||||||
|
ticker_elapsed = _time.monotonic() - pipeline_start
|
||||||
|
if result.error:
|
||||||
|
console.print(
|
||||||
|
f" [red]✗ {result.ticker}[/red]"
|
||||||
|
f" [dim]failed ({ticker_elapsed:.0f}s elapsed) — {result.error[:80]}[/dim]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decision_preview = str(result.final_trade_decision)[:70].replace("\n", " ")
|
||||||
|
console.print(
|
||||||
|
f" [green]✓ {result.ticker}[/green]"
|
||||||
|
f" [dim]({done_count}/{total_count}, {ticker_elapsed:.0f}s elapsed)[/dim]"
|
||||||
|
f" → {decision_preview}"
|
||||||
|
)
|
||||||
|
progress.advance(overall)
|
||||||
|
|
||||||
|
try:
|
||||||
results = asyncio.run(
|
results = asyncio.run(
|
||||||
run_all_tickers(candidates, macro_context, config, analysis_date)
|
run_all_tickers(
|
||||||
|
candidates, macro_context, config, analysis_date,
|
||||||
|
on_ticker_done=on_done,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]Pipeline failed: {e}[/red]")
|
console.print(f"[red]Pipeline failed: {e}[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
elapsed_total = _time.monotonic() - pipeline_start
|
||||||
|
console.print(
|
||||||
|
f"\n[bold green]All {len(candidates)} ticker(s) finished in {elapsed_total:.0f}s[/bold green]\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
save_results(results, macro_context, output_dir)
|
save_results(results, macro_context, output_dir)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,215 @@
|
||||||
|
"""Live-data integration tests for the stockstats utilities.
|
||||||
|
|
||||||
|
These tests call the real yfinance API and therefore require network access.
|
||||||
|
They are marked with ``integration`` and are excluded from the default test
|
||||||
|
run (which uses ``--ignore=tests/integration``).
|
||||||
|
|
||||||
|
Run them explicitly with:
|
||||||
|
|
||||||
|
python -m pytest tests/integration/test_stockstats_live.py -v --override-ini="addopts="
|
||||||
|
|
||||||
|
The tests validate the fix for the "Invalid number of return arguments after
|
||||||
|
parsing column name: 'Date'" error that occurred because stockstats.wrap()
|
||||||
|
promotes the lowercase ``date`` column to the DataFrame index, so the old
|
||||||
|
``df["Date"]`` access caused stockstats to try to parse "Date" as an indicator
|
||||||
|
name. The fix uses ``df.index.strftime("%Y-%m-%d")`` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
# A well-known trading day we can use for assertions
|
||||||
|
_TEST_DATE = "2025-01-02"
|
||||||
|
_TEST_TICKER = "AAPL"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StockstatsUtils.get_stock_stats
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestStockstatsUtilsLive:
|
||||||
|
"""Live tests for StockstatsUtils.get_stock_stats against real yfinance data."""
|
||||||
|
|
||||||
|
def test_close_50_sma_returns_numeric(self):
|
||||||
|
"""close_50_sma indicator returns a numeric value for a known trading day."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
|
||||||
|
|
||||||
|
result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", _TEST_DATE)
|
||||||
|
|
||||||
|
assert result != "N/A: Not a trading day (weekend or holiday)", (
|
||||||
|
f"Expected a numeric value for {_TEST_DATE}, got N/A (check if it's a holiday)"
|
||||||
|
)
|
||||||
|
# Should be a finite float-like value
|
||||||
|
assert float(result) > 0, f"close_50_sma should be positive, got: {result}"
|
||||||
|
|
||||||
|
def test_rsi_returns_value_in_valid_range(self):
|
||||||
|
"""RSI indicator returns a value in [0, 100] for a known trading day."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
|
||||||
|
|
||||||
|
result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "rsi", _TEST_DATE)
|
||||||
|
|
||||||
|
assert result != "N/A: Not a trading day (weekend or holiday)", (
|
||||||
|
f"Expected numeric RSI for {_TEST_DATE}"
|
||||||
|
)
|
||||||
|
rsi = float(result)
|
||||||
|
assert 0.0 <= rsi <= 100.0, f"RSI must be in [0, 100], got: {rsi}"
|
||||||
|
|
||||||
|
def test_macd_returns_numeric(self):
|
||||||
|
"""MACD indicator returns a numeric value for a known trading day."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
|
||||||
|
|
||||||
|
result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "macd", _TEST_DATE)
|
||||||
|
|
||||||
|
assert result != "N/A: Not a trading day (weekend or holiday)"
|
||||||
|
# MACD can be positive or negative — just confirm it's a valid float
|
||||||
|
float(result) # raises ValueError if not numeric
|
||||||
|
|
||||||
|
def test_weekend_returns_na(self):
|
||||||
|
"""A weekend date returns the N/A holiday/weekend message."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
|
||||||
|
|
||||||
|
# 2025-01-04 is a Saturday
|
||||||
|
result = StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", "2025-01-04")
|
||||||
|
|
||||||
|
assert result == "N/A: Not a trading day (weekend or holiday)", (
|
||||||
|
f"Expected N/A for Saturday 2025-01-04, got: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_date_column_error(self):
|
||||||
|
"""Calling get_stock_stats must NOT raise the 'Date' column parsing error."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
|
||||||
|
|
||||||
|
# This previously raised: Invalid number of return arguments after
|
||||||
|
# parsing column name: 'Date'
|
||||||
|
try:
|
||||||
|
StockstatsUtils.get_stock_stats(_TEST_TICKER, "close_50_sma", _TEST_DATE)
|
||||||
|
except Exception as e:
|
||||||
|
if "Invalid number of return arguments" in str(e) and "Date" in str(e):
|
||||||
|
pytest.fail(
|
||||||
|
"Regression: stockstats is still trying to parse 'Date' as an "
|
||||||
|
f"indicator. Error: {e}"
|
||||||
|
)
|
||||||
|
raise # re-raise unexpected errors
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _get_stock_stats_bulk
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetStockStatsBulkLive:
|
||||||
|
"""Live tests for _get_stock_stats_bulk against real yfinance data."""
|
||||||
|
|
||||||
|
def test_returns_dict_with_date_keys(self):
|
||||||
|
"""Bulk method returns a non-empty dict with YYYY-MM-DD string keys."""
|
||||||
|
from tradingagents.dataflows.y_finance import _get_stock_stats_bulk
|
||||||
|
|
||||||
|
result = _get_stock_stats_bulk(_TEST_TICKER, "rsi", _TEST_DATE)
|
||||||
|
|
||||||
|
assert isinstance(result, dict), "Expected dict from _get_stock_stats_bulk"
|
||||||
|
assert len(result) > 0, "Expected non-empty result dict"
|
||||||
|
|
||||||
|
# Keys should all be YYYY-MM-DD strings
|
||||||
|
for key in list(result.keys())[:5]:
|
||||||
|
pd.Timestamp(key) # raises if not parseable
|
||||||
|
|
||||||
|
def test_trading_day_has_numeric_value(self):
|
||||||
|
"""A known trading day has a numeric (non-N/A) value in the bulk result."""
|
||||||
|
from tradingagents.dataflows.y_finance import _get_stock_stats_bulk
|
||||||
|
|
||||||
|
result = _get_stock_stats_bulk(_TEST_TICKER, "rsi", _TEST_DATE)
|
||||||
|
|
||||||
|
assert _TEST_DATE in result, (
|
||||||
|
f"Expected {_TEST_DATE} in bulk result dict. Keys sample: {list(result.keys())[:5]}"
|
||||||
|
)
|
||||||
|
value = result[_TEST_DATE]
|
||||||
|
assert value != "N/A", (
|
||||||
|
f"Expected numeric RSI for {_TEST_DATE}, got N/A (check if it's a holiday)"
|
||||||
|
)
|
||||||
|
float(value) # should be convertible to float
|
||||||
|
|
||||||
|
def test_no_date_column_parsing_error(self):
|
||||||
|
"""Bulk method must not raise the 'Date' column parsing error (regression guard)."""
|
||||||
|
from tradingagents.dataflows.y_finance import _get_stock_stats_bulk
|
||||||
|
|
||||||
|
try:
|
||||||
|
_get_stock_stats_bulk(_TEST_TICKER, "close_50_sma", _TEST_DATE)
|
||||||
|
except Exception as e:
|
||||||
|
if "Invalid number of return arguments" in str(e) and "Date" in str(e):
|
||||||
|
pytest.fail(
|
||||||
|
"Regression: _get_stock_stats_bulk still hits the 'Date' indicator "
|
||||||
|
f"parsing error. Error: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def test_multiple_indicators_all_work(self):
|
||||||
|
"""All supported indicators can be computed without error."""
|
||||||
|
from tradingagents.dataflows.y_finance import _get_stock_stats_bulk
|
||||||
|
|
||||||
|
indicators = [
|
||||||
|
"close_50_sma",
|
||||||
|
"close_200_sma",
|
||||||
|
"close_10_ema",
|
||||||
|
"macd",
|
||||||
|
"macds",
|
||||||
|
"macdh",
|
||||||
|
"rsi",
|
||||||
|
"boll",
|
||||||
|
"boll_ub",
|
||||||
|
"boll_lb",
|
||||||
|
"atr",
|
||||||
|
]
|
||||||
|
|
||||||
|
for indicator in indicators:
|
||||||
|
try:
|
||||||
|
result = _get_stock_stats_bulk(_TEST_TICKER, indicator, _TEST_DATE)
|
||||||
|
assert isinstance(result, dict), f"{indicator}: expected dict"
|
||||||
|
assert len(result) > 0, f"{indicator}: expected non-empty dict"
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Indicator '{indicator}' raised an unexpected error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_stock_stats_indicators_window (end-to-end with live data)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetStockStatsIndicatorsWindowLive:
|
||||||
|
"""Live end-to-end tests for get_stock_stats_indicators_window."""
|
||||||
|
|
||||||
|
def test_rsi_window_returns_formatted_string(self):
|
||||||
|
"""Window function returns a multi-line string with RSI values over a date range."""
|
||||||
|
from tradingagents.dataflows.y_finance import get_stock_stats_indicators_window
|
||||||
|
|
||||||
|
result = get_stock_stats_indicators_window(_TEST_TICKER, "rsi", _TEST_DATE, look_back_days=5)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "rsi" in result.lower()
|
||||||
|
assert _TEST_DATE in result
|
||||||
|
# Should have date: value lines
|
||||||
|
lines = [l for l in result.split("\n") if ":" in l and "-" in l]
|
||||||
|
assert len(lines) > 0, "Expected date:value lines in result"
|
||||||
|
|
||||||
|
def test_close_50_sma_window_contains_numeric_values(self):
|
||||||
|
"""50-day SMA window result contains actual numeric price values."""
|
||||||
|
from tradingagents.dataflows.y_finance import get_stock_stats_indicators_window
|
||||||
|
|
||||||
|
result = get_stock_stats_indicators_window(
|
||||||
|
_TEST_TICKER, "close_50_sma", _TEST_DATE, look_back_days=10
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
# At least some lines should have numeric values (not all N/A)
|
||||||
|
value_lines = [l for l in result.split("\n") if ":" in l and l.strip().startswith("20")]
|
||||||
|
numeric_values = []
|
||||||
|
for line in value_lines:
|
||||||
|
try:
|
||||||
|
val = line.split(":", 1)[1].strip()
|
||||||
|
numeric_values.append(float(val))
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass # N/A lines are expected for weekends
|
||||||
|
|
||||||
|
assert len(numeric_values) > 0, (
|
||||||
|
"Expected at least some numeric 50-SMA values in the 10-day window"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,259 @@
|
||||||
|
"""Unit tests for the incident-fix improvements across the dataflows layer.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
1. _load_or_fetch_ohlcv — dynamic cache filename, corruption detection + re-fetch
|
||||||
|
2. YFinanceError — propagated by get_stockstats_indicator (no more silent return "")
|
||||||
|
3. _filter_csv_by_date_range — explicit date column discovery (no positional assumption)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _minimal_ohlcv_df(periods: int = 200) -> pd.DataFrame:
|
||||||
|
"""Return a minimal valid OHLCV DataFrame with a Date column."""
|
||||||
|
idx = pd.date_range("2024-01-02", periods=periods, freq="B")
|
||||||
|
return pd.DataFrame(
|
||||||
|
{
|
||||||
|
"Date": idx.strftime("%Y-%m-%d"),
|
||||||
|
"Open": [100.0] * periods,
|
||||||
|
"High": [105.0] * periods,
|
||||||
|
"Low": [95.0] * periods,
|
||||||
|
"Close": [102.0] * periods,
|
||||||
|
"Volume": [1_000_000] * periods,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _load_or_fetch_ohlcv — cache + download logic
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestLoadOrFetchOhlcv:
|
||||||
|
"""Tests for the unified OHLCV loader."""
|
||||||
|
|
||||||
|
def test_downloads_and_caches_on_first_call(self, tmp_path):
|
||||||
|
"""When no cache exists, yfinance is called and data is written to cache."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
|
||||||
|
|
||||||
|
expected_df = _minimal_ohlcv_df()
|
||||||
|
mock_downloaded = expected_df.copy()
|
||||||
|
mock_downloaded.index = pd.RangeIndex(len(mock_downloaded)) # simulate reset_index output
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.get_config",
|
||||||
|
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.yf.download",
|
||||||
|
return_value=expected_df.set_index("Date")) as mock_dl,
|
||||||
|
):
|
||||||
|
result = _load_or_fetch_ohlcv("AAPL")
|
||||||
|
mock_dl.assert_called_once()
|
||||||
|
|
||||||
|
# Cache file must exist after the call
|
||||||
|
csv_files = list(tmp_path.glob("AAPL-YFin-data-*.csv"))
|
||||||
|
assert len(csv_files) == 1, "Expected exactly one cache file to be created"
|
||||||
|
|
||||||
|
def test_uses_cache_on_second_call(self, tmp_path):
|
||||||
|
"""When cache already exists, yfinance.download is NOT called again."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
|
||||||
|
|
||||||
|
# Write a valid cache file manually
|
||||||
|
df = _minimal_ohlcv_df()
|
||||||
|
today = pd.Timestamp.today()
|
||||||
|
start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d")
|
||||||
|
end = today.strftime("%Y-%m-%d")
|
||||||
|
cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv"
|
||||||
|
df.to_csv(cache_file, index=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.get_config",
|
||||||
|
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.yf.download") as mock_dl,
|
||||||
|
):
|
||||||
|
result = _load_or_fetch_ohlcv("AAPL")
|
||||||
|
mock_dl.assert_not_called()
|
||||||
|
|
||||||
|
assert len(result) == 200
|
||||||
|
|
||||||
|
def test_corrupt_cache_is_deleted_and_refetched(self, tmp_path):
|
||||||
|
"""A corrupt (unparseable) cache file is deleted and yfinance is called again."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
|
||||||
|
|
||||||
|
today = pd.Timestamp.today()
|
||||||
|
start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d")
|
||||||
|
end = today.strftime("%Y-%m-%d")
|
||||||
|
cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv"
|
||||||
|
cache_file.write_text(",,,,CORRUPT,,\x00\x00BINARY GARBAGE")
|
||||||
|
|
||||||
|
fresh_df = _minimal_ohlcv_df()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.get_config",
|
||||||
|
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.yf.download",
|
||||||
|
return_value=fresh_df.set_index("Date")) as mock_dl,
|
||||||
|
):
|
||||||
|
result = _load_or_fetch_ohlcv("AAPL")
|
||||||
|
mock_dl.assert_called_once()
|
||||||
|
|
||||||
|
def test_truncated_cache_triggers_refetch(self, tmp_path):
|
||||||
|
"""A cache file with fewer than 50 rows is treated as truncated and re-fetched."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
|
||||||
|
|
||||||
|
tiny_df = _minimal_ohlcv_df(periods=10) # only 10 rows — well below threshold
|
||||||
|
today = pd.Timestamp.today()
|
||||||
|
start = (today - pd.DateOffset(years=15)).strftime("%Y-%m-%d")
|
||||||
|
end = today.strftime("%Y-%m-%d")
|
||||||
|
cache_file = tmp_path / f"AAPL-YFin-data-{start}-{end}.csv"
|
||||||
|
tiny_df.to_csv(cache_file, index=False)
|
||||||
|
|
||||||
|
fresh_df = _minimal_ohlcv_df()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.get_config",
|
||||||
|
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.yf.download",
|
||||||
|
return_value=fresh_df.set_index("Date")) as mock_dl,
|
||||||
|
):
|
||||||
|
result = _load_or_fetch_ohlcv("AAPL")
|
||||||
|
mock_dl.assert_called_once()
|
||||||
|
|
||||||
|
def test_empty_download_raises_yfinance_error(self, tmp_path):
|
||||||
|
"""An empty DataFrame from yfinance raises YFinanceError (not a silent return)."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv, YFinanceError
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.get_config",
|
||||||
|
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.yf.download",
|
||||||
|
return_value=pd.DataFrame()),
|
||||||
|
):
|
||||||
|
with pytest.raises(YFinanceError, match="no data"):
|
||||||
|
_load_or_fetch_ohlcv("INVALID_TICKER_XYZ")
|
||||||
|
|
||||||
|
def test_cache_filename_is_dynamic_not_hardcoded(self, tmp_path):
|
||||||
|
"""Cache filename contains today's date (not a hardcoded historical date like 2025-03-25)."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import _load_or_fetch_ohlcv
|
||||||
|
|
||||||
|
df = _minimal_ohlcv_df()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.get_config",
|
||||||
|
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.yf.download",
|
||||||
|
return_value=df.set_index("Date")),
|
||||||
|
):
|
||||||
|
_load_or_fetch_ohlcv("AAPL")
|
||||||
|
|
||||||
|
csv_files = list(tmp_path.glob("AAPL-YFin-data-*.csv"))
|
||||||
|
assert len(csv_files) == 1
|
||||||
|
filename = csv_files[0].name
|
||||||
|
# Must NOT contain the old hardcoded stale date
|
||||||
|
assert "2025-03-25" not in filename, (
|
||||||
|
f"Cache filename contains the old hardcoded stale date! Got: {filename}"
|
||||||
|
)
|
||||||
|
# Must contain today's year
|
||||||
|
today_year = str(pd.Timestamp.today().year)
|
||||||
|
assert today_year in filename, f"Expected today's year {today_year} in filename: {filename}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# YFinanceError propagation (no more silent return "")
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestYFinanceErrorPropagation:
|
||||||
|
"""Tests that YFinanceError is raised (not swallowed) by get_stockstats_indicator."""
|
||||||
|
|
||||||
|
def test_get_stockstats_indicator_raises_yfinance_error_on_failure(self, tmp_path):
|
||||||
|
"""get_stockstats_indicator raises YFinanceError when yfinance returns empty data."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import YFinanceError
|
||||||
|
from tradingagents.dataflows.y_finance import get_stockstats_indicator
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.get_config",
|
||||||
|
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.yf.download",
|
||||||
|
return_value=pd.DataFrame()),
|
||||||
|
):
|
||||||
|
with pytest.raises(YFinanceError):
|
||||||
|
get_stockstats_indicator("INVALID", "rsi", "2025-01-02")
|
||||||
|
|
||||||
|
def test_yfinance_error_is_not_swallowed_as_empty_string(self, tmp_path):
|
||||||
|
"""Regression test: get_stockstats_indicator must NOT return empty string on error."""
|
||||||
|
from tradingagents.dataflows.stockstats_utils import YFinanceError
|
||||||
|
from tradingagents.dataflows.y_finance import get_stockstats_indicator
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.get_config",
|
||||||
|
return_value={"data_cache_dir": str(tmp_path), "data_vendors": {"technical_indicators": "yfinance"}}),
|
||||||
|
patch("tradingagents.dataflows.stockstats_utils.yf.download",
|
||||||
|
return_value=pd.DataFrame()),
|
||||||
|
):
|
||||||
|
result = None
|
||||||
|
try:
|
||||||
|
result = get_stockstats_indicator("INVALID", "rsi", "2025-01-02")
|
||||||
|
except YFinanceError:
|
||||||
|
pass # This is the correct behaviour
|
||||||
|
|
||||||
|
assert result is None, (
|
||||||
|
"get_stockstats_indicator should raise YFinanceError, not silently return a value. "
|
||||||
|
f"Got: {result!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _filter_csv_by_date_range — explicit column discovery
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFilterCsvByDateRange:
|
||||||
|
"""Tests for the fixed _filter_csv_by_date_range in alpha_vantage_common."""
|
||||||
|
|
||||||
|
def _make_av_csv(self, date_col_name: str = "time") -> str:
|
||||||
|
return (
|
||||||
|
f"{date_col_name},SMA\n"
|
||||||
|
"2024-01-02,230.5\n"
|
||||||
|
"2024-01-03,231.0\n"
|
||||||
|
"2024-01-08,235.5\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filters_with_standard_time_column(self):
|
||||||
|
"""Standard Alpha Vantage CSV with 'time' header filters correctly."""
|
||||||
|
from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range
|
||||||
|
|
||||||
|
result = _filter_csv_by_date_range(self._make_av_csv("time"), "2024-01-03", "2024-01-08")
|
||||||
|
assert "2024-01-02" not in result
|
||||||
|
assert "2024-01-03" in result
|
||||||
|
assert "2024-01-08" in result
|
||||||
|
|
||||||
|
def test_filters_with_timestamp_column(self):
|
||||||
|
"""CSV with 'timestamp' header (alternative AV format) also works."""
|
||||||
|
from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range
|
||||||
|
|
||||||
|
result = _filter_csv_by_date_range(self._make_av_csv("timestamp"), "2024-01-03", "2024-01-08")
|
||||||
|
assert "2024-01-02" not in result
|
||||||
|
assert "2024-01-03" in result
|
||||||
|
|
||||||
|
def test_missing_date_column_raises_not_silently_filters_wrong_column(self):
|
||||||
|
"""When no recognised date column exists, raises ValueError immediately.
|
||||||
|
Previously it would silently use df.columns[0] and return garbage."""
|
||||||
|
from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range
|
||||||
|
|
||||||
|
bad_csv = "price,volume\n102.0,1000\n103.0,2000\n"
|
||||||
|
|
||||||
|
# The fixed code raises ValueError; the old code would silently try to
|
||||||
|
# parse the 'price' column as a date and return the original data.
|
||||||
|
with pytest.raises(ValueError, match="Date column not found"):
|
||||||
|
_filter_csv_by_date_range(bad_csv, "2024-01-01", "2024-01-31")
|
||||||
|
|
||||||
|
def test_empty_csv_returns_empty(self):
|
||||||
|
"""Empty input returns empty output without error."""
|
||||||
|
from tradingagents.dataflows.alpha_vantage_common import _filter_csv_by_date_range
|
||||||
|
|
||||||
|
result = _filter_csv_by_date_range("", "2024-01-01", "2024-01-31")
|
||||||
|
assert result == ""
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
@ -7,6 +8,9 @@ import time as _time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
API_BASE_URL = "https://www.alphavantage.co/query"
|
API_BASE_URL = "https://www.alphavantage.co/query"
|
||||||
|
|
||||||
def get_api_key() -> str:
|
def get_api_key() -> str:
|
||||||
|
|
@ -228,8 +232,15 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
|
||||||
# Parse CSV data
|
# Parse CSV data
|
||||||
df = pd.read_csv(StringIO(csv_data))
|
df = pd.read_csv(StringIO(csv_data))
|
||||||
|
|
||||||
# Assume the first column is the date column (timestamp)
|
# Find the date column by name rather than positional assumption.
|
||||||
date_col = df.columns[0]
|
# Alpha Vantage returns 'time' or 'timestamp' as the date column header.
|
||||||
|
_date_candidates = [c for c in df.columns if c.lower() in ("time", "timestamp", "date")]
|
||||||
|
if not _date_candidates:
|
||||||
|
raise ValueError(
|
||||||
|
f"Date column not found in Alpha Vantage CSV. "
|
||||||
|
f"Expected 'time' or 'timestamp', got columns: {list(df.columns)}"
|
||||||
|
)
|
||||||
|
date_col = _date_candidates[0]
|
||||||
df[date_col] = pd.to_datetime(df[date_col])
|
df[date_col] = pd.to_datetime(df[date_col])
|
||||||
|
|
||||||
# Filter by date range
|
# Filter by date range
|
||||||
|
|
@ -241,7 +252,11 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
|
||||||
# Convert back to CSV string
|
# Convert back to CSV string
|
||||||
return filtered_df.to_csv(index=False)
|
return filtered_df.to_csv(index=False)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
# ValueError = a programming/data-contract error (e.g. missing date column).
|
||||||
|
# Re-raise so callers see it immediately rather than getting silently wrong data.
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If filtering fails, return original data with a warning
|
# Transient errors (I/O, malformed CSV): log and return original data as-is.
|
||||||
print(f"Warning: Failed to filter CSV data by date range: {e}")
|
logger.warning("Failed to filter CSV data by date range: %s", e)
|
||||||
return csv_data
|
return csv_data
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ from .alpha_vantage_scanner import (
|
||||||
)
|
)
|
||||||
from .alpha_vantage_common import AlphaVantageError, AlphaVantageRateLimitError, RateLimitError
|
from .alpha_vantage_common import AlphaVantageError, AlphaVantageRateLimitError, RateLimitError
|
||||||
from .finnhub_common import FinnhubError
|
from .finnhub_common import FinnhubError
|
||||||
|
from .stockstats_utils import YFinanceError
|
||||||
from .finnhub_news import get_insider_transactions as get_finnhub_insider_transactions
|
from .finnhub_news import get_insider_transactions as get_finnhub_insider_transactions
|
||||||
from .finnhub_scanner import (
|
from .finnhub_scanner import (
|
||||||
get_market_indices_finnhub,
|
get_market_indices_finnhub,
|
||||||
|
|
@ -262,7 +263,7 @@ def route_to_vendor(method: str, *args, **kwargs):
|
||||||
if rl:
|
if rl:
|
||||||
rl.log_vendor_call(method, vendor, True, (time.time() - t0) * 1000, args_summary=args_summary)
|
rl.log_vendor_call(method, vendor, True, (time.time() - t0) * 1000, args_summary=args_summary)
|
||||||
return result
|
return result
|
||||||
except (AlphaVantageError, FinnhubError, ConnectionError, TimeoutError) as exc:
|
except (AlphaVantageError, FinnhubError, YFinanceError, ConnectionError, TimeoutError) as exc:
|
||||||
if rl:
|
if rl:
|
||||||
rl.log_vendor_call(method, vendor, False, (time.time() - t0) * 1000, error=str(exc)[:200], args_summary=args_summary)
|
rl.log_vendor_call(method, vendor, False, (time.time() - t0) * 1000, error=str(exc)[:200], args_summary=args_summary)
|
||||||
last_error = exc
|
last_error = exc
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from stockstats import wrap
|
from stockstats import wrap
|
||||||
|
|
@ -5,6 +6,22 @@ from typing import Annotated
|
||||||
import os
|
import os
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public exception — lets callers catch stockstats/yfinance failures by type
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class YFinanceError(Exception):
|
||||||
|
"""Raised when yfinance or stockstats data fetching/processing fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
||||||
"""Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps.
|
"""Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps.
|
||||||
|
|
@ -29,6 +46,80 @@ def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _load_or_fetch_ohlcv(symbol: str) -> pd.DataFrame:
|
||||||
|
"""Single authority for loading OHLCV data: cache → yfinance download → normalize.
|
||||||
|
|
||||||
|
Cache filename is always derived from today's date (15-year window) so the
|
||||||
|
cache key never goes stale. If a cached file exists but is corrupt (too few
|
||||||
|
rows to be useful), it is deleted and re-fetched rather than silently
|
||||||
|
returning bad data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
YFinanceError: if the download returns an empty DataFrame or fails.
|
||||||
|
"""
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
today_date = pd.Timestamp.today()
|
||||||
|
start_date = today_date - pd.DateOffset(years=15)
|
||||||
|
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||||
|
end_date_str = today_date.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||||
|
|
||||||
|
data_file = os.path.join(
|
||||||
|
config["data_cache_dir"],
|
||||||
|
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Try to load from cache ────────────────────────────────────────────────
|
||||||
|
if os.path.exists(data_file):
|
||||||
|
try:
|
||||||
|
data = pd.read_csv(data_file) # no on_bad_lines="skip" — we want to know about corruption
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Corrupt cache file for %s (%s) — deleting and re-fetching.", symbol, exc
|
||||||
|
)
|
||||||
|
os.remove(data_file)
|
||||||
|
data = None
|
||||||
|
else:
|
||||||
|
# Validate: a 15-year daily file should have well over 100 rows
|
||||||
|
if len(data) < 50:
|
||||||
|
logger.warning(
|
||||||
|
"Cache file for %s has only %d rows — likely truncated, re-fetching.",
|
||||||
|
symbol, len(data),
|
||||||
|
)
|
||||||
|
os.remove(data_file)
|
||||||
|
data = None
|
||||||
|
else:
|
||||||
|
data = None
|
||||||
|
|
||||||
|
# ── Download from yfinance if cache miss / corrupt ────────────────────────
|
||||||
|
if data is None:
|
||||||
|
raw = yf.download(
|
||||||
|
symbol,
|
||||||
|
start=start_date_str,
|
||||||
|
end=end_date_str,
|
||||||
|
multi_level_index=False,
|
||||||
|
progress=False,
|
||||||
|
auto_adjust=True,
|
||||||
|
)
|
||||||
|
if raw.empty:
|
||||||
|
raise YFinanceError(
|
||||||
|
f"yfinance returned no data for symbol '{symbol}' "
|
||||||
|
f"({start_date_str} → {end_date_str})"
|
||||||
|
)
|
||||||
|
data = raw.reset_index()
|
||||||
|
data.to_csv(data_file, index=False)
|
||||||
|
logger.debug("Downloaded and cached OHLCV for %s → %s", symbol, data_file)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class StockstatsUtils:
|
class StockstatsUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_stock_stats(
|
def get_stock_stats(
|
||||||
|
|
@ -40,48 +131,20 @@ class StockstatsUtils:
|
||||||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
today_date = pd.Timestamp.today()
|
|
||||||
curr_date_dt = pd.to_datetime(curr_date)
|
curr_date_dt = pd.to_datetime(curr_date)
|
||||||
|
|
||||||
end_date = today_date
|
|
||||||
start_date = today_date - pd.DateOffset(years=15)
|
|
||||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
|
||||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
# Ensure cache directory exists
|
|
||||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
|
||||||
|
|
||||||
data_file = os.path.join(
|
|
||||||
config["data_cache_dir"],
|
|
||||||
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.exists(data_file):
|
|
||||||
data = pd.read_csv(data_file, on_bad_lines="skip")
|
|
||||||
else:
|
|
||||||
data = yf.download(
|
|
||||||
symbol,
|
|
||||||
start=start_date_str,
|
|
||||||
end=end_date_str,
|
|
||||||
multi_level_index=False,
|
|
||||||
progress=False,
|
|
||||||
auto_adjust=True,
|
|
||||||
)
|
|
||||||
data = data.reset_index()
|
|
||||||
data.to_csv(data_file, index=False)
|
|
||||||
|
|
||||||
data = _clean_dataframe(data)
|
|
||||||
df = wrap(data)
|
|
||||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
|
||||||
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
|
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
data = _load_or_fetch_ohlcv(symbol)
|
||||||
|
data = _clean_dataframe(data)
|
||||||
|
df = wrap(data)
|
||||||
|
# After wrap(), the date column becomes the datetime index (named 'date').
|
||||||
|
# Access via df.index, not df["Date"] which stockstats would try to parse as an indicator.
|
||||||
|
|
||||||
df[indicator] # trigger stockstats to calculate the indicator
|
df[indicator] # trigger stockstats to calculate the indicator
|
||||||
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
|
date_index_strs = df.index.strftime("%Y-%m-%d")
|
||||||
|
matching_rows = df[date_index_strs == curr_date_str]
|
||||||
|
|
||||||
if not matching_rows.empty:
|
if not matching_rows.empty:
|
||||||
indicator_value = matching_rows[indicator].values[0]
|
return matching_rows[indicator].values[0]
|
||||||
return indicator_value
|
|
||||||
else:
|
else:
|
||||||
return "N/A: Not a trading day (weekend or holiday)"
|
return "N/A: Not a trading day (weekend or holiday)"
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,14 @@
|
||||||
|
import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
import os
|
import os
|
||||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe
|
from .stockstats_utils import StockstatsUtils, YFinanceError, _clean_dataframe, _load_or_fetch_ohlcv
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_YFin_data_online(
|
def get_YFin_data_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
|
@ -191,70 +196,32 @@ def _get_stock_stats_bulk(
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Optimized bulk calculation of stock stats indicators.
|
Optimized bulk calculation of stock stats indicators.
|
||||||
Fetches data once and calculates indicator for all available dates.
|
Fetches data once (via shared _load_or_fetch_ohlcv cache) and calculates
|
||||||
|
the indicator for all available dates.
|
||||||
Returns dict mapping date strings to indicator values.
|
Returns dict mapping date strings to indicator values.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
YFinanceError: if data cannot be loaded or indicator calculation fails.
|
||||||
"""
|
"""
|
||||||
from .config import get_config
|
|
||||||
import pandas as pd
|
|
||||||
from stockstats import wrap
|
from stockstats import wrap
|
||||||
import os
|
|
||||||
|
|
||||||
config = get_config()
|
|
||||||
online = config["data_vendors"]["technical_indicators"] != "local"
|
|
||||||
|
|
||||||
if not online:
|
|
||||||
# Local data path
|
|
||||||
try:
|
|
||||||
data = pd.read_csv(
|
|
||||||
os.path.join(
|
|
||||||
config.get("data_cache_dir", "data"),
|
|
||||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
|
||||||
),
|
|
||||||
on_bad_lines="skip",
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
|
|
||||||
else:
|
|
||||||
# Online data fetching with caching
|
|
||||||
today_date = pd.Timestamp.today()
|
|
||||||
curr_date_dt = pd.to_datetime(curr_date)
|
|
||||||
|
|
||||||
end_date = today_date
|
|
||||||
start_date = today_date - pd.DateOffset(years=15)
|
|
||||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
|
||||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
|
||||||
|
|
||||||
data_file = os.path.join(
|
|
||||||
config["data_cache_dir"],
|
|
||||||
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.exists(data_file):
|
|
||||||
data = pd.read_csv(data_file, on_bad_lines="skip")
|
|
||||||
else:
|
|
||||||
data = yf.download(
|
|
||||||
symbol,
|
|
||||||
start=start_date_str,
|
|
||||||
end=end_date_str,
|
|
||||||
multi_level_index=False,
|
|
||||||
progress=False,
|
|
||||||
auto_adjust=True,
|
|
||||||
)
|
|
||||||
data = data.reset_index()
|
|
||||||
data.to_csv(data_file, index=False)
|
|
||||||
|
|
||||||
|
# Single authority: _load_or_fetch_ohlcv handles both online and local modes,
|
||||||
|
# dynamic cache filename, and corpus validation — no duplicated download logic here.
|
||||||
|
data = _load_or_fetch_ohlcv(symbol)
|
||||||
data = _clean_dataframe(data)
|
data = _clean_dataframe(data)
|
||||||
df = wrap(data)
|
df = wrap(data)
|
||||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
# After wrap(), the date column becomes the datetime index (named 'date').
|
||||||
|
# Access via df.index, not df["Date"] which stockstats would try to parse as an indicator.
|
||||||
|
|
||||||
# Calculate the indicator for all rows at once
|
# Calculate the indicator for all rows at once
|
||||||
df[indicator] # This triggers stockstats to calculate the indicator
|
df[indicator] # This triggers stockstats to calculate the indicator
|
||||||
|
|
||||||
# Create a dictionary mapping date strings to indicator values
|
# Create a dictionary mapping date strings to indicator values
|
||||||
# Optimized: replaced iterrows() with vectorized operations for performance
|
# Optimized: vectorized operations for performance using correct DatetimeIndex
|
||||||
return df.set_index("Date")[indicator].fillna("N/A").astype(str).to_dict()
|
series = df[indicator].copy()
|
||||||
|
series.index = series.index.strftime("%Y-%m-%d")
|
||||||
|
return series.fillna("N/A").astype(str).to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_stockstats_indicator(
|
def get_stockstats_indicator(
|
||||||
|
|
@ -264,25 +231,15 @@ def get_stockstats_indicator(
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||||
],
|
],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
curr_date = curr_date_dt.strftime("%Y-%m-%d")
|
curr_date = curr_date_dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
try:
|
# Raises YFinanceError on failure — caller (route_to_vendor) catches typed exceptions.
|
||||||
indicator_value = StockstatsUtils.get_stock_stats(
|
indicator_value = StockstatsUtils.get_stock_stats(symbol, indicator, curr_date)
|
||||||
symbol,
|
|
||||||
indicator,
|
|
||||||
curr_date,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
return str(indicator_value)
|
return str(indicator_value)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_fundamentals(
|
def get_fundamentals(
|
||||||
ticker: Annotated[str, "ticker symbol of the company"],
|
ticker: Annotated[str, "ticker symbol of the company"],
|
||||||
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,14 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from tradingagents.agents.utils.json_utils import extract_json
|
from tradingagents.agents.utils.json_utils import extract_json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Callable, Literal
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -172,15 +174,6 @@ def run_ticker_analysis(
|
||||||
|
|
||||||
NOTE: TradingAgentsGraph is synchronous — call this from a thread pool
|
NOTE: TradingAgentsGraph is synchronous — call this from a thread pool
|
||||||
when running multiple tickers concurrently (see run_all_tickers).
|
when running multiple tickers concurrently (see run_all_tickers).
|
||||||
|
|
||||||
Args:
|
|
||||||
candidate: The stock candidate to analyse.
|
|
||||||
macro_context: Macro context to embed in the result.
|
|
||||||
config: TradingAgents configuration dict.
|
|
||||||
analysis_date: Date string in YYYY-MM-DD format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TickerResult with all report fields populated, or error set on failure.
|
|
||||||
"""
|
"""
|
||||||
result = TickerResult(
|
result = TickerResult(
|
||||||
ticker=candidate.ticker,
|
ticker=candidate.ticker,
|
||||||
|
|
@ -189,11 +182,14 @@ def run_ticker_analysis(
|
||||||
analysis_date=analysis_date,
|
analysis_date=analysis_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Starting analysis for %s on %s", candidate.ticker, analysis_date)
|
t0 = time.monotonic()
|
||||||
|
logger.info(
|
||||||
|
"[%s] ▶ Starting analysis (%s, %s conviction)",
|
||||||
|
candidate.ticker, candidate.sector, candidate.conviction,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
from tradingagents.observability import get_run_logger
|
from tradingagents.observability import get_run_logger
|
||||||
|
|
||||||
rl = get_run_logger()
|
rl = get_run_logger()
|
||||||
|
|
@ -210,23 +206,31 @@ def run_ticker_analysis(
|
||||||
result.risk_debate = str(final_state.get("risk_debate_state", ""))
|
result.risk_debate = str(final_state.get("risk_debate_state", ""))
|
||||||
result.final_trade_decision = decision
|
result.final_trade_decision = decision
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
logger.info(
|
logger.info(
|
||||||
"Analysis complete for %s: %s", candidate.ticker, str(decision)[:120]
|
"[%s] ✓ Analysis complete in %.0fs — decision: %s",
|
||||||
|
candidate.ticker, elapsed, str(decision)[:80],
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Analysis failed for %s: %s", candidate.ticker, exc, exc_info=True)
|
elapsed = time.monotonic() - t0
|
||||||
|
logger.error(
|
||||||
|
"[%s] ✗ Analysis FAILED after %.0fs: %s",
|
||||||
|
candidate.ticker, elapsed, exc, exc_info=True,
|
||||||
|
)
|
||||||
result.error = str(exc)
|
result.error = str(exc)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def run_all_tickers(
|
async def run_all_tickers(
|
||||||
candidates: list[StockCandidate],
|
candidates: list[StockCandidate],
|
||||||
macro_context: MacroContext,
|
macro_context: MacroContext,
|
||||||
config: dict,
|
config: dict,
|
||||||
analysis_date: str,
|
analysis_date: str,
|
||||||
max_concurrent: int = 2,
|
max_concurrent: int = 2,
|
||||||
|
on_ticker_done: Callable[[TickerResult, int, int], None] | None = None,
|
||||||
) -> list[TickerResult]:
|
) -> list[TickerResult]:
|
||||||
"""Run TradingAgents for every candidate with controlled concurrency.
|
"""Run TradingAgents for every candidate with controlled concurrency.
|
||||||
|
|
||||||
|
|
@ -239,28 +243,45 @@ async def run_all_tickers(
|
||||||
config: TradingAgents configuration dict.
|
config: TradingAgents configuration dict.
|
||||||
analysis_date: Date string in YYYY-MM-DD format.
|
analysis_date: Date string in YYYY-MM-DD format.
|
||||||
max_concurrent: Maximum number of tickers to process in parallel.
|
max_concurrent: Maximum number of tickers to process in parallel.
|
||||||
|
on_ticker_done: Optional callback(result, done_count, total_count) fired
|
||||||
|
after each ticker finishes — use this to drive a progress bar.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of TickerResult in completion order.
|
List of TickerResult in completion order.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
executor = ThreadPoolExecutor(max_workers=max_concurrent)
|
total = len(candidates)
|
||||||
try:
|
results: list[TickerResult] = []
|
||||||
tasks = [
|
|
||||||
loop.run_in_executor(
|
# Use a semaphore so at most max_concurrent tickers run simultaneously,
|
||||||
executor,
|
# but we still get individual completion callbacks via as_completed.
|
||||||
|
semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
|
||||||
|
async def _run_one(candidate: StockCandidate) -> TickerResult:
|
||||||
|
async with semaphore:
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, # use default ThreadPoolExecutor
|
||||||
run_ticker_analysis,
|
run_ticker_analysis,
|
||||||
c,
|
candidate,
|
||||||
macro_context,
|
macro_context,
|
||||||
config,
|
config,
|
||||||
analysis_date,
|
analysis_date,
|
||||||
)
|
)
|
||||||
for c in candidates
|
|
||||||
]
|
tasks = [asyncio.create_task(_run_one(c)) for c in candidates]
|
||||||
results = await asyncio.gather(*tasks)
|
done_count = 0
|
||||||
return list(results)
|
for coro in asyncio.as_completed(tasks):
|
||||||
finally:
|
result = await coro
|
||||||
executor.shutdown(wait=False)
|
done_count += 1
|
||||||
|
results.append(result)
|
||||||
|
if on_ticker_done is not None:
|
||||||
|
try:
|
||||||
|
on_ticker_done(result, done_count, total)
|
||||||
|
except Exception: # never let a callback crash the pipeline
|
||||||
|
pass
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Reporting ────────────────────────────────────────────────────────────────
|
# ─── Reporting ────────────────────────────────────────────────────────────────
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue