feat(graph): integrate new analysts into workflow - Fixes #17

This commit is contained in:
Andrew Kaszubski 2025-12-26 20:11:28 +11:00
parent a17fc1f029
commit 5a0606b59f
8 changed files with 663 additions and 5 deletions

View File

@ -0,0 +1,506 @@
"""Tests for Issue #17: Analyst integration into graph/setup.py workflow.
This module tests the integration of new analysts (momentum, macro, correlation)
into the TradingAgents graph workflow.
"""
import pytest
from unittest.mock import MagicMock, patch
from typing import get_type_hints
import sys
# Check if langchain dependencies are available
try:
import langchain_core
LANGCHAIN_AVAILABLE = True
except ImportError:
LANGCHAIN_AVAILABLE = False
# Skip all tests if langchain not available
pytestmark = pytest.mark.skipif(
not LANGCHAIN_AVAILABLE,
reason="langchain_core not installed"
)
class TestAgentStateReports:
"""Test that AgentState has the new report fields."""
def test_agent_state_has_momentum_report(self):
"""AgentState should have momentum_report field."""
from tradingagents.agents.utils.agent_states import AgentState
hints = get_type_hints(AgentState)
assert "momentum_report" in hints
def test_agent_state_has_macro_report(self):
"""AgentState should have macro_report field."""
from tradingagents.agents.utils.agent_states import AgentState
hints = get_type_hints(AgentState)
assert "macro_report" in hints
def test_agent_state_has_correlation_report(self):
"""AgentState should have correlation_report field."""
from tradingagents.agents.utils.agent_states import AgentState
hints = get_type_hints(AgentState)
assert "correlation_report" in hints
class TestConditionalLogicMethods:
"""Test that ConditionalLogic has methods for new analysts."""
def test_should_continue_momentum_exists(self):
"""ConditionalLogic should have should_continue_momentum method."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
assert hasattr(cl, "should_continue_momentum")
assert callable(cl.should_continue_momentum)
def test_should_continue_macro_exists(self):
"""ConditionalLogic should have should_continue_macro method."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
assert hasattr(cl, "should_continue_macro")
assert callable(cl.should_continue_macro)
def test_should_continue_correlation_exists(self):
"""ConditionalLogic should have should_continue_correlation method."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
assert hasattr(cl, "should_continue_correlation")
assert callable(cl.should_continue_correlation)
def test_momentum_conditional_returns_tools(self):
"""should_continue_momentum should return 'tools_momentum' when tool_calls exist."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
mock_message = MagicMock()
mock_message.tool_calls = [{"name": "test"}]
state = {"messages": [mock_message]}
result = cl.should_continue_momentum(state)
assert result == "tools_momentum"
def test_momentum_conditional_returns_clear(self):
"""should_continue_momentum should return 'Msg Clear Momentum' when no tool_calls."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
mock_message = MagicMock()
mock_message.tool_calls = []
state = {"messages": [mock_message]}
result = cl.should_continue_momentum(state)
assert result == "Msg Clear Momentum"
def test_macro_conditional_returns_tools(self):
"""should_continue_macro should return 'tools_macro' when tool_calls exist."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
mock_message = MagicMock()
mock_message.tool_calls = [{"name": "test"}]
state = {"messages": [mock_message]}
result = cl.should_continue_macro(state)
assert result == "tools_macro"
def test_macro_conditional_returns_clear(self):
"""should_continue_macro should return 'Msg Clear Macro' when no tool_calls."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
mock_message = MagicMock()
mock_message.tool_calls = []
state = {"messages": [mock_message]}
result = cl.should_continue_macro(state)
assert result == "Msg Clear Macro"
def test_correlation_conditional_returns_tools(self):
"""should_continue_correlation should return 'tools_correlation' when tool_calls exist."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
mock_message = MagicMock()
mock_message.tool_calls = [{"name": "test"}]
state = {"messages": [mock_message]}
result = cl.should_continue_correlation(state)
assert result == "tools_correlation"
def test_correlation_conditional_returns_clear(self):
"""should_continue_correlation should return 'Msg Clear Correlation' when no tool_calls."""
from tradingagents.graph.conditional_logic import ConditionalLogic
cl = ConditionalLogic()
mock_message = MagicMock()
mock_message.tool_calls = []
state = {"messages": [mock_message]}
result = cl.should_continue_correlation(state)
assert result == "Msg Clear Correlation"
class TestAgentImports:
"""Test that new analysts are properly exported from agents module."""
def test_create_momentum_analyst_import(self):
"""create_momentum_analyst should be importable from agents."""
from tradingagents.agents import create_momentum_analyst
assert callable(create_momentum_analyst)
def test_create_macro_analyst_import(self):
"""create_macro_analyst should be importable from agents."""
from tradingagents.agents import create_macro_analyst
assert callable(create_macro_analyst)
def test_create_correlation_analyst_import(self):
"""create_correlation_analyst should be importable from agents."""
from tradingagents.agents import create_correlation_analyst
assert callable(create_correlation_analyst)
def test_create_position_sizing_manager_import(self):
"""create_position_sizing_manager should be importable from agents."""
from tradingagents.agents import create_position_sizing_manager
assert callable(create_position_sizing_manager)
class TestAnalystsModuleExports:
"""Test analysts submodule exports."""
def test_analysts_module_exports_momentum(self):
"""analysts module should export create_momentum_analyst."""
from tradingagents.agents.analysts import create_momentum_analyst
assert callable(create_momentum_analyst)
def test_analysts_module_exports_macro(self):
"""analysts module should export create_macro_analyst."""
from tradingagents.agents.analysts import create_macro_analyst
assert callable(create_macro_analyst)
def test_analysts_module_exports_correlation(self):
"""analysts module should export create_correlation_analyst."""
from tradingagents.agents.analysts import create_correlation_analyst
assert callable(create_correlation_analyst)
def test_analysts_module_all_exports(self):
"""analysts __all__ should include all analyst creators."""
from tradingagents.agents import analysts
assert "create_momentum_analyst" in analysts.__all__
assert "create_macro_analyst" in analysts.__all__
assert "create_correlation_analyst" in analysts.__all__
class TestManagersModuleExports:
"""Test managers submodule exports."""
def test_managers_module_exports_position_sizing(self):
"""managers module should export create_position_sizing_manager."""
from tradingagents.agents.managers import create_position_sizing_manager
assert callable(create_position_sizing_manager)
def test_managers_module_all_exports(self):
"""managers __all__ should include position_sizing_manager."""
from tradingagents.agents import managers
assert "create_position_sizing_manager" in managers.__all__
class TestToolImports:
"""Test that tool functions are importable from analyst modules."""
def test_momentum_tools_importable(self):
"""Momentum analyst tools should be importable."""
from tradingagents.agents.analysts.momentum_analyst import (
get_multi_timeframe_momentum,
get_adx_analysis,
get_momentum_divergence,
)
assert callable(get_multi_timeframe_momentum)
assert callable(get_adx_analysis)
assert callable(get_momentum_divergence)
def test_macro_tools_importable(self):
"""Macro analyst tools should be importable."""
from tradingagents.agents.analysts.macro_analyst import (
get_economic_regime_analysis,
get_yield_curve_analysis,
get_monetary_policy_analysis,
get_inflation_regime_analysis,
)
assert callable(get_economic_regime_analysis)
assert callable(get_yield_curve_analysis)
assert callable(get_monetary_policy_analysis)
assert callable(get_inflation_regime_analysis)
def test_correlation_tools_importable(self):
"""Correlation analyst tools should be importable."""
from tradingagents.agents.analysts.correlation_analyst import (
get_cross_asset_correlation_analysis,
get_sector_rotation_analysis,
get_correlation_matrix,
get_rolling_correlation_trend,
)
assert callable(get_cross_asset_correlation_analysis)
assert callable(get_sector_rotation_analysis)
assert callable(get_correlation_matrix)
assert callable(get_rolling_correlation_trend)
class TestTradingGraphToolNodes:
"""Test that trading_graph.py has tool nodes for new analysts."""
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.set_config")
@patch("os.makedirs")
def test_tool_nodes_include_momentum(self, mock_makedirs, mock_set_config, mock_llm):
"""_create_tool_nodes should include momentum tools."""
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_llm.return_value = MagicMock()
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
tool_nodes = graph._create_tool_nodes()
assert "momentum" in tool_nodes
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.set_config")
@patch("os.makedirs")
def test_tool_nodes_include_macro(self, mock_makedirs, mock_set_config, mock_llm):
"""_create_tool_nodes should include macro tools."""
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_llm.return_value = MagicMock()
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
tool_nodes = graph._create_tool_nodes()
assert "macro" in tool_nodes
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.set_config")
@patch("os.makedirs")
def test_tool_nodes_include_correlation(self, mock_makedirs, mock_set_config, mock_llm):
"""_create_tool_nodes should include correlation tools."""
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_llm.return_value = MagicMock()
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
tool_nodes = graph._create_tool_nodes()
assert "correlation" in tool_nodes
class TestGraphSetupDocstring:
"""Test that setup_graph docstring documents new analysts."""
def test_docstring_mentions_momentum(self):
"""setup_graph docstring should mention momentum analyst."""
from tradingagents.graph.setup import GraphSetup
docstring = GraphSetup.setup_graph.__doc__
assert "momentum" in docstring.lower()
def test_docstring_mentions_macro(self):
"""setup_graph docstring should mention macro analyst."""
from tradingagents.graph.setup import GraphSetup
docstring = GraphSetup.setup_graph.__doc__
assert "macro" in docstring.lower()
def test_docstring_mentions_correlation(self):
"""setup_graph docstring should mention correlation analyst."""
from tradingagents.graph.setup import GraphSetup
docstring = GraphSetup.setup_graph.__doc__
assert "correlation" in docstring.lower()
class TestLogStateNewReports:
"""Test that _log_state includes new report fields."""
def test_log_state_includes_momentum_report(self):
"""_log_state should log momentum_report if present."""
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
graph.log_states_dict = {}
graph.ticker = "TEST"
final_state = {
"company_of_interest": "TEST",
"trade_date": "2024-01-01",
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"momentum_report": "Momentum analysis result",
"macro_report": "",
"correlation_report": "",
"investment_debate_state": {
"bull_history": "",
"bear_history": "",
"history": "",
"current_response": "",
"judge_decision": "",
},
"trader_investment_plan": "",
"risk_debate_state": {
"risky_history": "",
"safe_history": "",
"neutral_history": "",
"history": "",
"judge_decision": "",
},
"investment_plan": "",
"final_trade_decision": "",
}
with patch("builtins.open", MagicMock()):
with patch("json.dump"):
with patch("pathlib.Path.mkdir"):
graph._log_state("2024-01-01", final_state)
logged = graph.log_states_dict["2024-01-01"]
assert "momentum_report" in logged
assert logged["momentum_report"] == "Momentum analysis result"
def test_log_state_includes_macro_report(self):
"""_log_state should log macro_report if present."""
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
graph.log_states_dict = {}
graph.ticker = "TEST"
final_state = {
"company_of_interest": "TEST",
"trade_date": "2024-01-01",
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"momentum_report": "",
"macro_report": "Macro analysis result",
"correlation_report": "",
"investment_debate_state": {
"bull_history": "",
"bear_history": "",
"history": "",
"current_response": "",
"judge_decision": "",
},
"trader_investment_plan": "",
"risk_debate_state": {
"risky_history": "",
"safe_history": "",
"neutral_history": "",
"history": "",
"judge_decision": "",
},
"investment_plan": "",
"final_trade_decision": "",
}
with patch("builtins.open", MagicMock()):
with patch("json.dump"):
with patch("pathlib.Path.mkdir"):
graph._log_state("2024-01-01", final_state)
logged = graph.log_states_dict["2024-01-01"]
assert "macro_report" in logged
assert logged["macro_report"] == "Macro analysis result"
def test_log_state_includes_correlation_report(self):
"""_log_state should log correlation_report if present."""
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
graph.log_states_dict = {}
graph.ticker = "TEST"
final_state = {
"company_of_interest": "TEST",
"trade_date": "2024-01-01",
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"momentum_report": "",
"macro_report": "",
"correlation_report": "Correlation analysis result",
"investment_debate_state": {
"bull_history": "",
"bear_history": "",
"history": "",
"current_response": "",
"judge_decision": "",
},
"trader_investment_plan": "",
"risk_debate_state": {
"risky_history": "",
"safe_history": "",
"neutral_history": "",
"history": "",
"judge_decision": "",
},
"investment_plan": "",
"final_trade_decision": "",
}
with patch("builtins.open", MagicMock()):
with patch("json.dump"):
with patch("pathlib.Path.mkdir"):
graph._log_state("2024-01-01", final_state)
logged = graph.log_states_dict["2024-01-01"]
assert "correlation_report" in logged
assert logged["correlation_report"] == "Correlation analysis result"
def test_log_state_handles_missing_new_reports(self):
"""_log_state should handle missing new report fields gracefully."""
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
graph.log_states_dict = {}
graph.ticker = "TEST"
# State without new report fields (backward compatibility)
final_state = {
"company_of_interest": "TEST",
"trade_date": "2024-01-01",
"investment_debate_state": {
"bull_history": "",
"bear_history": "",
"history": "",
"current_response": "",
"judge_decision": "",
},
"trader_investment_plan": "",
"risk_debate_state": {
"risky_history": "",
"safe_history": "",
"neutral_history": "",
"history": "",
"judge_decision": "",
},
"investment_plan": "",
"final_trade_decision": "",
}
with patch("builtins.open", MagicMock()):
with patch("json.dump"):
with patch("pathlib.Path.mkdir"):
graph._log_state("2024-01-01", final_state)
logged = graph.log_states_dict["2024-01-01"]
# Should default to empty string
assert logged["momentum_report"] == ""
assert logged["macro_report"] == ""
assert logged["correlation_report"] == ""

