TradingAgents/tests/test_chief_analyst.py

149 lines
6.1 KiB
Python

import pytest
from pydantic import ValidationError
def test_chief_analyst_report_valid_buy():
from tradingagents.agents.utils.agent_states import ChiefAnalystReport
r = ChiefAnalystReport(verdict="BUY", catalyst="Strong Q4", execution="Enter at market", tail_risk="Rate hike")
assert r.verdict == "BUY"
assert r.model_dump() == {
"verdict": "BUY",
"catalyst": "Strong Q4",
"execution": "Enter at market",
"tail_risk": "Rate hike",
}
def test_chief_analyst_report_valid_sell():
from tradingagents.agents.utils.agent_states import ChiefAnalystReport
r = ChiefAnalystReport(verdict="SELL", catalyst="Weak guidance", execution="Exit position", tail_risk="Liquidity crunch")
assert r.verdict == "SELL"
def test_chief_analyst_report_valid_hold():
from tradingagents.agents.utils.agent_states import ChiefAnalystReport
r = ChiefAnalystReport(verdict="HOLD", catalyst="Mixed signals", execution="No change", tail_risk="FX exposure")
assert r.verdict == "HOLD"
def test_chief_analyst_report_rejects_invalid_verdict():
from tradingagents.agents.utils.agent_states import ChiefAnalystReport
with pytest.raises(ValidationError):
ChiefAnalystReport(verdict="MAYBE", catalyst="x", execution="x", tail_risk="x")
def test_agent_state_does_not_require_chief_analyst_report():
"""AgentState can be constructed without chief_analyst_report (NotRequired field)."""
from tradingagents.agents.utils.agent_states import AgentState
from typing_extensions import get_type_hints, NotRequired
hints = get_type_hints(AgentState, include_extras=True)
assert "chief_analyst_report" in hints
# --- Task 2: Chief Analyst agent factory ---
from unittest.mock import MagicMock
def _make_mock_llm(verdict="BUY", catalyst="Strong earnings", execution="Enter at market", tail_risk="Rate risk"):
"""Return a mock LLM that produces a structured ChiefAnalystReport."""
from tradingagents.agents.utils.agent_states import ChiefAnalystReport
structured_llm = MagicMock()
structured_llm.invoke.return_value = ChiefAnalystReport(
verdict=verdict, catalyst=catalyst, execution=execution, tail_risk=tail_risk
)
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = structured_llm
return mock_llm, structured_llm
def _make_state():
"""Minimal AgentState dict for testing the Chief Analyst node."""
return {
"company_of_interest": "AAPL",
"trade_date": "2024-01-15",
"market_report": "Bullish technicals.",
"sentiment_report": "Positive social sentiment.",
"news_report": "No major negative news.",
"fundamentals_report": "Strong balance sheet.",
"investment_plan": "Bull case: enter long.",
"trader_investment_plan": "Buy at market, SL at 180.",
"final_trade_decision": "BUY. Rationale: strong Q4 earnings.",
}
def test_create_chief_analyst_returns_callable():
from tradingagents.agents.managers.chief_analyst import create_chief_analyst
mock_llm, _ = _make_mock_llm()
node = create_chief_analyst(mock_llm)
assert callable(node)
def test_chief_analyst_node_calls_structured_llm():
from tradingagents.agents.managers.chief_analyst import create_chief_analyst
from tradingagents.agents.utils.agent_states import ChiefAnalystReport
mock_llm, structured_llm = _make_mock_llm()
node = create_chief_analyst(mock_llm)
mock_llm.with_structured_output.assert_called_once_with(ChiefAnalystReport)
def test_chief_analyst_node_returns_report_dict():
from tradingagents.agents.managers.chief_analyst import create_chief_analyst
mock_llm, _ = _make_mock_llm(verdict="BUY", catalyst="Strong earnings", execution="Enter at market", tail_risk="Rate risk")
node = create_chief_analyst(mock_llm)
result = node(_make_state())
assert "chief_analyst_report" in result
assert result["chief_analyst_report"]["verdict"] == "BUY"
assert result["chief_analyst_report"]["catalyst"] == "Strong earnings"
assert result["chief_analyst_report"]["execution"] == "Enter at market"
assert result["chief_analyst_report"]["tail_risk"] == "Rate risk"
def test_chief_analyst_node_result_is_json_serializable():
"""The returned dict must be serializable so SqliteSaver can checkpoint it."""
import json
from tradingagents.agents.managers.chief_analyst import create_chief_analyst
mock_llm, _ = _make_mock_llm()
node = create_chief_analyst(mock_llm)
result = node(_make_state())
serialized = json.dumps(result["chief_analyst_report"])
assert isinstance(serialized, str)
def test_chief_analyst_node_prompt_includes_company_name():
"""The LLM must be called with a prompt referencing the company."""
from tradingagents.agents.managers.chief_analyst import create_chief_analyst
mock_llm, structured_llm = _make_mock_llm()
node = create_chief_analyst(mock_llm)
state = _make_state()
node(state)
call_args = structured_llm.invoke.call_args
prompt_text = call_args[0][0]
assert "AAPL" in prompt_text
# --- Task 4: trading_graph.py changes ---
def test_extract_report_chief_analyst_serializes_dict():
"""_extract_report for chief_analyst must JSON-serialize the dict from state."""
import json
from tradingagents.graph.trading_graph import TradingAgentsGraph
report_dict = {"verdict": "BUY", "catalyst": "x", "execution": "y", "tail_risk": "z"}
update = {"chief_analyst_report": report_dict}
result = TradingAgentsGraph._extract_report("chief_analyst", update)
assert json.loads(result) == report_dict
def test_extract_report_chief_analyst_handles_missing():
"""_extract_report for chief_analyst returns empty JSON object when key absent."""
import json
from tradingagents.graph.trading_graph import TradingAgentsGraph
result = TradingAgentsGraph._extract_report("chief_analyst", {})
assert json.loads(result) == {}
def test_node_to_step_includes_chief_analyst():
from tradingagents.graph.trading_graph import _NODE_TO_STEP
assert "Chief Analyst" in _NODE_TO_STEP
assert _NODE_TO_STEP["Chief Analyst"] == "chief_analyst"