241 lines
8.7 KiB
Python
241 lines
8.7 KiB
Python
"""Unit tests for conditional logic."""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock
|
|
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
|
|
|
|
class TestConditionalLogic:
|
|
"""Tests for the ConditionalLogic class."""
|
|
|
|
@pytest.fixture
|
|
def logic(self):
|
|
"""Create a ConditionalLogic instance with default settings."""
|
|
return ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
|
|
|
@pytest.fixture
|
|
def logic_extended(self):
|
|
"""Create a ConditionalLogic instance with extended rounds."""
|
|
return ConditionalLogic(max_debate_rounds=3, max_risk_discuss_rounds=2)
|
|
|
|
@pytest.fixture
|
|
def state_with_tool_call(self):
|
|
"""Create a state with a tool call in the last message."""
|
|
msg = MagicMock()
|
|
msg.tool_calls = [{"name": "get_stock_data"}]
|
|
return {"messages": [msg]}
|
|
|
|
@pytest.fixture
|
|
def state_without_tool_call(self):
|
|
"""Create a state without tool calls."""
|
|
msg = MagicMock()
|
|
msg.tool_calls = []
|
|
return {"messages": [msg]}
|
|
|
|
|
|
class TestShouldContinueMarket(TestConditionalLogic):
|
|
"""Tests for should_continue_market method."""
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_tools_market_with_tool_call(self, logic, state_with_tool_call):
|
|
"""Test that tool calls route to tools_market."""
|
|
result = logic.should_continue_market(state_with_tool_call)
|
|
assert result == "tools_market"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_msg_clear_without_tool_call(self, logic, state_without_tool_call):
|
|
"""Test that no tool calls route to Msg Clear Market."""
|
|
result = logic.should_continue_market(state_without_tool_call)
|
|
assert result == "Msg Clear Market"
|
|
|
|
|
|
class TestShouldContinueSocial(TestConditionalLogic):
|
|
"""Tests for should_continue_social method."""
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_tools_social_with_tool_call(self, logic, state_with_tool_call):
|
|
"""Test that tool calls route to tools_social."""
|
|
result = logic.should_continue_social(state_with_tool_call)
|
|
assert result == "tools_social"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_msg_clear_without_tool_call(self, logic, state_without_tool_call):
|
|
"""Test that no tool calls route to Msg Clear Social."""
|
|
result = logic.should_continue_social(state_without_tool_call)
|
|
assert result == "Msg Clear Social"
|
|
|
|
|
|
class TestShouldContinueNews(TestConditionalLogic):
|
|
"""Tests for should_continue_news method."""
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_tools_news_with_tool_call(self, logic, state_with_tool_call):
|
|
"""Test that tool calls route to tools_news."""
|
|
result = logic.should_continue_news(state_with_tool_call)
|
|
assert result == "tools_news"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_msg_clear_without_tool_call(self, logic, state_without_tool_call):
|
|
"""Test that no tool calls route to Msg Clear News."""
|
|
result = logic.should_continue_news(state_without_tool_call)
|
|
assert result == "Msg Clear News"
|
|
|
|
|
|
class TestShouldContinueFundamentals(TestConditionalLogic):
|
|
"""Tests for should_continue_fundamentals method."""
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_tools_fundamentals_with_tool_call(self, logic, state_with_tool_call):
|
|
"""Test that tool calls route to tools_fundamentals."""
|
|
result = logic.should_continue_fundamentals(state_with_tool_call)
|
|
assert result == "tools_fundamentals"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_msg_clear_without_tool_call(self, logic, state_without_tool_call):
|
|
"""Test that no tool calls route to Msg Clear Fundamentals."""
|
|
result = logic.should_continue_fundamentals(state_without_tool_call)
|
|
assert result == "Msg Clear Fundamentals"
|
|
|
|
|
|
class TestShouldContinueDebate(TestConditionalLogic):
|
|
"""Tests for should_continue_debate method."""
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_research_manager_at_max_rounds(self, logic):
|
|
"""Test that debate ends at max rounds."""
|
|
state = {
|
|
"investment_debate_state": {
|
|
"count": 4, # 2 * max_debate_rounds = 2 * 1 = 2, but 4 > 2
|
|
"current_response": "Bull Analyst: Buy signal",
|
|
}
|
|
}
|
|
result = logic.should_continue_debate(state)
|
|
assert result == "Research Manager"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_bear_when_bull_speaks(self, logic):
|
|
"""Test that Bull speaker routes to Bear."""
|
|
state = {
|
|
"investment_debate_state": {
|
|
"count": 1,
|
|
"current_response": "Bull Analyst: Strong buy opportunity",
|
|
}
|
|
}
|
|
result = logic.should_continue_debate(state)
|
|
assert result == "Bear Researcher"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_bull_when_not_bull(self, logic):
|
|
"""Test that Bear speaker routes to Bull."""
|
|
state = {
|
|
"investment_debate_state": {
|
|
"count": 1,
|
|
"current_response": "Bear Analyst: High risk warning",
|
|
}
|
|
}
|
|
result = logic.should_continue_debate(state)
|
|
assert result == "Bull Researcher"
|
|
|
|
@pytest.mark.unit
|
|
def test_extended_debate_rounds(self, logic_extended):
|
|
"""Test debate with extended rounds."""
|
|
# With max_debate_rounds=3, max count = 2 * 3 = 6
|
|
state = {
|
|
"investment_debate_state": {
|
|
"count": 5, # Still under 6
|
|
"current_response": "Bull Analyst: Buy",
|
|
}
|
|
}
|
|
result = logic_extended.should_continue_debate(state)
|
|
assert result == "Bear Researcher"
|
|
|
|
@pytest.mark.unit
|
|
def test_extended_debate_ends_at_max(self, logic_extended):
|
|
"""Test extended debate ends at max rounds."""
|
|
state = {
|
|
"investment_debate_state": {
|
|
"count": 6, # 2 * max_debate_rounds = 6
|
|
"current_response": "Bull Analyst: Buy",
|
|
}
|
|
}
|
|
result = logic_extended.should_continue_debate(state)
|
|
assert result == "Research Manager"
|
|
|
|
|
|
class TestShouldContinueRiskAnalysis(TestConditionalLogic):
|
|
"""Tests for should_continue_risk_analysis method."""
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_risk_judge_at_max_rounds(self, logic):
|
|
"""Test that risk analysis ends at max rounds."""
|
|
state = {
|
|
"risk_debate_state": {
|
|
"count": 6, # 3 * max_risk_discuss_rounds = 3 * 1 = 3, but 6 > 3
|
|
"latest_speaker": "Aggressive Analyst",
|
|
}
|
|
}
|
|
result = logic.should_continue_risk_analysis(state)
|
|
assert result == "Risk Judge"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_conservative_after_aggressive(self, logic):
|
|
"""Test that Aggressive speaker routes to Conservative."""
|
|
state = {
|
|
"risk_debate_state": {
|
|
"count": 1,
|
|
"latest_speaker": "Aggressive Analyst: Go all in!",
|
|
}
|
|
}
|
|
result = logic.should_continue_risk_analysis(state)
|
|
assert result == "Conservative Analyst"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_neutral_after_conservative(self, logic):
|
|
"""Test that Conservative speaker routes to Neutral."""
|
|
state = {
|
|
"risk_debate_state": {
|
|
"count": 1,
|
|
"latest_speaker": "Conservative Analyst: Stay cautious",
|
|
}
|
|
}
|
|
result = logic.should_continue_risk_analysis(state)
|
|
assert result == "Neutral Analyst"
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_aggressive_after_neutral(self, logic):
|
|
"""Test that Neutral speaker routes to Aggressive."""
|
|
state = {
|
|
"risk_debate_state": {
|
|
"count": 1,
|
|
"latest_speaker": "Neutral Analyst: Balanced view",
|
|
}
|
|
}
|
|
result = logic.should_continue_risk_analysis(state)
|
|
assert result == "Aggressive Analyst"
|
|
|
|
@pytest.mark.unit
|
|
def test_extended_risk_rounds(self, logic_extended):
|
|
"""Test risk analysis with extended rounds."""
|
|
# With max_risk_discuss_rounds=2, max count = 3 * 2 = 6
|
|
state = {
|
|
"risk_debate_state": {
|
|
"count": 5, # Still under 6
|
|
"latest_speaker": "Aggressive Analyst",
|
|
}
|
|
}
|
|
result = logic_extended.should_continue_risk_analysis(state)
|
|
assert result == "Conservative Analyst"
|
|
|
|
@pytest.mark.unit
|
|
def test_extended_risk_ends_at_max(self, logic_extended):
|
|
"""Test extended risk analysis ends at max rounds."""
|
|
state = {
|
|
"risk_debate_state": {
|
|
"count": 6, # 3 * max_risk_discuss_rounds = 6
|
|
"latest_speaker": "Aggressive Analyst",
|
|
}
|
|
}
|
|
result = logic_extended.should_continue_risk_analysis(state)
|
|
assert result == "Risk Judge"
|