View File

@ -6,6 +6,9 @@ from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .analysts.momentum_analyst import create_momentum_analyst
from .analysts.macro_analyst import create_macro_analyst
from .analysts.correlation_analyst import create_correlation_analyst
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
@ -16,6 +19,7 @@ from .risk_mgmt.neutral_debator import create_neutral_debator
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .managers.position_sizing_manager import create_position_sizing_manager
from .trader.trader import create_trader
@ -30,10 +34,14 @@ __all__ = [
"create_research_manager",
"create_fundamentals_analyst",
"create_market_analyst",
"create_momentum_analyst",
"create_macro_analyst",
"create_correlation_analyst",
"create_neutral_debator",
"create_news_analyst",
"create_risky_debator",
"create_risk_manager",
"create_position_sizing_manager",
"create_safe_debator",
"create_social_media_analyst",
"create_trader",

View File

@ -0,0 +1,29 @@
"""Analyst agents for market analysis.
This module provides specialized analyst agents for different types of analysis:
- Fundamentals Analyst: Company financial analysis
- Market Analyst: Technical and market structure analysis
- News Analyst: News sentiment and event analysis
- Social Media Analyst: Social sentiment analysis
- Momentum Analyst: Multi-timeframe momentum analysis (Issue #13)
- Macro Analyst: Macroeconomic and FRED data analysis (Issue #14)
- Correlation Analyst: Cross-asset correlation and sector rotation (Issue #15)
"""
from .fundamentals_analyst import create_fundamentals_analyst
from .market_analyst import create_market_analyst
from .news_analyst import create_news_analyst
from .social_media_analyst import create_social_media_analyst
from .momentum_analyst import create_momentum_analyst
from .macro_analyst import create_macro_analyst
from .correlation_analyst import create_correlation_analyst
__all__ = [
"create_fundamentals_analyst",
"create_market_analyst",
"create_news_analyst",
"create_social_media_analyst",
"create_momentum_analyst",
"create_macro_analyst",
"create_correlation_analyst",
]

View File

@ -1 +1,17 @@
"""Manager agents for portfolio and risk management."""
"""Manager agents for portfolio and risk management.
This module provides manager agents for orchestrating research and risk:
- Research Manager: Coordinates research activities
- Risk Manager: Manages risk assessment
- Position Sizing Manager: Optimal position sizing (Issue #16)
"""
from .research_manager import create_research_manager
from .risk_manager import create_risk_manager
from .position_sizing_manager import create_position_sizing_manager
__all__ = [
"create_research_manager",
"create_risk_manager",
"create_position_sizing_manager",
]

View File

@ -60,6 +60,9 @@ class AgentState(MessagesState):
str, "Report from the News Researcher of current world affairs"
]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
momentum_report: Annotated[str, "Report from the Momentum Analyst"]
macro_report: Annotated[str, "Report from the Macro Analyst"]
correlation_report: Annotated[str, "Report from the Correlation Analyst"]
# researcher team discussion step
investment_debate_state: Annotated[

View File

@ -43,6 +43,30 @@ class ConditionalLogic:
return "tools_fundamentals"
return "Msg Clear Fundamentals"
def should_continue_momentum(self, state: AgentState):
"""Determine if momentum analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_momentum"
return "Msg Clear Momentum"
def should_continue_macro(self, state: AgentState):
"""Determine if macro analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_macro"
return "Msg Clear Macro"
def should_continue_correlation(self, state: AgentState):
"""Determine if correlation analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_correlation"
return "Msg Clear Correlation"
def should_continue_debate(self, state: AgentState) -> str:
"""Determine if debate should continue."""

View File

@ -48,6 +48,9 @@ class GraphSetup:
- "social": Social media analyst
- "news": News analyst
- "fundamentals": Fundamentals analyst
- "momentum": Momentum analyst (multi-TF momentum, ROC, ADX)
- "macro": Macro analyst (FRED data, economic regimes)
- "correlation": Correlation analyst (cross-asset, sector rotation)
"""
if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
@ -85,6 +88,27 @@ class GraphSetup:
delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
if "momentum" in selected_analysts:
analyst_nodes["momentum"] = create_momentum_analyst(
self.quick_thinking_llm
)
delete_nodes["momentum"] = create_msg_delete()
tool_nodes["momentum"] = self.tool_nodes["momentum"]
if "macro" in selected_analysts:
analyst_nodes["macro"] = create_macro_analyst(
self.quick_thinking_llm
)
delete_nodes["macro"] = create_msg_delete()
tool_nodes["macro"] = self.tool_nodes["macro"]
if "correlation" in selected_analysts:
analyst_nodes["correlation"] = create_correlation_analyst(
self.quick_thinking_llm
)
delete_nodes["correlation"] = create_msg_delete()
tool_nodes["correlation"] = self.tool_nodes["correlation"]
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory

View File

@ -36,6 +36,25 @@ from tradingagents.agents.utils.agent_utils import (
get_global_news
)
# Import tools from new analysts
from tradingagents.agents.analysts.momentum_analyst import (
get_multi_timeframe_momentum,
get_adx_analysis,
get_momentum_divergence,
)
from tradingagents.agents.analysts.macro_analyst import (
get_economic_regime_analysis,
get_yield_curve_analysis,
get_monetary_policy_analysis,
get_inflation_regime_analysis,
)
from tradingagents.agents.analysts.correlation_analyst import (
get_cross_asset_correlation_analysis,
get_sector_rotation_analysis,
get_correlation_matrix,
get_rolling_correlation_trend,
)
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
@ -223,6 +242,32 @@ class TradingAgentsGraph:
get_income_statement,
]
),
"momentum": ToolNode(
[
# Momentum analysis tools
get_multi_timeframe_momentum,
get_adx_analysis,
get_momentum_divergence,
]
),
"macro": ToolNode(
[
# Macroeconomic analysis tools
get_economic_regime_analysis,
get_yield_curve_analysis,
get_monetary_policy_analysis,
get_inflation_regime_analysis,
]
),
"correlation": ToolNode(
[
# Correlation analysis tools
get_cross_asset_correlation_analysis,
get_sector_rotation_analysis,
get_correlation_matrix,
get_rolling_correlation_trend,
]
),
}
def propagate(self, company_name, trade_date):
@ -265,10 +310,13 @@ class TradingAgentsGraph:
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
"market_report": final_state["market_report"],
"sentiment_report": final_state["sentiment_report"],
"news_report": final_state["news_report"],
"fundamentals_report": final_state["fundamentals_report"],
"market_report": final_state.get("market_report", ""),
"sentiment_report": final_state.get("sentiment_report", ""),
"news_report": final_state.get("news_report", ""),
"fundamentals_report": final_state.get("fundamentals_report", ""),
"momentum_report": final_state.get("momentum_report", ""),
"macro_report": final_state.get("macro_report", ""),
"correlation_report": final_state.get("correlation_report", ""),
"investment_debate_state": {
"bull_history": final_state["investment_debate_state"]["bull_history"],
"bear_history": final_state["investment_debate_state"]["bear_history"],