TradingAgents/tests/portfolio/test_selection_reflector.py

107 lines
4.4 KiB
Python

import pytest
import pandas as pd
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from tradingagents.portfolio.selection_reflector import (
fetch_price_trend, fetch_news_summary, generate_lesson, reflect_on_scan
)
@pytest.fixture
def mock_yf_download(monkeypatch):
def _mock_download(tickers, start, end, **kwargs):
dates = pd.date_range(start, periods=5)
# Ticker goes up then down
ticker_closes = [100.0, 110.0, 105.0, 90.0, 85.0]
# SPY goes up steadily
spy_closes = [400.0, 402.0, 405.0, 407.0, 410.0]
df = pd.DataFrame({
"AAPL": ticker_closes,
"SPY": spy_closes
}, index=dates)
return df
monkeypatch.setattr("yfinance.download", _mock_download)
return _mock_download
def test_fetch_price_data_normal(mock_yf_download):
terminal_return, spy_return, mfe_pct, mae_pct, days_to_peak, top_move_dates = fetch_price_trend("AAPL", "2025-01-01", "2025-01-05")
assert terminal_return == pytest.approx(-15.0) # (85 - 100) / 100 * 100
assert spy_return == pytest.approx(2.5) # (410 - 400) / 400 * 100
assert mfe_pct == pytest.approx(10.0) # (110 - 100) / 100 * 100
assert mae_pct == pytest.approx(-15.0) # (85 - 100) / 100 * 100
assert days_to_peak == 1 # Peak is at index 1 (2025-01-02)
assert len(top_move_dates) == 3
def test_fetch_price_data_single_day(monkeypatch):
monkeypatch.setattr("yfinance.download", lambda *args, **kwargs: pd.DataFrame({"AAPL": [100.0], "SPY": [400.0]}))
terminal_return, spy_return, mfe_pct, mae_pct, days_to_peak, top_move_dates = fetch_price_trend("AAPL", "2025-01-01", "2025-01-01")
assert terminal_return is None
assert spy_return is None
assert mfe_pct is None
assert mae_pct is None
assert days_to_peak is None
assert top_move_dates == []
def test_fetch_news_summary_weighted(monkeypatch):
def mock_get_company_news(ticker, start, end):
if start == "2025-01-01":
return "- Start news 1\n- Start news 2\n- Start news 3"
elif start == "2025-01-02":
return "- Top move 1\n- Top move 1b"
elif start == "2025-01-04":
return "- Top move 2"
return ""
monkeypatch.setattr("tradingagents.portfolio.selection_reflector.get_company_news", mock_get_company_news)
summary = fetch_news_summary("AAPL", "2025-01-01", "2025-01-05", ["2025-01-02", "2025-01-04"])
assert "- Start news 1" in summary
assert "- Start news 2" in summary
assert "- Start news 3" not in summary # Only taking 2 start news
assert "- Top move 1" in summary
assert "- Top move 1b" not in summary # Only taking 1 from each top move date
assert "- Top move 2" in summary
def test_generate_lesson_valid():
llm = MagicMock()
llm.invoke.return_value = AIMessage(content='```json\n{"situation": "test sit", "screening_advice": "test screen", "exit_advice": "test exit", "sentiment": "negative"}\n```')
cand = {"ticker": "AAPL", "sector": "Tech", "thesis_angle": "growth", "rationale": "good", "conviction": "high"}
lesson = generate_lesson(llm, cand, -10.0, 2.0, 5.0, -12.0, 5, "news", 30)
assert lesson is not None
assert lesson["situation"] == "test sit"
assert lesson["screening_advice"] == "test screen"
assert lesson["exit_advice"] == "test exit"
assert lesson["sentiment"] == "negative"
def test_generate_lesson_mfe_mae_in_prompt():
llm = MagicMock()
llm.invoke.return_value = AIMessage(content='{"situation": "a", "screening_advice": "b", "exit_advice": "c", "sentiment": "neutral"}')
cand = {"ticker": "AAPL"}
generate_lesson(llm, cand, -10.0, 2.0, 5.1, -12.2, 5, "news", 30)
call_args = llm.invoke.call_args[0][0]
prompt_text = call_args[0].content
assert "MFE): +5.1%" in prompt_text
assert "MAE): -12.2%" in prompt_text
assert "Day 5" in prompt_text
def test_generate_lesson_bad_json():
llm = MagicMock()
llm.invoke.return_value = AIMessage(content='Not a JSON')
lesson = generate_lesson(llm, {}, -10.0, 2.0, 5.0, -12.0, 5, "news", 30)
assert lesson is None
def test_reflect_on_scan_no_file(monkeypatch):
llm = MagicMock()
monkeypatch.setattr("tradingagents.portfolio.selection_reflector.load_scan_candidates", lambda date: [])
lessons = reflect_on_scan("2025-01-01", "2025-01-31", llm, 30)
assert lessons == []