TradingAgents/tests/integration/test_workflow_e2e.py

257 lines
9.5 KiB
Python

from unittest.mock import MagicMock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from tradingagents.agents.utils.agent_states import InvestDebateState, RiskDebateState
from tradingagents.graph.conditional_logic import ConditionalLogic
from tradingagents.graph.propagation import Propagator
from tradingagents.graph.trading_graph import TradingAgentsGraph
class TestWorkflowStateTransitions:
def test_initial_state_structure(self):
propagator = Propagator()
state = propagator.create_initial_state("AAPL", "2024-01-15")
assert "messages" in state
assert "company_of_interest" in state
assert "trade_date" in state
assert "investment_debate_state" in state
assert "risk_debate_state" in state
assert "market_report" in state
assert "sentiment_report" in state
assert "news_report" in state
assert "fundamentals_report" in state
def test_market_analyst_state_update(self):
initial_state = {
"messages": [HumanMessage(content="AAPL")],
"company_of_interest": "AAPL",
"trade_date": "2024-01-15",
"market_report": "",
}
updated_state = {
**initial_state,
"market_report": "Technical analysis shows bullish trend",
"messages": [
HumanMessage(content="AAPL"),
AIMessage(content="Technical analysis shows bullish trend"),
],
}
assert updated_state["market_report"] != ""
assert len(updated_state["messages"]) == 2
def test_debate_state_progression(self):
logic = ConditionalLogic(max_debate_rounds=2)
state_round_0 = {
"investment_debate_state": InvestDebateState(
bull_history="",
bear_history="",
history="",
current_response="",
judge_decision="",
count=0,
)
}
state_round_1 = {
"investment_debate_state": InvestDebateState(
bull_history="Bull: I see growth potential",
bear_history="",
history="Bull: I see growth potential",
current_response="Bull: I see growth potential",
judge_decision="",
count=1,
)
}
state_round_2 = {
"investment_debate_state": InvestDebateState(
bull_history="Bull: I see growth potential",
bear_history="Bear: Market risks are high",
history="Bull: I see growth potential\nBear: Market risks are high",
current_response="Bear: Market risks are high",
judge_decision="",
count=2,
)
}
assert logic.should_continue_debate(state_round_1) == "Bear Researcher"
assert logic.should_continue_debate(state_round_2) == "Bull Researcher"
def test_risk_analysis_state_progression(self):
logic = ConditionalLogic(max_risk_discuss_rounds=1)
state_risky = {
"risk_debate_state": RiskDebateState(
risky_history="Go for it!",
safe_history="",
neutral_history="",
history="Risky: Go for it!",
latest_speaker="Risky Analyst",
current_risky_response="Go for it!",
current_safe_response="",
current_neutral_response="",
judge_decision="",
count=1,
)
}
state_safe = {
"risk_debate_state": RiskDebateState(
risky_history="Go for it!",
safe_history="Be cautious",
neutral_history="",
history="Risky: Go for it!\nSafe: Be cautious",
latest_speaker="Safe Analyst",
current_risky_response="Go for it!",
current_safe_response="Be cautious",
current_neutral_response="",
judge_decision="",
count=2,
)
}
state_neutral = {
"risk_debate_state": RiskDebateState(
risky_history="Go for it!",
safe_history="Be cautious",
neutral_history="Balance both views",
history="Full discussion",
latest_speaker="Neutral Analyst",
current_risky_response="Go for it!",
current_safe_response="Be cautious",
current_neutral_response="Balance both views",
judge_decision="",
count=3,
)
}
assert logic.should_continue_risk_analysis(state_risky) == "Safe Analyst"
assert logic.should_continue_risk_analysis(state_safe) == "Neutral Analyst"
assert logic.should_continue_risk_analysis(state_neutral) == "Risk Judge"
class TestWorkflowEndToEnd:
def test_final_state_has_all_reports(self):
final_state = {
"company_of_interest": "AAPL",
"trade_date": "2024-01-15",
"market_report": "Bullish technical indicators",
"sentiment_report": "Positive social sentiment",
"news_report": "Favorable news coverage",
"fundamentals_report": "Strong financials",
"investment_debate_state": {
"bull_history": "Bull arguments",
"bear_history": "Bear arguments",
"history": "Full debate",
"current_response": "Final bull response",
"judge_decision": "BUY recommendation",
"count": 4,
},
"trader_investment_plan": "Buy 100 shares at market open",
"risk_debate_state": {
"risky_history": "High conviction",
"safe_history": "Moderate position size",
"neutral_history": "Balanced view",
"history": "Risk discussion",
"latest_speaker": "Risk Judge",
"judge_decision": "APPROVED with position limits",
"count": 3,
},
"final_trade_decision": "BUY 100 shares AAPL at market open",
}
assert final_state["market_report"] != ""
assert final_state["sentiment_report"] != ""
assert final_state["news_report"] != ""
assert final_state["fundamentals_report"] != ""
assert "BUY" in final_state["investment_debate_state"]["judge_decision"]
assert "APPROVED" in final_state["risk_debate_state"]["judge_decision"]
assert "BUY" in final_state["final_trade_decision"]
def test_workflow_handles_sell_decision(self):
final_state = {
"company_of_interest": "AAPL",
"trade_date": "2024-01-15",
"market_report": "Bearish technical indicators",
"sentiment_report": "Negative sentiment",
"news_report": "Bad news",
"fundamentals_report": "Weak financials",
"investment_debate_state": {
"judge_decision": "SELL recommendation",
"count": 4,
},
"risk_debate_state": {
"judge_decision": "APPROVED",
"count": 3,
},
"final_trade_decision": "SELL position in AAPL",
}
assert "SELL" in final_state["final_trade_decision"]
def test_workflow_handles_hold_decision(self):
final_state = {
"company_of_interest": "AAPL",
"trade_date": "2024-01-15",
"investment_debate_state": {
"judge_decision": "HOLD - insufficient conviction",
"count": 4,
},
"risk_debate_state": {
"judge_decision": "No action required",
"count": 3,
},
"final_trade_decision": "HOLD current position",
}
assert "HOLD" in final_state["final_trade_decision"]
class TestTradingAgentsGraphValidation:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.set_config")
def test_graph_validates_ticker_on_propagate(self, mock_set_config, mock_llm):
from tradingagents.validation import TickerValidationError
mock_llm_instance = MagicMock()
mock_llm.return_value = mock_llm_instance
with patch.object(TradingAgentsGraph, "__init__", lambda x, **kwargs: None):
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
graph.config = {"llm_provider": "openai"}
graph.debug = False
graph.ticker = None
graph.log_states_dict = {}
from tradingagents.validation import validate_ticker
with pytest.raises(TickerValidationError):
validate_ticker("INVALID123TICKER")
def test_valid_ticker_formats(self):
from tradingagents.validation import validate_ticker
assert validate_ticker("AAPL") == "AAPL"
assert validate_ticker("aapl") == "AAPL"
assert validate_ticker("BRK-B") == "BRK-B"
assert validate_ticker("BRK.A") == "BRK.A"
assert validate_ticker(" MSFT ") == "MSFT"
def test_invalid_ticker_formats(self):
from tradingagents.validation import TickerValidationError, validate_ticker
with pytest.raises(TickerValidationError):
validate_ticker("")
with pytest.raises(TickerValidationError):
validate_ticker("TOOLONGTICKER")
with pytest.raises(TickerValidationError):
validate_ticker("123")