From 5a0606b59f2811feb82f970f4136fa5077151d52 Mon Sep 17 00:00:00 2001 From: Andrew Kaszubski Date: Fri, 26 Dec 2025 20:11:28 +1100 Subject: [PATCH] feat(graph): integrate new analysts into workflow - Fixes #17 --- tests/unit/graph/test_analyst_integration.py | 506 +++++++++++++++++++ tradingagents/agents/__init__.py | 8 + tradingagents/agents/analysts/__init__.py | 29 ++ tradingagents/agents/managers/__init__.py | 18 +- tradingagents/agents/utils/agent_states.py | 3 + tradingagents/graph/conditional_logic.py | 24 + tradingagents/graph/setup.py | 24 + tradingagents/graph/trading_graph.py | 56 +- 8 files changed, 663 insertions(+), 5 deletions(-) create mode 100644 tests/unit/graph/test_analyst_integration.py create mode 100644 tradingagents/agents/analysts/__init__.py diff --git a/tests/unit/graph/test_analyst_integration.py b/tests/unit/graph/test_analyst_integration.py new file mode 100644 index 00000000..302558e4 --- /dev/null +++ b/tests/unit/graph/test_analyst_integration.py @@ -0,0 +1,506 @@ +"""Tests for Issue #17: Analyst integration into graph/setup.py workflow. + +This module tests the integration of new analysts (momentum, macro, correlation) +into the TradingAgents graph workflow. +""" + +import pytest +from unittest.mock import MagicMock, patch +from typing import get_type_hints +import sys + +# Check if langchain dependencies are available +try: + import langchain_core + LANGCHAIN_AVAILABLE = True +except ImportError: + LANGCHAIN_AVAILABLE = False + +# Skip all tests if langchain not available +pytestmark = pytest.mark.skipif( + not LANGCHAIN_AVAILABLE, + reason="langchain_core not installed" +) + + +class TestAgentStateReports: + """Test that AgentState has the new report fields.""" + + def test_agent_state_has_momentum_report(self): + """AgentState should have momentum_report field.""" + from tradingagents.agents.utils.agent_states import AgentState + hints = get_type_hints(AgentState) + assert "momentum_report" in hints + + def test_agent_state_has_macro_report(self): + """AgentState should have macro_report field.""" + from tradingagents.agents.utils.agent_states import AgentState + hints = get_type_hints(AgentState) + assert "macro_report" in hints + + def test_agent_state_has_correlation_report(self): + """AgentState should have correlation_report field.""" + from tradingagents.agents.utils.agent_states import AgentState + hints = get_type_hints(AgentState) + assert "correlation_report" in hints + + +class TestConditionalLogicMethods: + """Test that ConditionalLogic has methods for new analysts.""" + + def test_should_continue_momentum_exists(self): + """ConditionalLogic should have should_continue_momentum method.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + assert hasattr(cl, "should_continue_momentum") + assert callable(cl.should_continue_momentum) + + def test_should_continue_macro_exists(self): + """ConditionalLogic should have should_continue_macro method.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + assert hasattr(cl, "should_continue_macro") + assert callable(cl.should_continue_macro) + + def test_should_continue_correlation_exists(self): + """ConditionalLogic should have should_continue_correlation method.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + assert hasattr(cl, "should_continue_correlation") + assert callable(cl.should_continue_correlation) + + def test_momentum_conditional_returns_tools(self): + """should_continue_momentum should return 'tools_momentum' when tool_calls exist.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + + mock_message = MagicMock() + mock_message.tool_calls = [{"name": "test"}] + state = {"messages": [mock_message]} + + result = cl.should_continue_momentum(state) + assert result == "tools_momentum" + + def test_momentum_conditional_returns_clear(self): + """should_continue_momentum should return 'Msg Clear Momentum' when no tool_calls.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + + mock_message = MagicMock() + mock_message.tool_calls = [] + state = {"messages": [mock_message]} + + result = cl.should_continue_momentum(state) + assert result == "Msg Clear Momentum" + + def test_macro_conditional_returns_tools(self): + """should_continue_macro should return 'tools_macro' when tool_calls exist.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + + mock_message = MagicMock() + mock_message.tool_calls = [{"name": "test"}] + state = {"messages": [mock_message]} + + result = cl.should_continue_macro(state) + assert result == "tools_macro" + + def test_macro_conditional_returns_clear(self): + """should_continue_macro should return 'Msg Clear Macro' when no tool_calls.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + + mock_message = MagicMock() + mock_message.tool_calls = [] + state = {"messages": [mock_message]} + + result = cl.should_continue_macro(state) + assert result == "Msg Clear Macro" + + def test_correlation_conditional_returns_tools(self): + """should_continue_correlation should return 'tools_correlation' when tool_calls exist.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + + mock_message = MagicMock() + mock_message.tool_calls = [{"name": "test"}] + state = {"messages": [mock_message]} + + result = cl.should_continue_correlation(state) + assert result == "tools_correlation" + + def test_correlation_conditional_returns_clear(self): + """should_continue_correlation should return 'Msg Clear Correlation' when no tool_calls.""" + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic() + + mock_message = MagicMock() + mock_message.tool_calls = [] + state = {"messages": [mock_message]} + + result = cl.should_continue_correlation(state) + assert result == "Msg Clear Correlation" + + +class TestAgentImports: + """Test that new analysts are properly exported from agents module.""" + + def test_create_momentum_analyst_import(self): + """create_momentum_analyst should be importable from agents.""" + from tradingagents.agents import create_momentum_analyst + assert callable(create_momentum_analyst) + + def test_create_macro_analyst_import(self): + """create_macro_analyst should be importable from agents.""" + from tradingagents.agents import create_macro_analyst + assert callable(create_macro_analyst) + + def test_create_correlation_analyst_import(self): + """create_correlation_analyst should be importable from agents.""" + from tradingagents.agents import create_correlation_analyst + assert callable(create_correlation_analyst) + + def test_create_position_sizing_manager_import(self): + """create_position_sizing_manager should be importable from agents.""" + from tradingagents.agents import create_position_sizing_manager + assert callable(create_position_sizing_manager) + + +class TestAnalystsModuleExports: + """Test analysts submodule exports.""" + + def test_analysts_module_exports_momentum(self): + """analysts module should export create_momentum_analyst.""" + from tradingagents.agents.analysts import create_momentum_analyst + assert callable(create_momentum_analyst) + + def test_analysts_module_exports_macro(self): + """analysts module should export create_macro_analyst.""" + from tradingagents.agents.analysts import create_macro_analyst + assert callable(create_macro_analyst) + + def test_analysts_module_exports_correlation(self): + """analysts module should export create_correlation_analyst.""" + from tradingagents.agents.analysts import create_correlation_analyst + assert callable(create_correlation_analyst) + + def test_analysts_module_all_exports(self): + """analysts __all__ should include all analyst creators.""" + from tradingagents.agents import analysts + assert "create_momentum_analyst" in analysts.__all__ + assert "create_macro_analyst" in analysts.__all__ + assert "create_correlation_analyst" in analysts.__all__ + + +class TestManagersModuleExports: + """Test managers submodule exports.""" + + def test_managers_module_exports_position_sizing(self): + """managers module should export create_position_sizing_manager.""" + from tradingagents.agents.managers import create_position_sizing_manager + assert callable(create_position_sizing_manager) + + def test_managers_module_all_exports(self): + """managers __all__ should include position_sizing_manager.""" + from tradingagents.agents import managers + assert "create_position_sizing_manager" in managers.__all__ + + +class TestToolImports: + """Test that tool functions are importable from analyst modules.""" + + def test_momentum_tools_importable(self): + """Momentum analyst tools should be importable.""" + from tradingagents.agents.analysts.momentum_analyst import ( + get_multi_timeframe_momentum, + get_adx_analysis, + get_momentum_divergence, + ) + assert callable(get_multi_timeframe_momentum) + assert callable(get_adx_analysis) + assert callable(get_momentum_divergence) + + def test_macro_tools_importable(self): + """Macro analyst tools should be importable.""" + from tradingagents.agents.analysts.macro_analyst import ( + get_economic_regime_analysis, + get_yield_curve_analysis, + get_monetary_policy_analysis, + get_inflation_regime_analysis, + ) + assert callable(get_economic_regime_analysis) + assert callable(get_yield_curve_analysis) + assert callable(get_monetary_policy_analysis) + assert callable(get_inflation_regime_analysis) + + def test_correlation_tools_importable(self): + """Correlation analyst tools should be importable.""" + from tradingagents.agents.analysts.correlation_analyst import ( + get_cross_asset_correlation_analysis, + get_sector_rotation_analysis, + get_correlation_matrix, + get_rolling_correlation_trend, + ) + assert callable(get_cross_asset_correlation_analysis) + assert callable(get_sector_rotation_analysis) + assert callable(get_correlation_matrix) + assert callable(get_rolling_correlation_trend) + + +class TestTradingGraphToolNodes: + """Test that trading_graph.py has tool nodes for new analysts.""" + + @patch("tradingagents.graph.trading_graph.ChatOpenAI") + @patch("tradingagents.graph.trading_graph.set_config") + @patch("os.makedirs") + def test_tool_nodes_include_momentum(self, mock_makedirs, mock_set_config, mock_llm): + """_create_tool_nodes should include momentum tools.""" + from tradingagents.graph.trading_graph import TradingAgentsGraph + + mock_llm.return_value = MagicMock() + + with patch.object(TradingAgentsGraph, "__init__", lambda self: None): + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + tool_nodes = graph._create_tool_nodes() + + assert "momentum" in tool_nodes + + @patch("tradingagents.graph.trading_graph.ChatOpenAI") + @patch("tradingagents.graph.trading_graph.set_config") + @patch("os.makedirs") + def test_tool_nodes_include_macro(self, mock_makedirs, mock_set_config, mock_llm): + """_create_tool_nodes should include macro tools.""" + from tradingagents.graph.trading_graph import TradingAgentsGraph + + mock_llm.return_value = MagicMock() + + with patch.object(TradingAgentsGraph, "__init__", lambda self: None): + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + tool_nodes = graph._create_tool_nodes() + + assert "macro" in tool_nodes + + @patch("tradingagents.graph.trading_graph.ChatOpenAI") + @patch("tradingagents.graph.trading_graph.set_config") + @patch("os.makedirs") + def test_tool_nodes_include_correlation(self, mock_makedirs, mock_set_config, mock_llm): + """_create_tool_nodes should include correlation tools.""" + from tradingagents.graph.trading_graph import TradingAgentsGraph + + mock_llm.return_value = MagicMock() + + with patch.object(TradingAgentsGraph, "__init__", lambda self: None): + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + tool_nodes = graph._create_tool_nodes() + + assert "correlation" in tool_nodes + + +class TestGraphSetupDocstring: + """Test that setup_graph docstring documents new analysts.""" + + def test_docstring_mentions_momentum(self): + """setup_graph docstring should mention momentum analyst.""" + from tradingagents.graph.setup import GraphSetup + docstring = GraphSetup.setup_graph.__doc__ + assert "momentum" in docstring.lower() + + def test_docstring_mentions_macro(self): + """setup_graph docstring should mention macro analyst.""" + from tradingagents.graph.setup import GraphSetup + docstring = GraphSetup.setup_graph.__doc__ + assert "macro" in docstring.lower() + + def test_docstring_mentions_correlation(self): + """setup_graph docstring should mention correlation analyst.""" + from tradingagents.graph.setup import GraphSetup + docstring = GraphSetup.setup_graph.__doc__ + assert "correlation" in docstring.lower() + + +class TestLogStateNewReports: + """Test that _log_state includes new report fields.""" + + def test_log_state_includes_momentum_report(self): + """_log_state should log momentum_report if present.""" + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self: None): + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + graph.log_states_dict = {} + graph.ticker = "TEST" + + final_state = { + "company_of_interest": "TEST", + "trade_date": "2024-01-01", + "market_report": "", + "sentiment_report": "", + "news_report": "", + "fundamentals_report": "", + "momentum_report": "Momentum analysis result", + "macro_report": "", + "correlation_report": "", + "investment_debate_state": { + "bull_history": "", + "bear_history": "", + "history": "", + "current_response": "", + "judge_decision": "", + }, + "trader_investment_plan": "", + "risk_debate_state": { + "risky_history": "", + "safe_history": "", + "neutral_history": "", + "history": "", + "judge_decision": "", + }, + "investment_plan": "", + "final_trade_decision": "", + } + + with patch("builtins.open", MagicMock()): + with patch("json.dump"): + with patch("pathlib.Path.mkdir"): + graph._log_state("2024-01-01", final_state) + + logged = graph.log_states_dict["2024-01-01"] + assert "momentum_report" in logged + assert logged["momentum_report"] == "Momentum analysis result" + + def test_log_state_includes_macro_report(self): + """_log_state should log macro_report if present.""" + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self: None): + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + graph.log_states_dict = {} + graph.ticker = "TEST" + + final_state = { + "company_of_interest": "TEST", + "trade_date": "2024-01-01", + "market_report": "", + "sentiment_report": "", + "news_report": "", + "fundamentals_report": "", + "momentum_report": "", + "macro_report": "Macro analysis result", + "correlation_report": "", + "investment_debate_state": { + "bull_history": "", + "bear_history": "", + "history": "", + "current_response": "", + "judge_decision": "", + }, + "trader_investment_plan": "", + "risk_debate_state": { + "risky_history": "", + "safe_history": "", + "neutral_history": "", + "history": "", + "judge_decision": "", + }, + "investment_plan": "", + "final_trade_decision": "", + } + + with patch("builtins.open", MagicMock()): + with patch("json.dump"): + with patch("pathlib.Path.mkdir"): + graph._log_state("2024-01-01", final_state) + + logged = graph.log_states_dict["2024-01-01"] + assert "macro_report" in logged + assert logged["macro_report"] == "Macro analysis result" + + def test_log_state_includes_correlation_report(self): + """_log_state should log correlation_report if present.""" + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self: None): + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + graph.log_states_dict = {} + graph.ticker = "TEST" + + final_state = { + "company_of_interest": "TEST", + "trade_date": "2024-01-01", + "market_report": "", + "sentiment_report": "", + "news_report": "", + "fundamentals_report": "", + "momentum_report": "", + "macro_report": "", + "correlation_report": "Correlation analysis result", + "investment_debate_state": { + "bull_history": "", + "bear_history": "", + "history": "", + "current_response": "", + "judge_decision": "", + }, + "trader_investment_plan": "", + "risk_debate_state": { + "risky_history": "", + "safe_history": "", + "neutral_history": "", + "history": "", + "judge_decision": "", + }, + "investment_plan": "", + "final_trade_decision": "", + } + + with patch("builtins.open", MagicMock()): + with patch("json.dump"): + with patch("pathlib.Path.mkdir"): + graph._log_state("2024-01-01", final_state) + + logged = graph.log_states_dict["2024-01-01"] + assert "correlation_report" in logged + assert logged["correlation_report"] == "Correlation analysis result" + + def test_log_state_handles_missing_new_reports(self): + """_log_state should handle missing new report fields gracefully.""" + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self: None): + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + graph.log_states_dict = {} + graph.ticker = "TEST" + + # State without new report fields (backward compatibility) + final_state = { + "company_of_interest": "TEST", + "trade_date": "2024-01-01", + "investment_debate_state": { + "bull_history": "", + "bear_history": "", + "history": "", + "current_response": "", + "judge_decision": "", + }, + "trader_investment_plan": "", + "risk_debate_state": { + "risky_history": "", + "safe_history": "", + "neutral_history": "", + "history": "", + "judge_decision": "", + }, + "investment_plan": "", + "final_trade_decision": "", + } + + with patch("builtins.open", MagicMock()): + with patch("json.dump"): + with patch("pathlib.Path.mkdir"): + graph._log_state("2024-01-01", final_state) + + logged = graph.log_states_dict["2024-01-01"] + # Should default to empty string + assert logged["momentum_report"] == "" + assert logged["macro_report"] == "" + assert logged["correlation_report"] == "" diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index d84d9eb1..2030d8ae 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -6,6 +6,9 @@ from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.market_analyst import create_market_analyst from .analysts.news_analyst import create_news_analyst from .analysts.social_media_analyst import create_social_media_analyst +from .analysts.momentum_analyst import create_momentum_analyst +from .analysts.macro_analyst import create_macro_analyst +from .analysts.correlation_analyst import create_correlation_analyst from .researchers.bear_researcher import create_bear_researcher from .researchers.bull_researcher import create_bull_researcher @@ -16,6 +19,7 @@ from .risk_mgmt.neutral_debator import create_neutral_debator from .managers.research_manager import create_research_manager from .managers.risk_manager import create_risk_manager +from .managers.position_sizing_manager import create_position_sizing_manager from .trader.trader import create_trader @@ -30,10 +34,14 @@ __all__ = [ "create_research_manager", "create_fundamentals_analyst", "create_market_analyst", + "create_momentum_analyst", + "create_macro_analyst", + "create_correlation_analyst", "create_neutral_debator", "create_news_analyst", "create_risky_debator", "create_risk_manager", + "create_position_sizing_manager", "create_safe_debator", "create_social_media_analyst", "create_trader", diff --git a/tradingagents/agents/analysts/__init__.py b/tradingagents/agents/analysts/__init__.py new file mode 100644 index 00000000..203a00da --- /dev/null +++ b/tradingagents/agents/analysts/__init__.py @@ -0,0 +1,29 @@ +"""Analyst agents for market analysis. + +This module provides specialized analyst agents for different types of analysis: +- Fundamentals Analyst: Company financial analysis +- Market Analyst: Technical and market structure analysis +- News Analyst: News sentiment and event analysis +- Social Media Analyst: Social sentiment analysis +- Momentum Analyst: Multi-timeframe momentum analysis (Issue #13) +- Macro Analyst: Macroeconomic and FRED data analysis (Issue #14) +- Correlation Analyst: Cross-asset correlation and sector rotation (Issue #15) +""" + +from .fundamentals_analyst import create_fundamentals_analyst +from .market_analyst import create_market_analyst +from .news_analyst import create_news_analyst +from .social_media_analyst import create_social_media_analyst +from .momentum_analyst import create_momentum_analyst +from .macro_analyst import create_macro_analyst +from .correlation_analyst import create_correlation_analyst + +__all__ = [ + "create_fundamentals_analyst", + "create_market_analyst", + "create_news_analyst", + "create_social_media_analyst", + "create_momentum_analyst", + "create_macro_analyst", + "create_correlation_analyst", +] diff --git a/tradingagents/agents/managers/__init__.py b/tradingagents/agents/managers/__init__.py index 8ea3f0cd..e9692a85 100644 --- a/tradingagents/agents/managers/__init__.py +++ b/tradingagents/agents/managers/__init__.py @@ -1 +1,17 @@ -"""Manager agents for portfolio and risk management.""" +"""Manager agents for portfolio and risk management. + +This module provides manager agents for orchestrating research and risk: +- Research Manager: Coordinates research activities +- Risk Manager: Manages risk assessment +- Position Sizing Manager: Optimal position sizing (Issue #16) +""" + +from .research_manager import create_research_manager +from .risk_manager import create_risk_manager +from .position_sizing_manager import create_position_sizing_manager + +__all__ = [ + "create_research_manager", + "create_risk_manager", + "create_position_sizing_manager", +] diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 3a859ea1..eeb53303 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -60,6 +60,9 @@ class AgentState(MessagesState): str, "Report from the News Researcher of current world affairs" ] fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] + momentum_report: Annotated[str, "Report from the Momentum Analyst"] + macro_report: Annotated[str, "Report from the Macro Analyst"] + correlation_report: Annotated[str, "Report from the Correlation Analyst"] # researcher team discussion step investment_debate_state: Annotated[ diff --git a/tradingagents/graph/conditional_logic.py b/tradingagents/graph/conditional_logic.py index e7c87859..bd2717c6 100644 --- a/tradingagents/graph/conditional_logic.py +++ b/tradingagents/graph/conditional_logic.py @@ -43,6 +43,30 @@ class ConditionalLogic: return "tools_fundamentals" return "Msg Clear Fundamentals" + def should_continue_momentum(self, state: AgentState): + """Determine if momentum analysis should continue.""" + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools_momentum" + return "Msg Clear Momentum" + + def should_continue_macro(self, state: AgentState): + """Determine if macro analysis should continue.""" + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools_macro" + return "Msg Clear Macro" + + def should_continue_correlation(self, state: AgentState): + """Determine if correlation analysis should continue.""" + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools_correlation" + return "Msg Clear Correlation" + def should_continue_debate(self, state: AgentState) -> str: """Determine if debate should continue.""" diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index b270ffc0..6920aa1c 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -48,6 +48,9 @@ class GraphSetup: - "social": Social media analyst - "news": News analyst - "fundamentals": Fundamentals analyst + - "momentum": Momentum analyst (multi-TF momentum, ROC, ADX) + - "macro": Macro analyst (FRED data, economic regimes) + - "correlation": Correlation analyst (cross-asset, sector rotation) """ if len(selected_analysts) == 0: raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") @@ -85,6 +88,27 @@ class GraphSetup: delete_nodes["fundamentals"] = create_msg_delete() tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] + if "momentum" in selected_analysts: + analyst_nodes["momentum"] = create_momentum_analyst( + self.quick_thinking_llm + ) + delete_nodes["momentum"] = create_msg_delete() + tool_nodes["momentum"] = self.tool_nodes["momentum"] + + if "macro" in selected_analysts: + analyst_nodes["macro"] = create_macro_analyst( + self.quick_thinking_llm + ) + delete_nodes["macro"] = create_msg_delete() + tool_nodes["macro"] = self.tool_nodes["macro"] + + if "correlation" in selected_analysts: + analyst_nodes["correlation"] = create_correlation_analyst( + self.quick_thinking_llm + ) + delete_nodes["correlation"] = create_msg_delete() + tool_nodes["correlation"] = self.tool_nodes["correlation"] + # Create researcher and manager nodes bull_researcher_node = create_bull_researcher( self.quick_thinking_llm, self.bull_memory diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 8b334c7b..8cdae96c 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -36,6 +36,25 @@ from tradingagents.agents.utils.agent_utils import ( get_global_news ) +# Import tools from new analysts +from tradingagents.agents.analysts.momentum_analyst import ( + get_multi_timeframe_momentum, + get_adx_analysis, + get_momentum_divergence, +) +from tradingagents.agents.analysts.macro_analyst import ( + get_economic_regime_analysis, + get_yield_curve_analysis, + get_monetary_policy_analysis, + get_inflation_regime_analysis, +) +from tradingagents.agents.analysts.correlation_analyst import ( + get_cross_asset_correlation_analysis, + get_sector_rotation_analysis, + get_correlation_matrix, + get_rolling_correlation_trend, +) + from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator @@ -223,6 +242,32 @@ class TradingAgentsGraph: get_income_statement, ] ), + "momentum": ToolNode( + [ + # Momentum analysis tools + get_multi_timeframe_momentum, + get_adx_analysis, + get_momentum_divergence, + ] + ), + "macro": ToolNode( + [ + # Macroeconomic analysis tools + get_economic_regime_analysis, + get_yield_curve_analysis, + get_monetary_policy_analysis, + get_inflation_regime_analysis, + ] + ), + "correlation": ToolNode( + [ + # Correlation analysis tools + get_cross_asset_correlation_analysis, + get_sector_rotation_analysis, + get_correlation_matrix, + get_rolling_correlation_trend, + ] + ), } def propagate(self, company_name, trade_date): @@ -265,10 +310,13 @@ class TradingAgentsGraph: self.log_states_dict[str(trade_date)] = { "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], - "market_report": final_state["market_report"], - "sentiment_report": final_state["sentiment_report"], - "news_report": final_state["news_report"], - "fundamentals_report": final_state["fundamentals_report"], + "market_report": final_state.get("market_report", ""), + "sentiment_report": final_state.get("sentiment_report", ""), + "news_report": final_state.get("news_report", ""), + "fundamentals_report": final_state.get("fundamentals_report", ""), + "momentum_report": final_state.get("momentum_report", ""), + "macro_report": final_state.get("macro_report", ""), + "correlation_report": final_state.get("correlation_report", ""), "investment_debate_state": { "bull_history": final_state["investment_debate_state"]["bull_history"], "bear_history": final_state["investment_debate_state"]["bear_history"],