feat(graph): integrate new analysts into workflow - Fixes #17
This commit is contained in:
parent
a17fc1f029
commit
5a0606b59f
|
|
@ -0,0 +1,506 @@
|
|||
"""Tests for Issue #17: Analyst integration into graph/setup.py workflow.
|
||||
|
||||
This module tests the integration of new analysts (momentum, macro, correlation)
|
||||
into the TradingAgents graph workflow.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import get_type_hints
|
||||
import sys
|
||||
|
||||
# Check if langchain dependencies are available
|
||||
try:
|
||||
import langchain_core
|
||||
LANGCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
LANGCHAIN_AVAILABLE = False
|
||||
|
||||
# Skip all tests if langchain not available
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not LANGCHAIN_AVAILABLE,
|
||||
reason="langchain_core not installed"
|
||||
)
|
||||
|
||||
|
||||
class TestAgentStateReports:
|
||||
"""Test that AgentState has the new report fields."""
|
||||
|
||||
def test_agent_state_has_momentum_report(self):
|
||||
"""AgentState should have momentum_report field."""
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
hints = get_type_hints(AgentState)
|
||||
assert "momentum_report" in hints
|
||||
|
||||
def test_agent_state_has_macro_report(self):
|
||||
"""AgentState should have macro_report field."""
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
hints = get_type_hints(AgentState)
|
||||
assert "macro_report" in hints
|
||||
|
||||
def test_agent_state_has_correlation_report(self):
|
||||
"""AgentState should have correlation_report field."""
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
hints = get_type_hints(AgentState)
|
||||
assert "correlation_report" in hints
|
||||
|
||||
|
||||
class TestConditionalLogicMethods:
|
||||
"""Test that ConditionalLogic has methods for new analysts."""
|
||||
|
||||
def test_should_continue_momentum_exists(self):
|
||||
"""ConditionalLogic should have should_continue_momentum method."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
assert hasattr(cl, "should_continue_momentum")
|
||||
assert callable(cl.should_continue_momentum)
|
||||
|
||||
def test_should_continue_macro_exists(self):
|
||||
"""ConditionalLogic should have should_continue_macro method."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
assert hasattr(cl, "should_continue_macro")
|
||||
assert callable(cl.should_continue_macro)
|
||||
|
||||
def test_should_continue_correlation_exists(self):
|
||||
"""ConditionalLogic should have should_continue_correlation method."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
assert hasattr(cl, "should_continue_correlation")
|
||||
assert callable(cl.should_continue_correlation)
|
||||
|
||||
def test_momentum_conditional_returns_tools(self):
|
||||
"""should_continue_momentum should return 'tools_momentum' when tool_calls exist."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.tool_calls = [{"name": "test"}]
|
||||
state = {"messages": [mock_message]}
|
||||
|
||||
result = cl.should_continue_momentum(state)
|
||||
assert result == "tools_momentum"
|
||||
|
||||
def test_momentum_conditional_returns_clear(self):
|
||||
"""should_continue_momentum should return 'Msg Clear Momentum' when no tool_calls."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.tool_calls = []
|
||||
state = {"messages": [mock_message]}
|
||||
|
||||
result = cl.should_continue_momentum(state)
|
||||
assert result == "Msg Clear Momentum"
|
||||
|
||||
def test_macro_conditional_returns_tools(self):
|
||||
"""should_continue_macro should return 'tools_macro' when tool_calls exist."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.tool_calls = [{"name": "test"}]
|
||||
state = {"messages": [mock_message]}
|
||||
|
||||
result = cl.should_continue_macro(state)
|
||||
assert result == "tools_macro"
|
||||
|
||||
def test_macro_conditional_returns_clear(self):
|
||||
"""should_continue_macro should return 'Msg Clear Macro' when no tool_calls."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.tool_calls = []
|
||||
state = {"messages": [mock_message]}
|
||||
|
||||
result = cl.should_continue_macro(state)
|
||||
assert result == "Msg Clear Macro"
|
||||
|
||||
def test_correlation_conditional_returns_tools(self):
|
||||
"""should_continue_correlation should return 'tools_correlation' when tool_calls exist."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.tool_calls = [{"name": "test"}]
|
||||
state = {"messages": [mock_message]}
|
||||
|
||||
result = cl.should_continue_correlation(state)
|
||||
assert result == "tools_correlation"
|
||||
|
||||
def test_correlation_conditional_returns_clear(self):
|
||||
"""should_continue_correlation should return 'Msg Clear Correlation' when no tool_calls."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic()
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.tool_calls = []
|
||||
state = {"messages": [mock_message]}
|
||||
|
||||
result = cl.should_continue_correlation(state)
|
||||
assert result == "Msg Clear Correlation"
|
||||
|
||||
|
||||
class TestAgentImports:
|
||||
"""Test that new analysts are properly exported from agents module."""
|
||||
|
||||
def test_create_momentum_analyst_import(self):
|
||||
"""create_momentum_analyst should be importable from agents."""
|
||||
from tradingagents.agents import create_momentum_analyst
|
||||
assert callable(create_momentum_analyst)
|
||||
|
||||
def test_create_macro_analyst_import(self):
|
||||
"""create_macro_analyst should be importable from agents."""
|
||||
from tradingagents.agents import create_macro_analyst
|
||||
assert callable(create_macro_analyst)
|
||||
|
||||
def test_create_correlation_analyst_import(self):
|
||||
"""create_correlation_analyst should be importable from agents."""
|
||||
from tradingagents.agents import create_correlation_analyst
|
||||
assert callable(create_correlation_analyst)
|
||||
|
||||
def test_create_position_sizing_manager_import(self):
|
||||
"""create_position_sizing_manager should be importable from agents."""
|
||||
from tradingagents.agents import create_position_sizing_manager
|
||||
assert callable(create_position_sizing_manager)
|
||||
|
||||
|
||||
class TestAnalystsModuleExports:
|
||||
"""Test analysts submodule exports."""
|
||||
|
||||
def test_analysts_module_exports_momentum(self):
|
||||
"""analysts module should export create_momentum_analyst."""
|
||||
from tradingagents.agents.analysts import create_momentum_analyst
|
||||
assert callable(create_momentum_analyst)
|
||||
|
||||
def test_analysts_module_exports_macro(self):
|
||||
"""analysts module should export create_macro_analyst."""
|
||||
from tradingagents.agents.analysts import create_macro_analyst
|
||||
assert callable(create_macro_analyst)
|
||||
|
||||
def test_analysts_module_exports_correlation(self):
|
||||
"""analysts module should export create_correlation_analyst."""
|
||||
from tradingagents.agents.analysts import create_correlation_analyst
|
||||
assert callable(create_correlation_analyst)
|
||||
|
||||
def test_analysts_module_all_exports(self):
|
||||
"""analysts __all__ should include all analyst creators."""
|
||||
from tradingagents.agents import analysts
|
||||
assert "create_momentum_analyst" in analysts.__all__
|
||||
assert "create_macro_analyst" in analysts.__all__
|
||||
assert "create_correlation_analyst" in analysts.__all__
|
||||
|
||||
|
||||
class TestManagersModuleExports:
|
||||
"""Test managers submodule exports."""
|
||||
|
||||
def test_managers_module_exports_position_sizing(self):
|
||||
"""managers module should export create_position_sizing_manager."""
|
||||
from tradingagents.agents.managers import create_position_sizing_manager
|
||||
assert callable(create_position_sizing_manager)
|
||||
|
||||
def test_managers_module_all_exports(self):
|
||||
"""managers __all__ should include position_sizing_manager."""
|
||||
from tradingagents.agents import managers
|
||||
assert "create_position_sizing_manager" in managers.__all__
|
||||
|
||||
|
||||
class TestToolImports:
|
||||
"""Test that tool functions are importable from analyst modules."""
|
||||
|
||||
def test_momentum_tools_importable(self):
|
||||
"""Momentum analyst tools should be importable."""
|
||||
from tradingagents.agents.analysts.momentum_analyst import (
|
||||
get_multi_timeframe_momentum,
|
||||
get_adx_analysis,
|
||||
get_momentum_divergence,
|
||||
)
|
||||
assert callable(get_multi_timeframe_momentum)
|
||||
assert callable(get_adx_analysis)
|
||||
assert callable(get_momentum_divergence)
|
||||
|
||||
def test_macro_tools_importable(self):
|
||||
"""Macro analyst tools should be importable."""
|
||||
from tradingagents.agents.analysts.macro_analyst import (
|
||||
get_economic_regime_analysis,
|
||||
get_yield_curve_analysis,
|
||||
get_monetary_policy_analysis,
|
||||
get_inflation_regime_analysis,
|
||||
)
|
||||
assert callable(get_economic_regime_analysis)
|
||||
assert callable(get_yield_curve_analysis)
|
||||
assert callable(get_monetary_policy_analysis)
|
||||
assert callable(get_inflation_regime_analysis)
|
||||
|
||||
def test_correlation_tools_importable(self):
|
||||
"""Correlation analyst tools should be importable."""
|
||||
from tradingagents.agents.analysts.correlation_analyst import (
|
||||
get_cross_asset_correlation_analysis,
|
||||
get_sector_rotation_analysis,
|
||||
get_correlation_matrix,
|
||||
get_rolling_correlation_trend,
|
||||
)
|
||||
assert callable(get_cross_asset_correlation_analysis)
|
||||
assert callable(get_sector_rotation_analysis)
|
||||
assert callable(get_correlation_matrix)
|
||||
assert callable(get_rolling_correlation_trend)
|
||||
|
||||
|
||||
class TestTradingGraphToolNodes:
|
||||
"""Test that trading_graph.py has tool nodes for new analysts."""
|
||||
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.set_config")
|
||||
@patch("os.makedirs")
|
||||
def test_tool_nodes_include_momentum(self, mock_makedirs, mock_set_config, mock_llm):
|
||||
"""_create_tool_nodes should include momentum tools."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
mock_llm.return_value = MagicMock()
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
||||
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||
tool_nodes = graph._create_tool_nodes()
|
||||
|
||||
assert "momentum" in tool_nodes
|
||||
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.set_config")
|
||||
@patch("os.makedirs")
|
||||
def test_tool_nodes_include_macro(self, mock_makedirs, mock_set_config, mock_llm):
|
||||
"""_create_tool_nodes should include macro tools."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
mock_llm.return_value = MagicMock()
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
||||
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||
tool_nodes = graph._create_tool_nodes()
|
||||
|
||||
assert "macro" in tool_nodes
|
||||
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.set_config")
|
||||
@patch("os.makedirs")
|
||||
def test_tool_nodes_include_correlation(self, mock_makedirs, mock_set_config, mock_llm):
|
||||
"""_create_tool_nodes should include correlation tools."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
mock_llm.return_value = MagicMock()
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
||||
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||
tool_nodes = graph._create_tool_nodes()
|
||||
|
||||
assert "correlation" in tool_nodes
|
||||
|
||||
|
||||
class TestGraphSetupDocstring:
|
||||
"""Test that setup_graph docstring documents new analysts."""
|
||||
|
||||
def test_docstring_mentions_momentum(self):
|
||||
"""setup_graph docstring should mention momentum analyst."""
|
||||
from tradingagents.graph.setup import GraphSetup
|
||||
docstring = GraphSetup.setup_graph.__doc__
|
||||
assert "momentum" in docstring.lower()
|
||||
|
||||
def test_docstring_mentions_macro(self):
|
||||
"""setup_graph docstring should mention macro analyst."""
|
||||
from tradingagents.graph.setup import GraphSetup
|
||||
docstring = GraphSetup.setup_graph.__doc__
|
||||
assert "macro" in docstring.lower()
|
||||
|
||||
def test_docstring_mentions_correlation(self):
|
||||
"""setup_graph docstring should mention correlation analyst."""
|
||||
from tradingagents.graph.setup import GraphSetup
|
||||
docstring = GraphSetup.setup_graph.__doc__
|
||||
assert "correlation" in docstring.lower()
|
||||
|
||||
|
||||
class TestLogStateNewReports:
|
||||
"""Test that _log_state includes new report fields."""
|
||||
|
||||
def test_log_state_includes_momentum_report(self):
|
||||
"""_log_state should log momentum_report if present."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
||||
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||
graph.log_states_dict = {}
|
||||
graph.ticker = "TEST"
|
||||
|
||||
final_state = {
|
||||
"company_of_interest": "TEST",
|
||||
"trade_date": "2024-01-01",
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
"fundamentals_report": "",
|
||||
"momentum_report": "Momentum analysis result",
|
||||
"macro_report": "",
|
||||
"correlation_report": "",
|
||||
"investment_debate_state": {
|
||||
"bull_history": "",
|
||||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
},
|
||||
"trader_investment_plan": "",
|
||||
"risk_debate_state": {
|
||||
"risky_history": "",
|
||||
"safe_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"judge_decision": "",
|
||||
},
|
||||
"investment_plan": "",
|
||||
"final_trade_decision": "",
|
||||
}
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("json.dump"):
|
||||
with patch("pathlib.Path.mkdir"):
|
||||
graph._log_state("2024-01-01", final_state)
|
||||
|
||||
logged = graph.log_states_dict["2024-01-01"]
|
||||
assert "momentum_report" in logged
|
||||
assert logged["momentum_report"] == "Momentum analysis result"
|
||||
|
||||
def test_log_state_includes_macro_report(self):
|
||||
"""_log_state should log macro_report if present."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
||||
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||
graph.log_states_dict = {}
|
||||
graph.ticker = "TEST"
|
||||
|
||||
final_state = {
|
||||
"company_of_interest": "TEST",
|
||||
"trade_date": "2024-01-01",
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
"fundamentals_report": "",
|
||||
"momentum_report": "",
|
||||
"macro_report": "Macro analysis result",
|
||||
"correlation_report": "",
|
||||
"investment_debate_state": {
|
||||
"bull_history": "",
|
||||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
},
|
||||
"trader_investment_plan": "",
|
||||
"risk_debate_state": {
|
||||
"risky_history": "",
|
||||
"safe_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"judge_decision": "",
|
||||
},
|
||||
"investment_plan": "",
|
||||
"final_trade_decision": "",
|
||||
}
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("json.dump"):
|
||||
with patch("pathlib.Path.mkdir"):
|
||||
graph._log_state("2024-01-01", final_state)
|
||||
|
||||
logged = graph.log_states_dict["2024-01-01"]
|
||||
assert "macro_report" in logged
|
||||
assert logged["macro_report"] == "Macro analysis result"
|
||||
|
||||
def test_log_state_includes_correlation_report(self):
|
||||
"""_log_state should log correlation_report if present."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
||||
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||
graph.log_states_dict = {}
|
||||
graph.ticker = "TEST"
|
||||
|
||||
final_state = {
|
||||
"company_of_interest": "TEST",
|
||||
"trade_date": "2024-01-01",
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
"fundamentals_report": "",
|
||||
"momentum_report": "",
|
||||
"macro_report": "",
|
||||
"correlation_report": "Correlation analysis result",
|
||||
"investment_debate_state": {
|
||||
"bull_history": "",
|
||||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
},
|
||||
"trader_investment_plan": "",
|
||||
"risk_debate_state": {
|
||||
"risky_history": "",
|
||||
"safe_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"judge_decision": "",
|
||||
},
|
||||
"investment_plan": "",
|
||||
"final_trade_decision": "",
|
||||
}
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("json.dump"):
|
||||
with patch("pathlib.Path.mkdir"):
|
||||
graph._log_state("2024-01-01", final_state)
|
||||
|
||||
logged = graph.log_states_dict["2024-01-01"]
|
||||
assert "correlation_report" in logged
|
||||
assert logged["correlation_report"] == "Correlation analysis result"
|
||||
|
||||
def test_log_state_handles_missing_new_reports(self):
|
||||
"""_log_state should handle missing new report fields gracefully."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
||||
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||
graph.log_states_dict = {}
|
||||
graph.ticker = "TEST"
|
||||
|
||||
# State without new report fields (backward compatibility)
|
||||
final_state = {
|
||||
"company_of_interest": "TEST",
|
||||
"trade_date": "2024-01-01",
|
||||
"investment_debate_state": {
|
||||
"bull_history": "",
|
||||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
},
|
||||
"trader_investment_plan": "",
|
||||
"risk_debate_state": {
|
||||
"risky_history": "",
|
||||
"safe_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"judge_decision": "",
|
||||
},
|
||||
"investment_plan": "",
|
||||
"final_trade_decision": "",
|
||||
}
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("json.dump"):
|
||||
with patch("pathlib.Path.mkdir"):
|
||||
graph._log_state("2024-01-01", final_state)
|
||||
|
||||
logged = graph.log_states_dict["2024-01-01"]
|
||||
# Should default to empty string
|
||||
assert logged["momentum_report"] == ""
|
||||
assert logged["macro_report"] == ""
|
||||
assert logged["correlation_report"] == ""
|
||||
|
|
@ -6,6 +6,9 @@ from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
|||
from .analysts.market_analyst import create_market_analyst
|
||||
from .analysts.news_analyst import create_news_analyst
|
||||
from .analysts.social_media_analyst import create_social_media_analyst
|
||||
from .analysts.momentum_analyst import create_momentum_analyst
|
||||
from .analysts.macro_analyst import create_macro_analyst
|
||||
from .analysts.correlation_analyst import create_correlation_analyst
|
||||
|
||||
from .researchers.bear_researcher import create_bear_researcher
|
||||
from .researchers.bull_researcher import create_bull_researcher
|
||||
|
|
@ -16,6 +19,7 @@ from .risk_mgmt.neutral_debator import create_neutral_debator
|
|||
|
||||
from .managers.research_manager import create_research_manager
|
||||
from .managers.risk_manager import create_risk_manager
|
||||
from .managers.position_sizing_manager import create_position_sizing_manager
|
||||
|
||||
from .trader.trader import create_trader
|
||||
|
||||
|
|
@ -30,10 +34,14 @@ __all__ = [
|
|||
"create_research_manager",
|
||||
"create_fundamentals_analyst",
|
||||
"create_market_analyst",
|
||||
"create_momentum_analyst",
|
||||
"create_macro_analyst",
|
||||
"create_correlation_analyst",
|
||||
"create_neutral_debator",
|
||||
"create_news_analyst",
|
||||
"create_risky_debator",
|
||||
"create_risk_manager",
|
||||
"create_position_sizing_manager",
|
||||
"create_safe_debator",
|
||||
"create_social_media_analyst",
|
||||
"create_trader",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
"""Analyst agents for market analysis.
|
||||
|
||||
This module provides specialized analyst agents for different types of analysis:
|
||||
- Fundamentals Analyst: Company financial analysis
|
||||
- Market Analyst: Technical and market structure analysis
|
||||
- News Analyst: News sentiment and event analysis
|
||||
- Social Media Analyst: Social sentiment analysis
|
||||
- Momentum Analyst: Multi-timeframe momentum analysis (Issue #13)
|
||||
- Macro Analyst: Macroeconomic and FRED data analysis (Issue #14)
|
||||
- Correlation Analyst: Cross-asset correlation and sector rotation (Issue #15)
|
||||
"""
|
||||
|
||||
from .fundamentals_analyst import create_fundamentals_analyst
|
||||
from .market_analyst import create_market_analyst
|
||||
from .news_analyst import create_news_analyst
|
||||
from .social_media_analyst import create_social_media_analyst
|
||||
from .momentum_analyst import create_momentum_analyst
|
||||
from .macro_analyst import create_macro_analyst
|
||||
from .correlation_analyst import create_correlation_analyst
|
||||
|
||||
__all__ = [
|
||||
"create_fundamentals_analyst",
|
||||
"create_market_analyst",
|
||||
"create_news_analyst",
|
||||
"create_social_media_analyst",
|
||||
"create_momentum_analyst",
|
||||
"create_macro_analyst",
|
||||
"create_correlation_analyst",
|
||||
]
|
||||
|
|
@ -1 +1,17 @@
|
|||
"""Manager agents for portfolio and risk management."""
|
||||
"""Manager agents for portfolio and risk management.
|
||||
|
||||
This module provides manager agents for orchestrating research and risk:
|
||||
- Research Manager: Coordinates research activities
|
||||
- Risk Manager: Manages risk assessment
|
||||
- Position Sizing Manager: Optimal position sizing (Issue #16)
|
||||
"""
|
||||
|
||||
from .research_manager import create_research_manager
|
||||
from .risk_manager import create_risk_manager
|
||||
from .position_sizing_manager import create_position_sizing_manager
|
||||
|
||||
__all__ = [
|
||||
"create_research_manager",
|
||||
"create_risk_manager",
|
||||
"create_position_sizing_manager",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -60,6 +60,9 @@ class AgentState(MessagesState):
|
|||
str, "Report from the News Researcher of current world affairs"
|
||||
]
|
||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
||||
momentum_report: Annotated[str, "Report from the Momentum Analyst"]
|
||||
macro_report: Annotated[str, "Report from the Macro Analyst"]
|
||||
correlation_report: Annotated[str, "Report from the Correlation Analyst"]
|
||||
|
||||
# researcher team discussion step
|
||||
investment_debate_state: Annotated[
|
||||
|
|
|
|||
|
|
@ -43,6 +43,30 @@ class ConditionalLogic:
|
|||
return "tools_fundamentals"
|
||||
return "Msg Clear Fundamentals"
|
||||
|
||||
def should_continue_momentum(self, state: AgentState):
|
||||
"""Determine if momentum analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_momentum"
|
||||
return "Msg Clear Momentum"
|
||||
|
||||
def should_continue_macro(self, state: AgentState):
|
||||
"""Determine if macro analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_macro"
|
||||
return "Msg Clear Macro"
|
||||
|
||||
def should_continue_correlation(self, state: AgentState):
|
||||
"""Determine if correlation analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_correlation"
|
||||
return "Msg Clear Correlation"
|
||||
|
||||
def should_continue_debate(self, state: AgentState) -> str:
|
||||
"""Determine if debate should continue."""
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,9 @@ class GraphSetup:
|
|||
- "social": Social media analyst
|
||||
- "news": News analyst
|
||||
- "fundamentals": Fundamentals analyst
|
||||
- "momentum": Momentum analyst (multi-TF momentum, ROC, ADX)
|
||||
- "macro": Macro analyst (FRED data, economic regimes)
|
||||
- "correlation": Correlation analyst (cross-asset, sector rotation)
|
||||
"""
|
||||
if len(selected_analysts) == 0:
|
||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||
|
|
@ -85,6 +88,27 @@ class GraphSetup:
|
|||
delete_nodes["fundamentals"] = create_msg_delete()
|
||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||
|
||||
if "momentum" in selected_analysts:
|
||||
analyst_nodes["momentum"] = create_momentum_analyst(
|
||||
self.quick_thinking_llm
|
||||
)
|
||||
delete_nodes["momentum"] = create_msg_delete()
|
||||
tool_nodes["momentum"] = self.tool_nodes["momentum"]
|
||||
|
||||
if "macro" in selected_analysts:
|
||||
analyst_nodes["macro"] = create_macro_analyst(
|
||||
self.quick_thinking_llm
|
||||
)
|
||||
delete_nodes["macro"] = create_msg_delete()
|
||||
tool_nodes["macro"] = self.tool_nodes["macro"]
|
||||
|
||||
if "correlation" in selected_analysts:
|
||||
analyst_nodes["correlation"] = create_correlation_analyst(
|
||||
self.quick_thinking_llm
|
||||
)
|
||||
delete_nodes["correlation"] = create_msg_delete()
|
||||
tool_nodes["correlation"] = self.tool_nodes["correlation"]
|
||||
|
||||
# Create researcher and manager nodes
|
||||
bull_researcher_node = create_bull_researcher(
|
||||
self.quick_thinking_llm, self.bull_memory
|
||||
|
|
|
|||
|
|
@ -36,6 +36,25 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_global_news
|
||||
)
|
||||
|
||||
# Import tools from new analysts
|
||||
from tradingagents.agents.analysts.momentum_analyst import (
|
||||
get_multi_timeframe_momentum,
|
||||
get_adx_analysis,
|
||||
get_momentum_divergence,
|
||||
)
|
||||
from tradingagents.agents.analysts.macro_analyst import (
|
||||
get_economic_regime_analysis,
|
||||
get_yield_curve_analysis,
|
||||
get_monetary_policy_analysis,
|
||||
get_inflation_regime_analysis,
|
||||
)
|
||||
from tradingagents.agents.analysts.correlation_analyst import (
|
||||
get_cross_asset_correlation_analysis,
|
||||
get_sector_rotation_analysis,
|
||||
get_correlation_matrix,
|
||||
get_rolling_correlation_trend,
|
||||
)
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
|
|
@ -223,6 +242,32 @@ class TradingAgentsGraph:
|
|||
get_income_statement,
|
||||
]
|
||||
),
|
||||
"momentum": ToolNode(
|
||||
[
|
||||
# Momentum analysis tools
|
||||
get_multi_timeframe_momentum,
|
||||
get_adx_analysis,
|
||||
get_momentum_divergence,
|
||||
]
|
||||
),
|
||||
"macro": ToolNode(
|
||||
[
|
||||
# Macroeconomic analysis tools
|
||||
get_economic_regime_analysis,
|
||||
get_yield_curve_analysis,
|
||||
get_monetary_policy_analysis,
|
||||
get_inflation_regime_analysis,
|
||||
]
|
||||
),
|
||||
"correlation": ToolNode(
|
||||
[
|
||||
# Correlation analysis tools
|
||||
get_cross_asset_correlation_analysis,
|
||||
get_sector_rotation_analysis,
|
||||
get_correlation_matrix,
|
||||
get_rolling_correlation_trend,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
|
|
@ -265,10 +310,13 @@ class TradingAgentsGraph:
|
|||
self.log_states_dict[str(trade_date)] = {
|
||||
"company_of_interest": final_state["company_of_interest"],
|
||||
"trade_date": final_state["trade_date"],
|
||||
"market_report": final_state["market_report"],
|
||||
"sentiment_report": final_state["sentiment_report"],
|
||||
"news_report": final_state["news_report"],
|
||||
"fundamentals_report": final_state["fundamentals_report"],
|
||||
"market_report": final_state.get("market_report", ""),
|
||||
"sentiment_report": final_state.get("sentiment_report", ""),
|
||||
"news_report": final_state.get("news_report", ""),
|
||||
"fundamentals_report": final_state.get("fundamentals_report", ""),
|
||||
"momentum_report": final_state.get("momentum_report", ""),
|
||||
"macro_report": final_state.get("macro_report", ""),
|
||||
"correlation_report": final_state.get("correlation_report", ""),
|
||||
"investment_debate_state": {
|
||||
"bull_history": final_state["investment_debate_state"]["bull_history"],
|
||||
"bear_history": final_state["investment_debate_state"]["bear_history"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue