diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_agent_states.py b/tests/integration/test_agent_states.py new file mode 100644 index 00000000..34db8dfb --- /dev/null +++ b/tests/integration/test_agent_states.py @@ -0,0 +1,117 @@ +import pytest +from tradingagents.agents.utils.agent_states import ( + AgentState, + InvestDebateState, + RiskDebateState, +) + + +class TestInvestDebateState: + def test_create_empty_state(self): + state = InvestDebateState( + bull_history="", + bear_history="", + history="", + current_response="", + judge_decision="", + count=0, + ) + assert state["count"] == 0 + assert state["history"] == "" + + def test_create_state_with_history(self): + state = InvestDebateState( + bull_history="Bull argues for buying", + bear_history="Bear argues for selling", + history="Bull: buy\nBear: sell", + current_response="Bull: I maintain my position", + judge_decision="", + count=2, + ) + assert state["count"] == 2 + assert "Bull" in state["bull_history"] + assert "Bear" in state["bear_history"] + + def test_state_as_dict(self): + state = InvestDebateState( + bull_history="test", + bear_history="test", + history="test", + current_response="test", + judge_decision="BUY", + count=1, + ) + assert isinstance(state, dict) + assert state["judge_decision"] == "BUY" + + +class TestRiskDebateState: + def test_create_empty_state(self): + state = RiskDebateState( + risky_history="", + safe_history="", + neutral_history="", + history="", + latest_speaker="", + current_risky_response="", + current_safe_response="", + current_neutral_response="", + judge_decision="", + count=0, + ) + assert state["count"] == 0 + assert state["latest_speaker"] == "" + + def test_create_state_with_speakers(self): + state = RiskDebateState( + risky_history="Risky: Go all in!", + safe_history="Safe: Be cautious", + neutral_history="Neutral: Consider both", + history="Discussion ongoing", + latest_speaker="Risky Analyst", + current_risky_response="Go all in!", + current_safe_response="", + current_neutral_response="", + judge_decision="", + count=1, + ) + assert state["latest_speaker"] == "Risky Analyst" + assert "Go all in" in state["risky_history"] + + def test_state_tracks_all_speakers(self): + state = RiskDebateState( + risky_history="Risky view", + safe_history="Safe view", + neutral_history="Neutral view", + history="Full debate", + latest_speaker="Neutral Analyst", + current_risky_response="risky", + current_safe_response="safe", + current_neutral_response="neutral", + judge_decision="APPROVED", + count=3, + ) + assert state["count"] == 3 + assert state["judge_decision"] == "APPROVED" + + +class TestAgentStateStructure: + def test_agent_state_has_required_fields(self): + required_fields = [ + "company_of_interest", + "trade_date", + "sender", + "market_report", + "sentiment_report", + "news_report", + "fundamentals_report", + "investment_debate_state", + "investment_plan", + "trader_investment_plan", + "risk_debate_state", + "final_trade_decision", + ] + + annotations = AgentState.__annotations__ + for field in required_fields: + assert field in annotations, f"Missing field: {field}" diff --git a/tests/integration/test_conditional_logic.py b/tests/integration/test_conditional_logic.py new file mode 100644 index 00000000..47c996c4 --- /dev/null +++ b/tests/integration/test_conditional_logic.py @@ -0,0 +1,240 @@ +import pytest +from unittest.mock import MagicMock +from tradingagents.graph.conditional_logic import ConditionalLogic +from tradingagents.agents.utils.agent_states import InvestDebateState, RiskDebateState + + +class TestConditionalLogicAnalysts: + def setup_method(self): + self.logic = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=2) + + def test_should_continue_market_with_tool_calls(self): + mock_message = MagicMock() + mock_message.tool_calls = [{"name": "get_stock_data"}] + state = {"messages": [mock_message]} + + result = self.logic.should_continue_market(state) + assert result == "tools_market" + + def test_should_continue_market_without_tool_calls(self): + mock_message = MagicMock() + mock_message.tool_calls = [] + state = {"messages": [mock_message]} + + result = self.logic.should_continue_market(state) + assert result == "Msg Clear Market" + + def test_should_continue_social_with_tool_calls(self): + mock_message = MagicMock() + mock_message.tool_calls = [{"name": "get_news"}] + state = {"messages": [mock_message]} + + result = self.logic.should_continue_social(state) + assert result == "tools_social" + + def test_should_continue_social_without_tool_calls(self): + mock_message = MagicMock() + mock_message.tool_calls = [] + state = {"messages": [mock_message]} + + result = self.logic.should_continue_social(state) + assert result == "Msg Clear Social" + + def test_should_continue_news_with_tool_calls(self): + mock_message = MagicMock() + mock_message.tool_calls = [{"name": "get_global_news"}] + state = {"messages": [mock_message]} + + result = self.logic.should_continue_news(state) + assert result == "tools_news" + + def test_should_continue_news_without_tool_calls(self): + mock_message = MagicMock() + mock_message.tool_calls = [] + state = {"messages": [mock_message]} + + result = self.logic.should_continue_news(state) + assert result == "Msg Clear News" + + def test_should_continue_fundamentals_with_tool_calls(self): + mock_message = MagicMock() + mock_message.tool_calls = [{"name": "get_balance_sheet"}] + state = {"messages": [mock_message]} + + result = self.logic.should_continue_fundamentals(state) + assert result == "tools_fundamentals" + + def test_should_continue_fundamentals_without_tool_calls(self): + mock_message = MagicMock() + mock_message.tool_calls = [] + state = {"messages": [mock_message]} + + result = self.logic.should_continue_fundamentals(state) + assert result == "Msg Clear Fundamentals" + + +class TestConditionalLogicDebate: + def setup_method(self): + self.logic = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=2) + + def test_should_continue_debate_to_bear(self): + state = { + "investment_debate_state": InvestDebateState( + bull_history="", + bear_history="", + history="", + current_response="Bull: I think we should buy", + judge_decision="", + count=1, + ) + } + + result = self.logic.should_continue_debate(state) + assert result == "Bear Researcher" + + def test_should_continue_debate_to_bull(self): + state = { + "investment_debate_state": InvestDebateState( + bull_history="", + bear_history="", + history="", + current_response="Bear: I disagree", + judge_decision="", + count=2, + ) + } + + result = self.logic.should_continue_debate(state) + assert result == "Bull Researcher" + + def test_should_continue_debate_to_manager_max_rounds(self): + state = { + "investment_debate_state": InvestDebateState( + bull_history="", + bear_history="", + history="", + current_response="Bull: Final argument", + judge_decision="", + count=4, + ) + } + + result = self.logic.should_continue_debate(state) + assert result == "Research Manager" + + def test_debate_rounds_configurable(self): + logic_one_round = ConditionalLogic(max_debate_rounds=1) + state = { + "investment_debate_state": InvestDebateState( + bull_history="", + bear_history="", + history="", + current_response="Bull: argument", + judge_decision="", + count=2, + ) + } + + result = logic_one_round.should_continue_debate(state) + assert result == "Research Manager" + + +class TestConditionalLogicRiskAnalysis: + def setup_method(self): + self.logic = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=2) + + def test_should_continue_risk_to_safe(self): + state = { + "risk_debate_state": RiskDebateState( + risky_history="", + safe_history="", + neutral_history="", + history="", + latest_speaker="Risky Analyst", + current_risky_response="", + current_safe_response="", + current_neutral_response="", + judge_decision="", + count=1, + ) + } + + result = self.logic.should_continue_risk_analysis(state) + assert result == "Safe Analyst" + + def test_should_continue_risk_to_neutral(self): + state = { + "risk_debate_state": RiskDebateState( + risky_history="", + safe_history="", + neutral_history="", + history="", + latest_speaker="Safe Analyst", + current_risky_response="", + current_safe_response="", + current_neutral_response="", + judge_decision="", + count=2, + ) + } + + result = self.logic.should_continue_risk_analysis(state) + assert result == "Neutral Analyst" + + def test_should_continue_risk_to_risky(self): + state = { + "risk_debate_state": RiskDebateState( + risky_history="", + safe_history="", + neutral_history="", + history="", + latest_speaker="Neutral Analyst", + current_risky_response="", + current_safe_response="", + current_neutral_response="", + judge_decision="", + count=3, + ) + } + + result = self.logic.should_continue_risk_analysis(state) + assert result == "Risky Analyst" + + def test_should_continue_risk_to_judge_max_rounds(self): + state = { + "risk_debate_state": RiskDebateState( + risky_history="", + safe_history="", + neutral_history="", + history="", + latest_speaker="Risky Analyst", + current_risky_response="", + current_safe_response="", + current_neutral_response="", + judge_decision="", + count=6, + ) + } + + result = self.logic.should_continue_risk_analysis(state) + assert result == "Risk Judge" + + def test_risk_rounds_configurable(self): + logic_one_round = ConditionalLogic(max_risk_discuss_rounds=1) + state = { + "risk_debate_state": RiskDebateState( + risky_history="", + safe_history="", + neutral_history="", + history="", + latest_speaker="Neutral Analyst", + current_risky_response="", + current_safe_response="", + current_neutral_response="", + judge_decision="", + count=3, + ) + } + + result = logic_one_round.should_continue_risk_analysis(state) + assert result == "Risk Judge" diff --git a/tests/integration/test_graph_setup.py b/tests/integration/test_graph_setup.py new file mode 100644 index 00000000..2c3a0b4f --- /dev/null +++ b/tests/integration/test_graph_setup.py @@ -0,0 +1,145 @@ +import pytest +from unittest.mock import MagicMock, patch +from tradingagents.graph.setup import GraphSetup +from tradingagents.graph.conditional_logic import ConditionalLogic + + +class TestGraphSetup: + def setup_method(self): + self.mock_llm = MagicMock() + self.mock_tool_nodes = { + "market": MagicMock(), + "social": MagicMock(), + "news": MagicMock(), + "fundamentals": MagicMock(), + } + self.mock_memory = MagicMock() + self.conditional_logic = ConditionalLogic() + + def create_graph_setup(self): + return GraphSetup( + quick_thinking_llm=self.mock_llm, + deep_thinking_llm=self.mock_llm, + tool_nodes=self.mock_tool_nodes, + bull_memory=self.mock_memory, + bear_memory=self.mock_memory, + trader_memory=self.mock_memory, + invest_judge_memory=self.mock_memory, + risk_manager_memory=self.mock_memory, + conditional_logic=self.conditional_logic, + ) + + def test_graph_setup_initialization(self): + setup = self.create_graph_setup() + + assert setup.quick_thinking_llm == self.mock_llm + assert setup.deep_thinking_llm == self.mock_llm + assert setup.tool_nodes == self.mock_tool_nodes + assert setup.conditional_logic == self.conditional_logic + + def test_setup_graph_with_all_analysts(self): + setup = self.create_graph_setup() + + with patch("tradingagents.graph.setup.create_market_analyst") as mock_market, \ + patch("tradingagents.graph.setup.create_social_media_analyst") as mock_social, \ + patch("tradingagents.graph.setup.create_news_analyst") as mock_news, \ + patch("tradingagents.graph.setup.create_fundamentals_analyst") as mock_fund, \ + patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull, \ + patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear, \ + patch("tradingagents.graph.setup.create_research_manager") as mock_rm, \ + patch("tradingagents.graph.setup.create_trader") as mock_trader, \ + patch("tradingagents.graph.setup.create_risky_debator") as mock_risky, \ + patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral, \ + patch("tradingagents.graph.setup.create_safe_debator") as mock_safe, \ + patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr: + + mock_market.return_value = MagicMock() + mock_social.return_value = MagicMock() + mock_news.return_value = MagicMock() + mock_fund.return_value = MagicMock() + mock_bull.return_value = MagicMock() + mock_bear.return_value = MagicMock() + mock_rm.return_value = MagicMock() + mock_trader.return_value = MagicMock() + mock_risky.return_value = MagicMock() + mock_neutral.return_value = MagicMock() + mock_safe.return_value = MagicMock() + mock_risk_mgr.return_value = MagicMock() + + graph = setup.setup_graph(["market", "social", "news", "fundamentals"]) + + mock_market.assert_called_once() + mock_social.assert_called_once() + mock_news.assert_called_once() + mock_fund.assert_called_once() + mock_bull.assert_called_once() + mock_bear.assert_called_once() + mock_rm.assert_called_once() + mock_trader.assert_called_once() + + def test_setup_graph_with_single_analyst(self): + setup = self.create_graph_setup() + + with patch("tradingagents.graph.setup.create_market_analyst") as mock_market, \ + patch("tradingagents.graph.setup.create_social_media_analyst") as mock_social, \ + patch("tradingagents.graph.setup.create_news_analyst") as mock_news, \ + patch("tradingagents.graph.setup.create_fundamentals_analyst") as mock_fund, \ + patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull, \ + patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear, \ + patch("tradingagents.graph.setup.create_research_manager") as mock_rm, \ + patch("tradingagents.graph.setup.create_trader") as mock_trader, \ + patch("tradingagents.graph.setup.create_risky_debator") as mock_risky, \ + patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral, \ + patch("tradingagents.graph.setup.create_safe_debator") as mock_safe, \ + patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr: + + mock_market.return_value = MagicMock() + mock_bull.return_value = MagicMock() + mock_bear.return_value = MagicMock() + mock_rm.return_value = MagicMock() + mock_trader.return_value = MagicMock() + mock_risky.return_value = MagicMock() + mock_neutral.return_value = MagicMock() + mock_safe.return_value = MagicMock() + mock_risk_mgr.return_value = MagicMock() + + graph = setup.setup_graph(["market"]) + + mock_market.assert_called_once() + mock_social.assert_not_called() + mock_news.assert_not_called() + mock_fund.assert_not_called() + + def test_setup_graph_empty_analysts_raises(self): + setup = self.create_graph_setup() + + with pytest.raises(ValueError, match="no analysts selected"): + setup.setup_graph([]) + + def test_setup_graph_returns_compiled_graph(self): + setup = self.create_graph_setup() + + with patch("tradingagents.graph.setup.create_market_analyst") as mock_market, \ + patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull, \ + patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear, \ + patch("tradingagents.graph.setup.create_research_manager") as mock_rm, \ + patch("tradingagents.graph.setup.create_trader") as mock_trader, \ + patch("tradingagents.graph.setup.create_risky_debator") as mock_risky, \ + patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral, \ + patch("tradingagents.graph.setup.create_safe_debator") as mock_safe, \ + patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr: + + mock_market.return_value = MagicMock() + mock_bull.return_value = MagicMock() + mock_bear.return_value = MagicMock() + mock_rm.return_value = MagicMock() + mock_trader.return_value = MagicMock() + mock_risky.return_value = MagicMock() + mock_neutral.return_value = MagicMock() + mock_safe.return_value = MagicMock() + mock_risk_mgr.return_value = MagicMock() + + graph = setup.setup_graph(["market"]) + + assert graph is not None + assert hasattr(graph, "invoke") or hasattr(graph, "stream") diff --git a/tests/integration/test_propagation.py b/tests/integration/test_propagation.py new file mode 100644 index 00000000..8f1c6a83 --- /dev/null +++ b/tests/integration/test_propagation.py @@ -0,0 +1,75 @@ +import pytest +from datetime import date +from tradingagents.graph.propagation import Propagator +from tradingagents.agents.utils.agent_states import InvestDebateState, RiskDebateState + + +class TestPropagator: + def setup_method(self): + self.propagator = Propagator(max_recur_limit=50) + + def test_create_initial_state_basic(self): + state = self.propagator.create_initial_state("AAPL", "2024-01-15") + + assert state["company_of_interest"] == "AAPL" + assert state["trade_date"] == "2024-01-15" + assert state["market_report"] == "" + assert state["fundamentals_report"] == "" + assert state["sentiment_report"] == "" + assert state["news_report"] == "" + + def test_create_initial_state_messages(self): + state = self.propagator.create_initial_state("MSFT", "2024-01-15") + + assert "messages" in state + assert len(state["messages"]) == 1 + assert state["messages"][0] == ("human", "MSFT") + + def test_create_initial_state_debate_states(self): + state = self.propagator.create_initial_state("GOOGL", "2024-01-15") + + assert "investment_debate_state" in state + invest_state = state["investment_debate_state"] + assert invest_state["history"] == "" + assert invest_state["current_response"] == "" + assert invest_state["count"] == 0 + + assert "risk_debate_state" in state + risk_state = state["risk_debate_state"] + assert risk_state["history"] == "" + assert risk_state["count"] == 0 + + def test_create_initial_state_with_date_object(self): + trade_date = date(2024, 1, 15) + state = self.propagator.create_initial_state("TSLA", trade_date) + + assert state["trade_date"] == "2024-01-15" + + def test_get_graph_args(self): + args = self.propagator.get_graph_args() + + assert "stream_mode" in args + assert args["stream_mode"] == "values" + assert "config" in args + assert "recursion_limit" in args["config"] + assert args["config"]["recursion_limit"] == 50 + + def test_custom_recursion_limit(self): + custom_propagator = Propagator(max_recur_limit=200) + args = custom_propagator.get_graph_args() + + assert args["config"]["recursion_limit"] == 200 + + def test_state_is_dict(self): + state = self.propagator.create_initial_state("NVDA", "2024-01-15") + assert isinstance(state, dict) + + def test_multiple_states_independent(self): + state1 = self.propagator.create_initial_state("AAPL", "2024-01-15") + state2 = self.propagator.create_initial_state("MSFT", "2024-01-16") + + assert state1["company_of_interest"] != state2["company_of_interest"] + assert state1["trade_date"] != state2["trade_date"] + + state1["market_report"] = "Modified" + assert state2["market_report"] == "" diff --git a/tests/integration/test_workflow_e2e.py b/tests/integration/test_workflow_e2e.py new file mode 100644 index 00000000..5b0c4e96 --- /dev/null +++ b/tests/integration/test_workflow_e2e.py @@ -0,0 +1,258 @@ +import pytest +from unittest.mock import MagicMock, patch, PropertyMock +from datetime import date +from langchain_core.messages import AIMessage, HumanMessage + +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.graph.propagation import Propagator +from tradingagents.graph.conditional_logic import ConditionalLogic +from tradingagents.agents.utils.agent_states import InvestDebateState, RiskDebateState + + +class TestWorkflowStateTransitions: + + def test_initial_state_structure(self): + propagator = Propagator() + state = propagator.create_initial_state("AAPL", "2024-01-15") + + assert "messages" in state + assert "company_of_interest" in state + assert "trade_date" in state + assert "investment_debate_state" in state + assert "risk_debate_state" in state + assert "market_report" in state + assert "sentiment_report" in state + assert "news_report" in state + assert "fundamentals_report" in state + + def test_market_analyst_state_update(self): + initial_state = { + "messages": [HumanMessage(content="AAPL")], + "company_of_interest": "AAPL", + "trade_date": "2024-01-15", + "market_report": "", + } + + updated_state = { + **initial_state, + "market_report": "Technical analysis shows bullish trend", + "messages": [ + HumanMessage(content="AAPL"), + AIMessage(content="Technical analysis shows bullish trend"), + ], + } + + assert updated_state["market_report"] != "" + assert len(updated_state["messages"]) == 2 + + def test_debate_state_progression(self): + logic = ConditionalLogic(max_debate_rounds=2) + + state_round_0 = { + "investment_debate_state": InvestDebateState( + bull_history="", + bear_history="", + history="", + current_response="", + judge_decision="", + count=0, + ) + } + + state_round_1 = { + "investment_debate_state": InvestDebateState( + bull_history="Bull: I see growth potential", + bear_history="", + history="Bull: I see growth potential", + current_response="Bull: I see growth potential", + judge_decision="", + count=1, + ) + } + + state_round_2 = { + "investment_debate_state": InvestDebateState( + bull_history="Bull: I see growth potential", + bear_history="Bear: Market risks are high", + history="Bull: I see growth potential\nBear: Market risks are high", + current_response="Bear: Market risks are high", + judge_decision="", + count=2, + ) + } + + assert logic.should_continue_debate(state_round_1) == "Bear Researcher" + assert logic.should_continue_debate(state_round_2) == "Bull Researcher" + + def test_risk_analysis_state_progression(self): + logic = ConditionalLogic(max_risk_discuss_rounds=1) + + state_risky = { + "risk_debate_state": RiskDebateState( + risky_history="Go for it!", + safe_history="", + neutral_history="", + history="Risky: Go for it!", + latest_speaker="Risky Analyst", + current_risky_response="Go for it!", + current_safe_response="", + current_neutral_response="", + judge_decision="", + count=1, + ) + } + + state_safe = { + "risk_debate_state": RiskDebateState( + risky_history="Go for it!", + safe_history="Be cautious", + neutral_history="", + history="Risky: Go for it!\nSafe: Be cautious", + latest_speaker="Safe Analyst", + current_risky_response="Go for it!", + current_safe_response="Be cautious", + current_neutral_response="", + judge_decision="", + count=2, + ) + } + + state_neutral = { + "risk_debate_state": RiskDebateState( + risky_history="Go for it!", + safe_history="Be cautious", + neutral_history="Balance both views", + history="Full discussion", + latest_speaker="Neutral Analyst", + current_risky_response="Go for it!", + current_safe_response="Be cautious", + current_neutral_response="Balance both views", + judge_decision="", + count=3, + ) + } + + assert logic.should_continue_risk_analysis(state_risky) == "Safe Analyst" + assert logic.should_continue_risk_analysis(state_safe) == "Neutral Analyst" + assert logic.should_continue_risk_analysis(state_neutral) == "Risk Judge" + + +class TestWorkflowEndToEnd: + + def test_final_state_has_all_reports(self): + final_state = { + "company_of_interest": "AAPL", + "trade_date": "2024-01-15", + "market_report": "Bullish technical indicators", + "sentiment_report": "Positive social sentiment", + "news_report": "Favorable news coverage", + "fundamentals_report": "Strong financials", + "investment_debate_state": { + "bull_history": "Bull arguments", + "bear_history": "Bear arguments", + "history": "Full debate", + "current_response": "Final bull response", + "judge_decision": "BUY recommendation", + "count": 4, + }, + "trader_investment_plan": "Buy 100 shares at market open", + "risk_debate_state": { + "risky_history": "High conviction", + "safe_history": "Moderate position size", + "neutral_history": "Balanced view", + "history": "Risk discussion", + "latest_speaker": "Risk Judge", + "judge_decision": "APPROVED with position limits", + "count": 3, + }, + "final_trade_decision": "BUY 100 shares AAPL at market open", + } + + assert final_state["market_report"] != "" + assert final_state["sentiment_report"] != "" + assert final_state["news_report"] != "" + assert final_state["fundamentals_report"] != "" + assert "BUY" in final_state["investment_debate_state"]["judge_decision"] + assert "APPROVED" in final_state["risk_debate_state"]["judge_decision"] + assert "BUY" in final_state["final_trade_decision"] + + def test_workflow_handles_sell_decision(self): + final_state = { + "company_of_interest": "AAPL", + "trade_date": "2024-01-15", + "market_report": "Bearish technical indicators", + "sentiment_report": "Negative sentiment", + "news_report": "Bad news", + "fundamentals_report": "Weak financials", + "investment_debate_state": { + "judge_decision": "SELL recommendation", + "count": 4, + }, + "risk_debate_state": { + "judge_decision": "APPROVED", + "count": 3, + }, + "final_trade_decision": "SELL position in AAPL", + } + + assert "SELL" in final_state["final_trade_decision"] + + def test_workflow_handles_hold_decision(self): + final_state = { + "company_of_interest": "AAPL", + "trade_date": "2024-01-15", + "investment_debate_state": { + "judge_decision": "HOLD - insufficient conviction", + "count": 4, + }, + "risk_debate_state": { + "judge_decision": "No action required", + "count": 3, + }, + "final_trade_decision": "HOLD current position", + } + + assert "HOLD" in final_state["final_trade_decision"] + + +class TestTradingAgentsGraphValidation: + + @patch("tradingagents.graph.trading_graph.ChatOpenAI") + @patch("tradingagents.graph.trading_graph.set_config") + def test_graph_validates_ticker_on_propagate(self, mock_set_config, mock_llm): + from tradingagents.validation import TickerValidationError + + mock_llm_instance = MagicMock() + mock_llm.return_value = mock_llm_instance + + with patch.object(TradingAgentsGraph, "__init__", lambda x, **kwargs: None): + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + graph.config = {"llm_provider": "openai"} + graph.debug = False + graph.ticker = None + graph.log_states_dict = {} + + from tradingagents.validation import validate_ticker + with pytest.raises(TickerValidationError): + validate_ticker("INVALID123TICKER") + + def test_valid_ticker_formats(self): + from tradingagents.validation import validate_ticker + + assert validate_ticker("AAPL") == "AAPL" + assert validate_ticker("aapl") == "AAPL" + assert validate_ticker("BRK-B") == "BRK-B" + assert validate_ticker("BRK.A") == "BRK.A" + assert validate_ticker(" MSFT ") == "MSFT" + + def test_invalid_ticker_formats(self): + from tradingagents.validation import validate_ticker, TickerValidationError + + with pytest.raises(TickerValidationError): + validate_ticker("") + + with pytest.raises(TickerValidationError): + validate_ticker("TOOLONGTICKER") + + with pytest.raises(TickerValidationError): + validate_ticker("123")