565 lines
23 KiB
Python
565 lines
23 KiB
Python
"""Comprehensive unit tests for Fast-Reject [CRITICAL ABORT] feature.
|
|
|
|
This module tests the critical abort mechanism that short-circuits the trading agent
|
|
workflow when catastrophic conditions are detected in market or fundamentals reports.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from langchain_core.messages import AIMessage
|
|
|
|
from tradingagents.agents.analysts.fundamentals_analyst import create_fundamentals_analyst
|
|
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mock Data
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Market report with abort
|
|
market_report_abort = "[CRITICAL ABORT] Reason: Trading halted pending SEC investigation"
|
|
|
|
# Fundamentals report with abort
|
|
fundamentals_report_abort = "[CRITICAL ABORT] Reason: Negative gross margin with bankruptcy filing"
|
|
|
|
# Normal market report
|
|
normal_market_report = "Market analysis shows strong bullish trend with positive momentum..."
|
|
|
|
# Normal fundamentals report
|
|
normal_fundamentals_report = "Company fundamentals are strong with healthy margins and growth prospects..."
|
|
|
|
# Macro regime report
|
|
macro_regime_report = "Current macro environment shows stable interest rates and moderate inflation."
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ConditionalLogic Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestConditionalLogicAbortDetection:
|
|
"""Tests for critical abort detection in ConditionalLogic."""
|
|
|
|
def test_check_critical_abort_detected_in_market_report(self):
|
|
"""Verify abort is detected in market_report."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": market_report_abort,
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
}
|
|
result = cl._check_critical_abort(state, "market_report")
|
|
assert result is True
|
|
|
|
def test_check_critical_abort_detected_in_fundamentals_report(self):
|
|
"""Verify abort is detected in fundamentals_report."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": normal_market_report,
|
|
"fundamentals_report": fundamentals_report_abort,
|
|
}
|
|
result = cl._check_critical_abort(state, "fundamentals_report")
|
|
assert result is True
|
|
|
|
def test_check_critical_abort_not_detected(self):
|
|
"""Verify normal reports pass through without abort detection."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": normal_market_report,
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
}
|
|
result = cl._check_critical_abort(state, "market_report")
|
|
assert result is False
|
|
|
|
def test_check_critical_abort_empty_report(self):
|
|
"""Verify abort is not detected when report field is empty."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": "",
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
}
|
|
result = cl._check_critical_abort(state, "market_report")
|
|
assert result is False
|
|
|
|
def test_check_critical_abort_missing_report_field(self):
|
|
"""Verify abort is not detected when report field is missing."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
}
|
|
result = cl._check_critical_abort(state, "market_report")
|
|
assert result is False
|
|
|
|
def test_check_critical_abort_partial_match(self):
|
|
"""Verify abort is detected even with partial match."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": "Some text [CRITICAL ABORT] Reason: Test",
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
}
|
|
result = cl._check_critical_abort(state, "market_report")
|
|
assert result is True
|
|
|
|
|
|
class TestConditionalLogicFlowControl:
|
|
"""Tests for flow control when abort is detected."""
|
|
|
|
def test_should_continue_debate_with_abort(self):
|
|
"""Verify debate is bypassed when abort detected."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": market_report_abort,
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
"investment_debate_state": {
|
|
"history": [],
|
|
"bull_history": [],
|
|
"bear_history": [],
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
}
|
|
result = cl.should_continue_debate(state)
|
|
assert result == "Portfolio Manager"
|
|
|
|
def test_should_continue_risk_analysis_with_abort(self):
|
|
"""Verify risk analysis is bypassed when abort detected."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": market_report_abort,
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
"risk_debate_state": {
|
|
"history": [],
|
|
"aggressive_history": [],
|
|
"conservative_history": [],
|
|
"neutral_history": [],
|
|
"latest_speaker": "Aggressive",
|
|
"current_aggressive_response": "",
|
|
"current_conservative_response": "",
|
|
"current_neutral_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
}
|
|
result = cl.should_continue_risk_analysis(state)
|
|
assert result == "Portfolio Manager"
|
|
|
|
def test_normal_flow_without_abort(self):
|
|
"""Verify normal flow continues when no abort detected."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": normal_market_report,
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
"investment_debate_state": {
|
|
"history": [],
|
|
"bull_history": [],
|
|
"bear_history": [],
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
}
|
|
result = cl.should_continue_debate(state)
|
|
assert result == "Bull Researcher" # Bull speaks first when current_response is empty
|
|
|
|
def test_normal_flow_without_abort_risk_analysis(self):
|
|
"""Verify normal risk analysis flow continues when no abort detected."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": normal_market_report,
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
"risk_debate_state": {
|
|
"history": [],
|
|
"aggressive_history": [],
|
|
"conservative_history": [],
|
|
"neutral_history": [],
|
|
"latest_speaker": "Aggressive",
|
|
"count": 0,
|
|
},
|
|
}
|
|
result = cl.should_continue_risk_analysis(state)
|
|
assert result == "Conservative Analyst"
|
|
|
|
def test_abort_in_fundamentals_bypasses_debate(self):
|
|
"""Verify debate is bypassed when fundamentals report contains abort."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": normal_market_report,
|
|
"fundamentals_report": fundamentals_report_abort,
|
|
"investment_debate_state": {
|
|
"history": [],
|
|
"bull_history": [],
|
|
"bear_history": [],
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
}
|
|
result = cl.should_continue_debate(state)
|
|
assert result == "Portfolio Manager"
|
|
|
|
def test_abort_in_fundamentals_bypasses_risk_analysis(self):
|
|
"""Verify risk analysis is bypassed when fundamentals report contains abort."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": normal_market_report,
|
|
"fundamentals_report": fundamentals_report_abort,
|
|
"risk_debate_state": {
|
|
"history": [],
|
|
"aggressive_history": [],
|
|
"conservative_history": [],
|
|
"neutral_history": [],
|
|
"latest_speaker": "Aggressive",
|
|
"current_aggressive_response": "",
|
|
"current_conservative_response": "",
|
|
"current_neutral_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
}
|
|
result = cl.should_continue_risk_analysis(state)
|
|
assert result == "Portfolio Manager"
|
|
|
|
def test_abort_in_market_bypasses_risk_analysis(self):
|
|
"""Verify market abort bypasses risk analysis."""
|
|
cl = ConditionalLogic()
|
|
state = {
|
|
"market_report": market_report_abort,
|
|
"fundamentals_report": normal_fundamentals_report,
|
|
"risk_debate_state": {
|
|
"history": [],
|
|
"aggressive_history": [],
|
|
"conservative_history": [],
|
|
"neutral_history": [],
|
|
"latest_speaker": "Aggressive",
|
|
"current_aggressive_response": "",
|
|
"current_conservative_response": "",
|
|
"current_neutral_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
}
|
|
result = cl.should_continue_risk_analysis(state)
|
|
assert result == "Portfolio Manager"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Analyst Report Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestMarketAnalystAbortInstructions:
|
|
"""Tests for market analyst abort instructions in system prompt."""
|
|
|
|
def test_market_analyst_includes_abort_instructions(self):
|
|
"""Verify market analyst produces abort report when LLM signals critical abort."""
|
|
# run_tool_loop is the injectable boundary — patch it to return the abort message
|
|
# without making any network calls.
|
|
mock_result = MagicMock()
|
|
mock_result.content = market_report_abort
|
|
mock_result.tool_calls = []
|
|
|
|
with patch("tradingagents.agents.analysts.market_analyst.prefetch_tools_parallel", return_value={}), \
|
|
patch("tradingagents.agents.analysts.market_analyst.run_tool_loop", return_value=mock_result):
|
|
market_analyst = create_market_analyst(MagicMock())
|
|
state = {
|
|
"trade_date": "2024-01-01",
|
|
"company_of_interest": "AAPL",
|
|
"messages": [],
|
|
}
|
|
result = market_analyst(state)
|
|
|
|
# Verify the report contains abort
|
|
assert "[CRITICAL ABORT]" in result.get("market_report", "")
|
|
|
|
def test_market_analyst_abort_conditions(self):
|
|
"""Verify market analyst abort conditions are documented in the system prompt constant."""
|
|
market_analyst = create_market_analyst(MagicMock())
|
|
|
|
# The system_message is built from adjacent string literals that the compiler
|
|
# concatenates into one big string constant stored in co_consts.
|
|
# Check that at least one constant contains the trigger phrase as a substring.
|
|
assert any(
|
|
"CRITICAL ABORT TRIGGER" in str(c)
|
|
for c in market_analyst.__code__.co_consts
|
|
)
|
|
|
|
|
|
class TestFundamentalsAnalystAbortInstructions:
|
|
"""Tests for fundamentals analyst abort instructions in system prompt."""
|
|
|
|
def test_fundamentals_analyst_includes_abort_instructions(self):
|
|
"""Verify fundamentals analyst produces abort report when LLM signals critical abort."""
|
|
mock_result = MagicMock()
|
|
mock_result.content = fundamentals_report_abort
|
|
mock_result.tool_calls = []
|
|
|
|
with patch("tradingagents.agents.analysts.fundamentals_analyst.prefetch_tools_parallel", return_value={}), \
|
|
patch("tradingagents.agents.analysts.fundamentals_analyst.run_tool_loop", return_value=mock_result):
|
|
fundamentals_analyst = create_fundamentals_analyst(MagicMock())
|
|
state = {
|
|
"trade_date": "2024-01-01",
|
|
"company_of_interest": "AAPL",
|
|
"messages": [],
|
|
}
|
|
result = fundamentals_analyst(state)
|
|
|
|
# Verify the report contains abort
|
|
assert "[CRITICAL ABORT]" in result.get("fundamentals_report", "")
|
|
|
|
def test_fundamentals_analyst_abort_conditions(self):
|
|
"""Verify fundamentals analyst abort conditions are documented in the system prompt constant."""
|
|
fundamentals_analyst = create_fundamentals_analyst(MagicMock())
|
|
|
|
# The system_message is built from adjacent string literals compiled into one constant.
|
|
assert any(
|
|
"CRITICAL ABORT TRIGGER" in str(c)
|
|
for c in fundamentals_analyst.__code__.co_consts
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Portfolio Manager Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPortfolioManagerAbortDetection:
|
|
"""Tests for portfolio manager abort detection and response."""
|
|
|
|
def _make_abort_state(self, market_report, fundamentals_report):
|
|
"""Build a minimal state dict suitable for portfolio_manager_node."""
|
|
return {
|
|
"company_of_interest": "AAPL",
|
|
"market_report": market_report,
|
|
"fundamentals_report": fundamentals_report,
|
|
"macro_regime_report": macro_regime_report,
|
|
"risk_debate_state": {
|
|
"history": [],
|
|
"aggressive_history": [],
|
|
"conservative_history": [],
|
|
"neutral_history": [],
|
|
"current_aggressive_response": "",
|
|
"current_conservative_response": "",
|
|
"current_neutral_response": "",
|
|
"count": 0,
|
|
},
|
|
"news_report": "",
|
|
"sentiment_report": "",
|
|
"investment_plan": "BUY AAPL",
|
|
}
|
|
|
|
def test_portfolio_manager_detects_abort(self):
|
|
"""Verify PM detects abort and recommends SELL/AVOID."""
|
|
# Create mock LLM *before* the closure so the closure captures it.
|
|
mock_llm = MagicMock()
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content="RECOMMENDATION: SELL - Trading halted pending SEC investigation"
|
|
)
|
|
portfolio_manager = create_portfolio_manager(mock_llm, MagicMock())
|
|
state = self._make_abort_state(market_report_abort, normal_fundamentals_report)
|
|
|
|
result = portfolio_manager(state)
|
|
|
|
# Verify the closure's LLM was actually called
|
|
assert mock_llm.invoke.called
|
|
# Verify the result contains SELL recommendation
|
|
assert "SELL" in result.get("final_trade_decision", "").upper()
|
|
|
|
def test_portfolio_manager_uses_aborting_analyst_report(self):
|
|
"""Verify PM decision text reflects the abort reason from the analyst report."""
|
|
mock_llm = MagicMock()
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content="RECOMMENDATION: SELL - Trading halted pending SEC investigation"
|
|
)
|
|
portfolio_manager = create_portfolio_manager(mock_llm, MagicMock())
|
|
state = self._make_abort_state(market_report_abort, normal_fundamentals_report)
|
|
|
|
result = portfolio_manager(state)
|
|
|
|
recommendation = result.get("final_trade_decision", "")
|
|
assert "SEC investigation" in recommendation
|
|
|
|
def test_portfolio_manager_normal_flow(self):
|
|
"""Verify PM works normally without abort."""
|
|
mock_llm = MagicMock()
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content="RECOMMENDATION: BUY - Strong bullish trend with positive momentum"
|
|
)
|
|
portfolio_manager = create_portfolio_manager(mock_llm, MagicMock())
|
|
state = self._make_abort_state(normal_market_report, normal_fundamentals_report)
|
|
|
|
result = portfolio_manager(state)
|
|
|
|
# Verify the closure's LLM was actually called
|
|
assert mock_llm.invoke.called
|
|
# Verify the result contains BUY recommendation
|
|
assert "BUY" in result.get("final_trade_decision", "").upper()
|
|
|
|
def test_portfolio_manager_uses_fundamentals_abort_report(self):
|
|
"""Verify PM uses fundamentals report when it contains abort."""
|
|
mock_llm = MagicMock()
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content="RECOMMENDATION: AVOID - Negative gross margin with bankruptcy filing"
|
|
)
|
|
portfolio_manager = create_portfolio_manager(mock_llm, MagicMock())
|
|
state = self._make_abort_state(normal_market_report, fundamentals_report_abort)
|
|
|
|
result = portfolio_manager(state)
|
|
|
|
recommendation = result.get("final_trade_decision", "")
|
|
assert "bankruptcy" in recommendation.lower()
|
|
|
|
def test_portfolio_manager_avoids_recommendation(self):
|
|
"""Verify PM recommends AVOID when fundamentals report has abort."""
|
|
mock_llm = MagicMock()
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content="RECOMMENDATION: AVOID - Negative gross margin with bankruptcy filing"
|
|
)
|
|
portfolio_manager = create_portfolio_manager(mock_llm, MagicMock())
|
|
state = self._make_abort_state(normal_market_report, fundamentals_report_abort)
|
|
|
|
result = portfolio_manager(state)
|
|
|
|
assert "AVOID" in result.get("final_trade_decision", "").upper()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestFastRejectFullFlow:
|
|
"""Integration tests for the complete fast-reject short-circuit flow."""
|
|
|
|
# Shared initial state template for integration tests
|
|
_base_state = {
|
|
"ticker": "AAPL",
|
|
"trade_date": "2024-01-01",
|
|
"company_of_interest": "AAPL",
|
|
"macro_regime_report": macro_regime_report,
|
|
"risk_debate_state": {
|
|
"history": [],
|
|
"aggressive_history": [],
|
|
"conservative_history": [],
|
|
"neutral_history": [],
|
|
"current_aggressive_response": "",
|
|
"current_conservative_response": "",
|
|
"current_neutral_response": "",
|
|
"count": 0,
|
|
},
|
|
"investment_debate_state": {
|
|
"history": [],
|
|
"bull_history": [],
|
|
"bear_history": [],
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
"news_report": "",
|
|
"sentiment_report": "",
|
|
"investment_plan": "BUY AAPL",
|
|
"messages": [],
|
|
}
|
|
|
|
def _make_state(self, market_report, fundamentals_report):
|
|
return {**self._base_state, "market_report": market_report, "fundamentals_report": fundamentals_report}
|
|
|
|
def test_fast_reject_full_flow(self):
|
|
"""Test the complete short-circuit flow from analyst to portfolio manager."""
|
|
mock_market_ai = MagicMock()
|
|
mock_market_ai.content = market_report_abort
|
|
mock_market_ai.tool_calls = []
|
|
|
|
state = self._make_state(market_report_abort, normal_fundamentals_report)
|
|
|
|
# Patch network-calling helpers; control analyst output via run_tool_loop mock
|
|
with patch("tradingagents.agents.analysts.market_analyst.prefetch_tools_parallel", return_value={}), \
|
|
patch("tradingagents.agents.analysts.market_analyst.run_tool_loop", return_value=mock_market_ai):
|
|
market_analyst = create_market_analyst(MagicMock())
|
|
analyst_result = market_analyst(state)
|
|
state = {**state, **analyst_result} # merge so all keys are preserved
|
|
|
|
# Verify market report contains abort
|
|
assert "[CRITICAL ABORT]" in state.get("market_report", "")
|
|
|
|
# Run portfolio manager (mock LLM captured by closure)
|
|
mock_llm = MagicMock()
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content="RECOMMENDATION: SELL - Trading halted pending SEC investigation"
|
|
)
|
|
portfolio_manager = create_portfolio_manager(mock_llm, MagicMock())
|
|
pm_result = portfolio_manager(state)
|
|
state = {**state, **pm_result} # merge so market_report is still accessible
|
|
|
|
# Verify portfolio manager detected abort
|
|
assert "SELL" in state.get("final_trade_decision", "").upper()
|
|
|
|
# Verify conditional logic would bypass debate and risk analysis
|
|
cl = ConditionalLogic()
|
|
assert cl.should_continue_debate(state) == "Portfolio Manager"
|
|
assert cl.should_continue_risk_analysis(state) == "Portfolio Manager"
|
|
|
|
def test_fast_reject_fundamentals_flow(self):
|
|
"""Test the complete short-circuit flow with fundamentals abort."""
|
|
mock_market_ai = MagicMock()
|
|
mock_market_ai.content = normal_market_report
|
|
mock_market_ai.tool_calls = []
|
|
|
|
state = self._make_state(normal_market_report, fundamentals_report_abort)
|
|
|
|
with patch("tradingagents.agents.analysts.market_analyst.prefetch_tools_parallel", return_value={}), \
|
|
patch("tradingagents.agents.analysts.market_analyst.run_tool_loop", return_value=mock_market_ai):
|
|
market_analyst = create_market_analyst(MagicMock())
|
|
analyst_result = market_analyst(state)
|
|
state = {**state, **analyst_result}
|
|
|
|
# Market report should be normal (abort is in fundamentals)
|
|
assert "[CRITICAL ABORT]" not in state.get("market_report", "")
|
|
# Fundamentals abort must survive the merge
|
|
assert "[CRITICAL ABORT]" in state.get("fundamentals_report", "")
|
|
|
|
mock_llm = MagicMock()
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content="RECOMMENDATION: AVOID - Negative gross margin with bankruptcy filing"
|
|
)
|
|
portfolio_manager = create_portfolio_manager(mock_llm, MagicMock())
|
|
pm_result = portfolio_manager(state)
|
|
state = {**state, **pm_result}
|
|
|
|
assert "AVOID" in state.get("final_trade_decision", "").upper()
|
|
|
|
cl = ConditionalLogic()
|
|
assert cl.should_continue_debate(state) == "Portfolio Manager"
|
|
assert cl.should_continue_risk_analysis(state) == "Portfolio Manager"
|
|
|
|
def test_fast_reject_normal_flow(self):
|
|
"""Test the complete flow without abort."""
|
|
mock_market_ai = MagicMock()
|
|
mock_market_ai.content = normal_market_report
|
|
mock_market_ai.tool_calls = []
|
|
|
|
state = self._make_state(normal_market_report, normal_fundamentals_report)
|
|
|
|
with patch("tradingagents.agents.analysts.market_analyst.prefetch_tools_parallel", return_value={}), \
|
|
patch("tradingagents.agents.analysts.market_analyst.run_tool_loop", return_value=mock_market_ai):
|
|
market_analyst = create_market_analyst(MagicMock())
|
|
analyst_result = market_analyst(state)
|
|
state = {**state, **analyst_result}
|
|
|
|
assert "[CRITICAL ABORT]" not in state.get("market_report", "")
|
|
|
|
mock_llm = MagicMock()
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content="RECOMMENDATION: BUY - Strong bullish trend with positive momentum"
|
|
)
|
|
portfolio_manager = create_portfolio_manager(mock_llm, MagicMock())
|
|
pm_result = portfolio_manager(state)
|
|
state = {**state, **pm_result}
|
|
|
|
assert "BUY" in state.get("final_trade_decision", "").upper()
|
|
|
|
# Normal flow: conditional logic must NOT route directly to Portfolio Manager
|
|
cl = ConditionalLogic()
|
|
debate_result = cl.should_continue_debate(state)
|
|
risk_result = cl.should_continue_risk_analysis(state)
|
|
|
|
assert debate_result != "Portfolio Manager"
|
|
assert risk_result != "Portfolio Manager" |