180 lines
6.0 KiB
Python
180 lines
6.0 KiB
Python
"""Unit tests for propagation module."""
|
|
|
|
from unittest.mock import Mock, patch
|
|
import pytest
|
|
|
|
|
|
class TestPropagator:
|
|
"""Test suite for Propagator class."""
|
|
|
|
def test_propagator_initialization(self):
|
|
"""Test Propagator initialization."""
|
|
# Mock propagator
|
|
propagator = Mock()
|
|
propagator.create_initial_state = Mock()
|
|
propagator.get_graph_args = Mock()
|
|
|
|
assert hasattr(propagator, 'create_initial_state')
|
|
assert hasattr(propagator, 'get_graph_args')
|
|
assert callable(propagator.create_initial_state)
|
|
assert callable(propagator.get_graph_args)
|
|
|
|
def test_create_initial_state(self):
|
|
"""Test creating initial state for propagation."""
|
|
propagator = Mock()
|
|
|
|
# Mock the create_initial_state method
|
|
expected_state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-05-10",
|
|
"messages": [],
|
|
"market_report": "",
|
|
"sentiment_report": "",
|
|
"news_report": "",
|
|
"fundamentals_report": "",
|
|
"investment_debate_state": {
|
|
"bull_history": [],
|
|
"bear_history": [],
|
|
"history": [],
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
},
|
|
"trader_investment_plan": "",
|
|
"risk_debate_state": {
|
|
"risky_history": [],
|
|
"safe_history": [],
|
|
"neutral_history": [],
|
|
"history": [],
|
|
"judge_decision": "",
|
|
},
|
|
"investment_plan": "",
|
|
"final_trade_decision": "",
|
|
}
|
|
|
|
propagator.create_initial_state = Mock(return_value=expected_state)
|
|
|
|
# Test
|
|
state = propagator.create_initial_state("AAPL", "2024-05-10")
|
|
|
|
assert state["company_of_interest"] == "AAPL"
|
|
assert state["trade_date"] == "2024-05-10"
|
|
assert state["messages"] == []
|
|
assert "investment_debate_state" in state
|
|
assert "risk_debate_state" in state
|
|
propagator.create_initial_state.assert_called_once_with("AAPL", "2024-05-10")
|
|
|
|
def test_get_graph_args(self):
|
|
"""Test getting graph arguments."""
|
|
propagator = Mock()
|
|
|
|
# Mock the get_graph_args method
|
|
expected_args = {
|
|
"recursion_limit": 100,
|
|
"config": {"tags": ["tradingagents"]},
|
|
}
|
|
|
|
propagator.get_graph_args = Mock(return_value=expected_args)
|
|
|
|
# Test
|
|
args = propagator.get_graph_args()
|
|
|
|
assert "recursion_limit" in args
|
|
assert "config" in args
|
|
assert args["recursion_limit"] == 100
|
|
propagator.get_graph_args.assert_called_once()
|
|
|
|
def test_propagate_with_different_tickers(self):
|
|
"""Test propagation with different ticker symbols."""
|
|
propagator = Mock()
|
|
|
|
tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"]
|
|
|
|
for ticker in tickers:
|
|
state = {
|
|
"company_of_interest": ticker,
|
|
"trade_date": "2024-05-10",
|
|
"messages": []
|
|
}
|
|
propagator.create_initial_state = Mock(return_value=state)
|
|
|
|
result = propagator.create_initial_state(ticker, "2024-05-10")
|
|
assert result["company_of_interest"] == ticker
|
|
|
|
def test_propagate_with_different_dates(self):
|
|
"""Test propagation with different dates."""
|
|
propagator = Mock()
|
|
|
|
dates = ["2024-01-01", "2024-06-15", "2024-12-31"]
|
|
|
|
for date in dates:
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": date,
|
|
"messages": []
|
|
}
|
|
propagator.create_initial_state = Mock(return_value=state)
|
|
|
|
result = propagator.create_initial_state("AAPL", date)
|
|
assert result["trade_date"] == date
|
|
|
|
def test_propagate_error_handling(self):
|
|
"""Test error handling in propagation."""
|
|
propagator = Mock()
|
|
|
|
# Simulate error
|
|
propagator.create_initial_state = Mock(side_effect=ValueError("Invalid ticker"))
|
|
|
|
with pytest.raises(ValueError):
|
|
propagator.create_initial_state("INVALID", "2024-05-10")
|
|
|
|
propagator.create_initial_state.assert_called_once()
|
|
|
|
def test_graph_args_with_custom_config(self):
|
|
"""Test graph args with custom configuration."""
|
|
propagator = Mock()
|
|
|
|
custom_config = {
|
|
"recursion_limit": 200,
|
|
"config": {
|
|
"tags": ["custom", "test"],
|
|
"metadata": {"version": "1.0"}
|
|
}
|
|
}
|
|
|
|
propagator.get_graph_args = Mock(return_value=custom_config)
|
|
|
|
args = propagator.get_graph_args()
|
|
assert args["recursion_limit"] == 200
|
|
assert "custom" in args["config"]["tags"]
|
|
assert args["config"]["metadata"]["version"] == "1.0"
|
|
|
|
def test_initial_state_completeness(self):
|
|
"""Test that initial state contains all required fields."""
|
|
propagator = Mock()
|
|
|
|
required_fields = [
|
|
"company_of_interest",
|
|
"trade_date",
|
|
"messages",
|
|
"market_report",
|
|
"sentiment_report",
|
|
"news_report",
|
|
"fundamentals_report",
|
|
"investment_debate_state",
|
|
"trader_investment_plan",
|
|
"risk_debate_state",
|
|
"investment_plan",
|
|
"final_trade_decision"
|
|
]
|
|
|
|
state = {field: "" for field in required_fields}
|
|
state["messages"] = []
|
|
state["investment_debate_state"] = {}
|
|
state["risk_debate_state"] = {}
|
|
|
|
propagator.create_initial_state = Mock(return_value=state)
|
|
|
|
result = propagator.create_initial_state("AAPL", "2024-05-10")
|
|
|
|
for field in required_fields:
|
|
assert field in result, f"Missing required field: {field}" |