TradingAgents/tests/integration/test_conditional_logic.py

241 lines
7.9 KiB
Python

import pytest
from unittest.mock import MagicMock
from tradingagents.graph.conditional_logic import ConditionalLogic
from tradingagents.agents.utils.agent_states import InvestDebateState, RiskDebateState
class TestConditionalLogicAnalysts:
def setup_method(self):
self.logic = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=2)
def test_should_continue_market_with_tool_calls(self):
mock_message = MagicMock()
mock_message.tool_calls = [{"name": "get_stock_data"}]
state = {"messages": [mock_message]}
result = self.logic.should_continue_market(state)
assert result == "tools_market"
def test_should_continue_market_without_tool_calls(self):
mock_message = MagicMock()
mock_message.tool_calls = []
state = {"messages": [mock_message]}
result = self.logic.should_continue_market(state)
assert result == "Msg Clear Market"
def test_should_continue_social_with_tool_calls(self):
mock_message = MagicMock()
mock_message.tool_calls = [{"name": "get_news"}]
state = {"messages": [mock_message]}
result = self.logic.should_continue_social(state)
assert result == "tools_social"
def test_should_continue_social_without_tool_calls(self):
mock_message = MagicMock()
mock_message.tool_calls = []
state = {"messages": [mock_message]}
result = self.logic.should_continue_social(state)
assert result == "Msg Clear Social"
def test_should_continue_news_with_tool_calls(self):
mock_message = MagicMock()
mock_message.tool_calls = [{"name": "get_global_news"}]
state = {"messages": [mock_message]}
result = self.logic.should_continue_news(state)
assert result == "tools_news"
def test_should_continue_news_without_tool_calls(self):
mock_message = MagicMock()
mock_message.tool_calls = []
state = {"messages": [mock_message]}
result = self.logic.should_continue_news(state)
assert result == "Msg Clear News"
def test_should_continue_fundamentals_with_tool_calls(self):
mock_message = MagicMock()
mock_message.tool_calls = [{"name": "get_balance_sheet"}]
state = {"messages": [mock_message]}
result = self.logic.should_continue_fundamentals(state)
assert result == "tools_fundamentals"
def test_should_continue_fundamentals_without_tool_calls(self):
mock_message = MagicMock()
mock_message.tool_calls = []
state = {"messages": [mock_message]}
result = self.logic.should_continue_fundamentals(state)
assert result == "Msg Clear Fundamentals"
class TestConditionalLogicDebate:
def setup_method(self):
self.logic = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=2)
def test_should_continue_debate_to_bear(self):
state = {
"investment_debate_state": InvestDebateState(
bull_history="",
bear_history="",
history="",
current_response="Bull: I think we should buy",
judge_decision="",
count=1,
)
}
result = self.logic.should_continue_debate(state)
assert result == "Bear Researcher"
def test_should_continue_debate_to_bull(self):
state = {
"investment_debate_state": InvestDebateState(
bull_history="",
bear_history="",
history="",
current_response="Bear: I disagree",
judge_decision="",
count=2,
)
}
result = self.logic.should_continue_debate(state)
assert result == "Bull Researcher"
def test_should_continue_debate_to_manager_max_rounds(self):
state = {
"investment_debate_state": InvestDebateState(
bull_history="",
bear_history="",
history="",
current_response="Bull: Final argument",
judge_decision="",
count=4,
)
}
result = self.logic.should_continue_debate(state)
assert result == "Research Manager"
def test_debate_rounds_configurable(self):
logic_one_round = ConditionalLogic(max_debate_rounds=1)
state = {
"investment_debate_state": InvestDebateState(
bull_history="",
bear_history="",
history="",
current_response="Bull: argument",
judge_decision="",
count=2,
)
}
result = logic_one_round.should_continue_debate(state)
assert result == "Research Manager"
class TestConditionalLogicRiskAnalysis:
def setup_method(self):
self.logic = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=2)
def test_should_continue_risk_to_safe(self):
state = {
"risk_debate_state": RiskDebateState(
risky_history="",
safe_history="",
neutral_history="",
history="",
latest_speaker="Risky Analyst",
current_risky_response="",
current_safe_response="",
current_neutral_response="",
judge_decision="",
count=1,
)
}
result = self.logic.should_continue_risk_analysis(state)
assert result == "Safe Analyst"
def test_should_continue_risk_to_neutral(self):
state = {
"risk_debate_state": RiskDebateState(
risky_history="",
safe_history="",
neutral_history="",
history="",
latest_speaker="Safe Analyst",
current_risky_response="",
current_safe_response="",
current_neutral_response="",
judge_decision="",
count=2,
)
}
result = self.logic.should_continue_risk_analysis(state)
assert result == "Neutral Analyst"
def test_should_continue_risk_to_risky(self):
state = {
"risk_debate_state": RiskDebateState(
risky_history="",
safe_history="",
neutral_history="",
history="",
latest_speaker="Neutral Analyst",
current_risky_response="",
current_safe_response="",
current_neutral_response="",
judge_decision="",
count=3,
)
}
result = self.logic.should_continue_risk_analysis(state)
assert result == "Risky Analyst"
def test_should_continue_risk_to_judge_max_rounds(self):
state = {
"risk_debate_state": RiskDebateState(
risky_history="",
safe_history="",
neutral_history="",
history="",
latest_speaker="Risky Analyst",
current_risky_response="",
current_safe_response="",
current_neutral_response="",
judge_decision="",
count=6,
)
}
result = self.logic.should_continue_risk_analysis(state)
assert result == "Risk Judge"
def test_risk_rounds_configurable(self):
logic_one_round = ConditionalLogic(max_risk_discuss_rounds=1)
state = {
"risk_debate_state": RiskDebateState(
risky_history="",
safe_history="",
neutral_history="",
history="",
latest_speaker="Neutral Analyst",
current_risky_response="",
current_safe_response="",
current_neutral_response="",
judge_decision="",
count=3,
)
}
result = logic_one_round.should_continue_risk_analysis(state)
assert result == "Risk Judge"