diff --git a/tests/test_scenario_catalyst_analyst.py b/tests/test_scenario_catalyst_analyst.py new file mode 100644 index 00000000..18e49bfe --- /dev/null +++ b/tests/test_scenario_catalyst_analyst.py @@ -0,0 +1,465 @@ +import json + +from tradingagents.graph.setup import GraphSetup +from tradingagents.graph.trading_graph import TradingAgentsGraph + + +class DummyStateGraph: + def __init__(self, _state_type): + self.nodes = {} + self.conditional_edges = {} + + def add_node(self, name, node): + self.nodes[name] = node + + def add_edge(self, *_args, **_kwargs): + return None + + def add_conditional_edges(self, source, condition, destinations): + self.conditional_edges[source] = { + "condition": condition, + "destinations": destinations, + } + + def compile(self): + return { + "nodes": self.nodes, + "conditional_edges": self.conditional_edges, + } + + +class DummyToolNode: + def __init__(self, tools): + self.tools = tools + + +def test_scenario_tools_route_to_vendor(monkeypatch): + import tradingagents.dataflows.interface as interface + from tradingagents.agents.utils.scenario_tools import ( + get_catalyst_calendar, + get_scenario_fundamentals, + get_scenario_news, + ) + + calls = [] + + def fake_route_to_vendor(method, *args, **kwargs): + calls.append((method, args, kwargs)) + return f"{method}-result" + + monkeypatch.setattr(interface, "route_to_vendor", fake_route_to_vendor) + + assert ( + get_scenario_fundamentals.invoke({"ticker": "AAPL", "curr_date": "2026-03-24"}) + == "get_fundamentals-result" + ) + assert ( + get_scenario_news.invoke( + {"query": "AAPL product launch catalyst", "start_date": "2026-03-01", "end_date": "2026-03-24"} + ) + == "get_news-result" + ) + assert ( + get_catalyst_calendar.invoke({"curr_date": "2026-03-24"}) + == "get_fed_calendar-result" + ) + assert calls == [ + ( + "get_fundamentals", + (), + {"ticker": "AAPL", "curr_date": "2026-03-24"}, + ), + ( + "get_news", + (), + { + "query": "AAPL product launch catalyst", + "start_date": "2026-03-01", + "end_date": "2026-03-24", + }, + ), + ("get_fed_calendar", (), {"curr_date": "2026-03-24"}), + ] + + +def test_graph_setup_wires_scenario_catalyst_analyst_and_tools(monkeypatch): + recorded_llms = {} + + monkeypatch.setattr("tradingagents.graph.setup.StateGraph", DummyStateGraph) + monkeypatch.setattr("tradingagents.graph.setup.create_msg_delete", lambda: "delete") + + def make_factory(node_name): + def factory(llm, *_args): + recorded_llms[node_name] = llm + return node_name + + return factory + + monkeypatch.setattr( + "tradingagents.graph.setup.create_market_analyst", + make_factory("Market Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_scenario_catalyst_analyst", + make_factory("Scenario Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_social_media_analyst", + make_factory("Social Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_news_analyst", + make_factory("News Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_fundamentals_analyst", + make_factory("Fundamentals Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_bull_researcher", + make_factory("Bull Researcher"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_bear_researcher", + make_factory("Bear Researcher"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_research_manager", + make_factory("Research Manager"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_trader", + make_factory("Trader"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_aggressive_debator", + make_factory("Aggressive Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_neutral_debator", + make_factory("Neutral Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_conservative_debator", + make_factory("Conservative Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_portfolio_manager", + make_factory("Portfolio Manager"), + ) + + class PartialConditionalLogic: + def should_continue_market(self, _state): + return "Msg Clear Market" + + def should_continue_scenario(self, _state): + return "Msg Clear Scenario" + + def should_continue_debate(self, _state): + return "Research Manager" + + def should_continue_risk_analysis(self, _state): + return "Portfolio Manager" + + setup = GraphSetup( + quick_thinking_llm="quick-llm", + deep_thinking_llm="deep-llm", + tool_nodes={"market": "market-tools", "scenario": "scenario-tools"}, + bull_memory=object(), + bear_memory=object(), + trader_memory=object(), + invest_judge_memory=object(), + portfolio_manager_memory=object(), + conditional_logic=PartialConditionalLogic(), + role_llms={"scenario": "scenario-llm"}, + ) + + graph = setup.setup_graph(selected_analysts=["market", "scenario"]) + + assert recorded_llms["Scenario Analyst"] == "scenario-llm" + assert graph["nodes"]["Scenario Analyst"] == "Scenario Analyst" + assert graph["nodes"]["tools_scenario"] == "scenario-tools" + assert "Scenario Analyst" in graph["conditional_edges"] + + +def test_trading_graph_creates_scenario_tool_node(monkeypatch): + monkeypatch.setattr("tradingagents.graph.trading_graph.ToolNode", DummyToolNode) + + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + tool_nodes = TradingAgentsGraph._create_tool_nodes(graph) + + assert [tool.name for tool in tool_nodes["scenario"].tools] == [ + "get_scenario_fundamentals", + "get_scenario_news", + "get_catalyst_calendar", + ] + + +class DummyPrompt: + def __init__(self, result): + self.result = result + + def partial(self, **_kwargs): + return self + + def __or__(self, _other): + return DummyChain(self.result) + + +class DummyChain: + def __init__(self, result): + self.result = result + + def invoke(self, _messages): + return self.result + + +class DummyResult: + def __init__(self, content, tool_calls): + self.content = content + self.tool_calls = tool_calls + + +class DummyLLM: + def __init__(self, result): + self.result = result + self.bound_tool_names = [] + + def bind_tools(self, tools): + self.bound_tool_names = [tool.name for tool in tools] + return object() + + +def test_scenario_catalyst_analyst_returns_structured_data(monkeypatch): + from tradingagents.agents.analysts.scenario_catalyst_analyst import ( + create_scenario_catalyst_analyst, + ) + + result = DummyResult( + content="""## Scenario and Catalyst Summary + +Bull case is an AI-led re-rating with operating leverage, base case is steady +execution with modest multiple expansion, and bear case is demand rollover. + +```json +{ + "scenario_map": [ + { + "name": "bull", + "probability_pct": 30, + "thesis": "AI-driven demand acceleration", + "valuation_implication": "multiple expansion toward upper historical range", + "signposts": ["order lead-times extend", "gross margin beats guidance"] + }, + { + "name": "base", + "probability_pct": 50, + "thesis": "stable demand and disciplined opex", + "valuation_implication": "range-bound multiple with EPS carry", + "signposts": ["in-line guidance", "services growth steady"] + }, + { + "name": "bear", + "probability_pct": 20, + "thesis": "weaker upgrade cycle and pricing pressure", + "valuation_implication": "derating to cycle-low valuation band", + "signposts": ["inventory builds", "discounting rises"] + } + ], + "dated_catalyst_map": [ + { + "catalyst": "FOMC rate decision", + "date_or_window": "2026-05-06", + "related_scenarios": ["bull", "base", "bear"], + "expected_impact": "changes discount-rate pressure on valuation", + "confidence": "medium" + } + ], + "invalidation_triggers": [ + { + "trigger": "two consecutive quarters of revenue miss versus guidance midpoint", + "affected_scenarios": ["bull", "base"], + "severity": "high", + "evidence_to_watch": "quarterly filings and management commentary" + } + ] +} +```""", + tool_calls=[], + ) + llm = DummyLLM(result) + monkeypatch.setattr( + "tradingagents.agents.analysts.scenario_catalyst_analyst.ChatPromptTemplate.from_messages", + lambda _messages: DummyPrompt(result), + ) + + node = create_scenario_catalyst_analyst(llm) + output = node( + { + "trade_date": "2026-03-24", + "company_of_interest": "AAPL", + "messages": ["analyze scenario tree"], + } + ) + + assert llm.bound_tool_names == [ + "get_scenario_fundamentals", + "get_scenario_news", + "get_catalyst_calendar", + ] + assert output["scenario_catalyst_report"] == result.content + assert output["scenario_catalyst_data"] == { + "ticker": "AAPL", + "analysis_date": "2026-03-24", + "scenario_map": [ + { + "name": "bull", + "probability_pct": 30, + "thesis": "AI-driven demand acceleration", + "valuation_implication": "multiple expansion toward upper historical range", + "signposts": [ + "order lead-times extend", + "gross margin beats guidance", + ], + }, + { + "name": "base", + "probability_pct": 50, + "thesis": "stable demand and disciplined opex", + "valuation_implication": "range-bound multiple with EPS carry", + "signposts": ["in-line guidance", "services growth steady"], + }, + { + "name": "bear", + "probability_pct": 20, + "thesis": "weaker upgrade cycle and pricing pressure", + "valuation_implication": "derating to cycle-low valuation band", + "signposts": ["inventory builds", "discounting rises"], + }, + ], + "dated_catalyst_map": [ + { + "catalyst": "FOMC rate decision", + "date_or_window": "2026-05-06", + "related_scenarios": ["bull", "base", "bear"], + "expected_impact": "changes discount-rate pressure on valuation", + "confidence": "medium", + } + ], + "invalidation_triggers": [ + { + "trigger": "two consecutive quarters of revenue miss versus guidance midpoint", + "affected_scenarios": ["bull", "base"], + "severity": "high", + "evidence_to_watch": "quarterly filings and management commentary", + } + ], + } + assert output["messages"] == [result] + + +def test_extract_scenario_payload_tolerates_common_model_json_variants(): + from tradingagents.agents.analysts.scenario_catalyst_analyst import ( + _extract_scenario_catalyst_payload, + ) + + expected = { + "scenario_map": [{"name": "base", "probability_pct": 60}], + "dated_catalyst_map": [{"catalyst": "earnings", "date_or_window": "Q2"}], + "invalidation_triggers": [{"trigger": "demand miss"}], + } + + uppercase_fence = """ +```JSON +{"scenario_map":[{"name":"base","probability_pct":60}],"dated_catalyst_map":[{"catalyst":"earnings","date_or_window":"Q2"}],"invalidation_triggers":[{"trigger":"demand miss"}]} +``` +""" + plain_fence = """ +``` +{"scenario_map":[{"name":"base","probability_pct":60}],"dated_catalyst_map":[{"catalyst":"earnings","date_or_window":"Q2"}],"invalidation_triggers":[{"trigger":"demand miss"}]} +``` +""" + raw_json = """ +Narrative intro before payload. +{"scenario_map":[{"name":"base","probability_pct":60}],"dated_catalyst_map":[{"catalyst":"earnings","date_or_window":"Q2"}],"invalidation_triggers":[{"trigger":"demand miss"}]} +Tail note after payload. +""" + + assert _extract_scenario_catalyst_payload(uppercase_fence) == expected + assert _extract_scenario_catalyst_payload(plain_fence) == expected + assert _extract_scenario_catalyst_payload(raw_json) == expected + + +def test_propagator_initial_state_seeds_scenario_defaults(): + from tradingagents.graph.propagation import Propagator + + state = Propagator().create_initial_state("AAPL", "2026-03-24") + + assert state["scenario_catalyst_report"] == "" + assert state["scenario_catalyst_data"] == { + "ticker": "", + "analysis_date": "", + "scenario_map": [], + "dated_catalyst_map": [], + "invalidation_triggers": [], + } + + +def test_log_state_persists_scenario_catalyst_report_and_data(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + graph.ticker = "AAPL" + graph.log_states_dict = {} + + final_state = { + "company_of_interest": "Apple", + "trade_date": "2026-03-24", + "market_report": "market", + "sentiment_report": "sentiment", + "news_report": "news", + "fundamentals_report": "fundamentals", + "segment_report": "segment report", + "segment_data": {"ticker": "AAPL"}, + "macro_report": "macro report", + "scenario_catalyst_report": "scenario report", + "scenario_catalyst_data": { + "ticker": "AAPL", + "analysis_date": "2026-03-24", + "scenario_map": [{"name": "base", "probability_pct": 60}], + "dated_catalyst_map": [{"catalyst": "earnings", "date_or_window": "Q2"}], + "invalidation_triggers": [{"trigger": "demand miss"}], + }, + "investment_debate_state": { + "bull_history": "bull", + "bear_history": "bear", + "history": "debate history", + "current_response": "current", + "judge_decision": "judge", + }, + "trader_investment_plan": "trader plan", + "risk_debate_state": { + "aggressive_history": "agg", + "conservative_history": "cons", + "neutral_history": "neutral", + "history": "risk history", + "judge_decision": "risk judge", + }, + "investment_plan": "investment plan", + "final_trade_decision": "buy", + } + + graph._log_state("2026-03-24", final_state) + + output_path = ( + tmp_path + / "eval_results" + / "AAPL" + / "TradingAgentsStrategy_logs" + / "full_states_log_2026-03-24.json" + ) + payload = json.loads(output_path.read_text()) + logged = payload["2026-03-24"] + + assert logged["scenario_catalyst_report"] == "scenario report" + assert logged["scenario_catalyst_data"] == final_state["scenario_catalyst_data"] diff --git a/tests/test_segment_analyst.py b/tests/test_segment_analyst.py new file mode 100644 index 00000000..81d59c26 --- /dev/null +++ b/tests/test_segment_analyst.py @@ -0,0 +1,410 @@ +from tradingagents.graph.setup import GraphSetup +from tradingagents.graph.trading_graph import TradingAgentsGraph + + +class DummyStateGraph: + def __init__(self, _state_type): + self.nodes = {} + self.conditional_edges = {} + + def add_node(self, name, node): + self.nodes[name] = node + + def add_edge(self, *_args, **_kwargs): + return None + + def add_conditional_edges(self, source, condition, destinations): + self.conditional_edges[source] = { + "condition": condition, + "destinations": destinations, + } + + def compile(self): + return { + "nodes": self.nodes, + "conditional_edges": self.conditional_edges, + } + + +class DummyToolNode: + def __init__(self, tools): + self.tools = tools + + +def test_segment_tools_route_to_vendor(monkeypatch): + import tradingagents.dataflows.interface as interface + from tradingagents.agents.utils.segment_tools import ( + get_segment_fundamentals, + get_segment_income_statement, + get_segment_news, + ) + + calls = [] + + def fake_route_to_vendor(method, *args, **kwargs): + calls.append((method, args, kwargs)) + return f"{method}-result" + + monkeypatch.setattr(interface, "route_to_vendor", fake_route_to_vendor) + + assert ( + get_segment_fundamentals.invoke({"ticker": "AAPL", "curr_date": "2026-03-24"}) + == "get_fundamentals-result" + ) + assert ( + get_segment_income_statement.invoke( + {"ticker": "AAPL", "curr_date": "2026-03-24", "freq": "quarterly"} + ) + == "get_income_statement-result" + ) + assert ( + get_segment_news.invoke( + {"query": "AAPL product segment demand", "start_date": "2026-03-01", "end_date": "2026-03-24"} + ) + == "get_news-result" + ) + assert calls == [ + ( + "get_fundamentals", + (), + {"ticker": "AAPL", "curr_date": "2026-03-24"}, + ), + ( + "get_income_statement", + (), + {"ticker": "AAPL", "freq": "quarterly", "curr_date": "2026-03-24"}, + ), + ( + "get_news", + (), + { + "query": "AAPL product segment demand", + "start_date": "2026-03-01", + "end_date": "2026-03-24", + }, + ), + ] + + +def test_graph_setup_wires_segment_analyst_and_segment_tools(monkeypatch): + recorded_llms = {} + + monkeypatch.setattr("tradingagents.graph.setup.StateGraph", DummyStateGraph) + monkeypatch.setattr("tradingagents.graph.setup.create_msg_delete", lambda: "delete") + + def make_factory(node_name): + def factory(llm, *_args): + recorded_llms[node_name] = llm + return node_name + + return factory + + monkeypatch.setattr( + "tradingagents.graph.setup.create_market_analyst", + make_factory("Market Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_segment_analyst", + make_factory("Segment Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_social_media_analyst", + make_factory("Social Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_news_analyst", + make_factory("News Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_fundamentals_analyst", + make_factory("Fundamentals Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_bull_researcher", + make_factory("Bull Researcher"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_bear_researcher", + make_factory("Bear Researcher"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_research_manager", + make_factory("Research Manager"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_trader", + make_factory("Trader"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_aggressive_debator", + make_factory("Aggressive Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_neutral_debator", + make_factory("Neutral Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_conservative_debator", + make_factory("Conservative Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_portfolio_manager", + make_factory("Portfolio Manager"), + ) + + class PartialConditionalLogic: + def should_continue_market(self, _state): + return "Msg Clear Market" + + def should_continue_segment(self, _state): + return "Msg Clear Segment" + + def should_continue_debate(self, _state): + return "Research Manager" + + def should_continue_risk_analysis(self, _state): + return "Portfolio Manager" + + setup = GraphSetup( + quick_thinking_llm="quick-llm", + deep_thinking_llm="deep-llm", + tool_nodes={"market": "market-tools", "segment": "segment-tools"}, + bull_memory=object(), + bear_memory=object(), + trader_memory=object(), + invest_judge_memory=object(), + portfolio_manager_memory=object(), + conditional_logic=PartialConditionalLogic(), + role_llms={"segment": "segment-llm"}, + ) + + graph = setup.setup_graph(selected_analysts=["market", "segment"]) + + assert recorded_llms["Segment Analyst"] == "segment-llm" + assert graph["nodes"]["Segment Analyst"] == "Segment Analyst" + assert graph["nodes"]["tools_segment"] == "segment-tools" + assert "Segment Analyst" in graph["conditional_edges"] + + +def test_trading_graph_creates_segment_tool_node(monkeypatch): + monkeypatch.setattr("tradingagents.graph.trading_graph.ToolNode", DummyToolNode) + + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + tool_nodes = TradingAgentsGraph._create_tool_nodes(graph) + + assert [tool.name for tool in tool_nodes["segment"].tools] == [ + "get_segment_fundamentals", + "get_segment_income_statement", + "get_segment_news", + ] + + +class DummyPrompt: + def __init__(self, result): + self.result = result + + def partial(self, **_kwargs): + return self + + def __or__(self, _other): + return DummyChain(self.result) + + +class DummyChain: + def __init__(self, result): + self.result = result + + def invoke(self, _messages): + return self.result + + +class DummyResult: + def __init__(self, content, tool_calls): + self.content = content + self.tool_calls = tool_calls + + +class DummyLLM: + def __init__(self, result): + self.result = result + self.bound_tool_names = [] + + def bind_tools(self, tools): + self.bound_tool_names = [tool.name for tool in tools] + return object() + + +def test_segment_analyst_returns_structured_segment_data(monkeypatch): + from tradingagents.agents.analysts.segment_analyst import create_segment_analyst + + result = DummyResult( + content="""## Segment Summary + +Apple's iPhone remains the primary demand engine, Services is the highest-quality +profit pool, and Wearables is a smaller but defensible ecosystem layer. + +```json +{ + "business_unit_decomposition": [ + { + "segment": "iPhone", + "revenue_share_pct": 52, + "growth_trend": "stable", + "strategic_role": "core hardware platform" + }, + { + "segment": "Services", + "revenue_share_pct": 23, + "growth_trend": "expanding", + "strategic_role": "high-margin recurring engine" + } + ], + "segment_economics": { + "margin_profile": { + "iPhone": "mid-margin, scale-driven", + "Services": "high-margin recurring" + }, + "capital_intensity": { + "iPhone": "high", + "Services": "low" + }, + "cyclicality": { + "iPhone": "medium", + "Services": "low" + } + }, + "value_driver_map": [ + { + "driver": "AI-enabled upgrade cycle", + "impacted_segments": ["iPhone"], + "direction": "upside", + "horizon": "6-12m", + "evidence": "on-device feature expansion supports ASP and volume" + }, + { + "driver": "App Store regulatory pressure", + "impacted_segments": ["Services"], + "direction": "downside", + "horizon": "12-24m", + "evidence": "potential take-rate compression in key regions" + } + ] +} +```""", + tool_calls=[], + ) + llm = DummyLLM(result) + monkeypatch.setattr( + "tradingagents.agents.analysts.segment_analyst.ChatPromptTemplate.from_messages", + lambda _messages: DummyPrompt(result), + ) + + node = create_segment_analyst(llm) + output = node( + { + "trade_date": "2026-03-24", + "company_of_interest": "AAPL", + "messages": ["analyze segment exposure"], + } + ) + + assert llm.bound_tool_names == [ + "get_segment_fundamentals", + "get_segment_income_statement", + "get_segment_news", + ] + assert output["segment_report"] == result.content + assert output["segment_data"] == { + "ticker": "AAPL", + "analysis_date": "2026-03-24", + "business_unit_decomposition": [ + { + "segment": "iPhone", + "revenue_share_pct": 52, + "growth_trend": "stable", + "strategic_role": "core hardware platform", + }, + { + "segment": "Services", + "revenue_share_pct": 23, + "growth_trend": "expanding", + "strategic_role": "high-margin recurring engine", + }, + ], + "segment_economics": { + "margin_profile": { + "iPhone": "mid-margin, scale-driven", + "Services": "high-margin recurring", + }, + "capital_intensity": { + "iPhone": "high", + "Services": "low", + }, + "cyclicality": { + "iPhone": "medium", + "Services": "low", + }, + }, + "value_driver_map": [ + { + "driver": "AI-enabled upgrade cycle", + "impacted_segments": ["iPhone"], + "direction": "upside", + "horizon": "6-12m", + "evidence": "on-device feature expansion supports ASP and volume", + }, + { + "driver": "App Store regulatory pressure", + "impacted_segments": ["Services"], + "direction": "downside", + "horizon": "12-24m", + "evidence": "potential take-rate compression in key regions", + }, + ], + } + assert output["messages"] == [result] + + +def test_extract_segment_payload_tolerates_common_model_json_variants(): + from tradingagents.agents.analysts.segment_analyst import _extract_segment_payload + + expected = { + "business_unit_decomposition": [{"segment": "Services"}], + "segment_economics": {"margin_profile": {"Services": "high"}}, + "value_driver_map": [{"driver": "pricing"}], + } + + uppercase_fence = """ +```JSON +{"business_unit_decomposition":[{"segment":"Services"}],"segment_economics":{"margin_profile":{"Services":"high"}},"value_driver_map":[{"driver":"pricing"}]} +``` +""" + plain_fence = """ +``` +{"business_unit_decomposition":[{"segment":"Services"}],"segment_economics":{"margin_profile":{"Services":"high"}},"value_driver_map":[{"driver":"pricing"}]} +``` +""" + raw_json = """ +Narrative intro before payload. +{"business_unit_decomposition":[{"segment":"Services"}],"segment_economics":{"margin_profile":{"Services":"high"}},"value_driver_map":[{"driver":"pricing"}]} +Tail note after payload. +""" + + assert _extract_segment_payload(uppercase_fence) == expected + assert _extract_segment_payload(plain_fence) == expected + assert _extract_segment_payload(raw_json) == expected + + +def test_propagator_initial_state_seeds_segment_defaults(): + from tradingagents.graph.propagation import Propagator + + state = Propagator().create_initial_state("AAPL", "2026-03-24") + + assert state["segment_report"] == "" + assert state["segment_data"] == { + "ticker": "", + "analysis_date": "", + "business_unit_decomposition": [], + "segment_economics": {}, + "value_driver_map": [], + } diff --git a/tests/test_stock_role_wiring.py b/tests/test_stock_role_wiring.py index 17e020c5..4ce213d6 100644 --- a/tests/test_stock_role_wiring.py +++ b/tests/test_stock_role_wiring.py @@ -16,16 +16,18 @@ EXPECTED_VALUATION_DATA = { } EXPECTED_SEGMENT_DATA = { - "segments": [], - "dominant_segment": "", - "thesis": "", + "ticker": "", + "analysis_date": "", + "business_unit_decomposition": [], + "segment_economics": {}, + "value_driver_map": [], } EXPECTED_SCENARIO_CATALYST_DATA = { - "bull_case": {"probability": None, "price_target": None, "thesis": ""}, - "base_case": {"probability": None, "price_target": None, "thesis": ""}, - "bear_case": {"probability": None, "price_target": None, "thesis": ""}, - "catalysts": [], + "ticker": "", + "analysis_date": "", + "scenario_map": [], + "dated_catalyst_map": [], "invalidation_triggers": [], } diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 263d1ce7..8a4b68f6 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -7,6 +7,8 @@ from .analysts.factor_rule_analyst import create_factor_rule_analyst from .analysts.macro_analyst import create_macro_analyst from .analysts.market_analyst import create_market_analyst from .analysts.news_analyst import create_news_analyst +from .analysts.scenario_catalyst_analyst import create_scenario_catalyst_analyst +from .analysts.segment_analyst import create_segment_analyst from .analysts.social_media_analyst import create_social_media_analyst from .analysts.valuation_analyst import create_valuation_analyst @@ -37,6 +39,8 @@ __all__ = [ "create_market_analyst", "create_neutral_debator", "create_news_analyst", + "create_scenario_catalyst_analyst", + "create_segment_analyst", "create_valuation_analyst", "create_aggressive_debator", "create_portfolio_manager", diff --git a/tradingagents/agents/analysts/scenario_catalyst_analyst.py b/tradingagents/agents/analysts/scenario_catalyst_analyst.py new file mode 100644 index 00000000..98dd406d --- /dev/null +++ b/tradingagents/agents/analysts/scenario_catalyst_analyst.py @@ -0,0 +1,137 @@ +import json +import re + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +from tradingagents.agents.utils.agent_utils import ( + build_instrument_context, + get_catalyst_calendar, + get_scenario_fundamentals, + get_scenario_news, +) + + +def _extract_scenario_catalyst_payload(report: str) -> dict: + if not report: + return {} + + # Prefer fenced JSON payloads (supports ```json, ```JSON, and unlabeled ```). + for match in re.finditer( + r"```(?:\s*([A-Za-z]+))?\s*(\{.*?\})\s*```", + report, + re.DOTALL, + ): + label = (match.group(1) or "").strip().lower() + if label and label != "json": + continue + try: + payload = json.loads(match.group(2)) + except json.JSONDecodeError: + continue + if isinstance(payload, dict): + return payload + + # Fallback: tolerate raw JSON object embedded in body text. + decoder = json.JSONDecoder() + for brace_match in re.finditer(r"\{", report): + candidate = report[brace_match.start() :].lstrip() + try: + payload, _ = decoder.raw_decode(candidate) + except json.JSONDecodeError: + continue + if isinstance(payload, dict): + return payload + + return {} + + +def _build_scenario_catalyst_data(ticker: str, analysis_date: str, report: str) -> dict: + payload = _extract_scenario_catalyst_payload(report) + scenario_map = payload.get("scenario_map", []) + dated_catalyst_map = payload.get("dated_catalyst_map", []) + invalidation_triggers = payload.get("invalidation_triggers", []) + + if not isinstance(scenario_map, list): + scenario_map = [] + if not isinstance(dated_catalyst_map, list): + dated_catalyst_map = [] + if not isinstance(invalidation_triggers, list): + invalidation_triggers = [] + + return { + "ticker": ticker, + "analysis_date": analysis_date, + "scenario_map": scenario_map, + "dated_catalyst_map": dated_catalyst_map, + "invalidation_triggers": invalidation_triggers, + } + + +def create_scenario_catalyst_analyst(llm): + def scenario_catalyst_analyst_node(state): + current_date = state["trade_date"] + ticker = state["company_of_interest"] + instrument_context = build_instrument_context(ticker) + + tools = [ + get_scenario_fundamentals, + get_scenario_news, + get_catalyst_calendar, + ] + + system_message = ( + "You are a scenario and catalyst analyst focused on bull/base/bear framing and " + "timed event risk for the instrument. Use `get_scenario_fundamentals` to anchor " + "fundamental sensitivity, `get_scenario_news` to identify company-specific drivers, " + "and `get_catalyst_calendar` to map date-based macro/policy events. Deliver a concise " + "Markdown narrative with bull, base, and bear case probabilities, key signposts, and " + "thesis invalidation logic. Your response must contain two parts: " + "(1) a Markdown summary and catalyst table, followed by " + "(2) a fenced JSON block (```json ... ```) with exactly these top-level keys: " + "`scenario_map` (list of objects with `name`, `probability_pct`, `thesis`, " + "`valuation_implication`, `signposts`), `dated_catalyst_map` (list of objects with " + "`catalyst`, `date_or_window`, `related_scenarios`, `expected_impact`, `confidence`), " + "and `invalidation_triggers` (list of objects with `trigger`, `affected_scenarios`, " + "`severity`, `evidence_to_watch`). If data is unavailable, still include all keys " + "using empty lists." + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," + " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." + " You have access to the following tools: {tool_names}.\n{system_message}" + "For your reference, the current date is {current_date}. {instrument_context}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join(tool.name for tool in tools)) + prompt = prompt.partial(current_date=current_date) + prompt = prompt.partial(instrument_context=instrument_context) + + chain = prompt | llm.bind_tools(tools) + result = chain.invoke(state["messages"]) + + tool_calls = getattr(result, "tool_calls", None) or [] + report = result.content if len(tool_calls) == 0 else "" + + return { + "messages": [result], + "scenario_catalyst_report": report, + "scenario_catalyst_data": _build_scenario_catalyst_data( + ticker, + current_date, + report, + ), + } + + return scenario_catalyst_analyst_node diff --git a/tradingagents/agents/analysts/segment_analyst.py b/tradingagents/agents/analysts/segment_analyst.py new file mode 100644 index 00000000..0bd61be2 --- /dev/null +++ b/tradingagents/agents/analysts/segment_analyst.py @@ -0,0 +1,133 @@ +import json +import re + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +from tradingagents.agents.utils.agent_utils import ( + build_instrument_context, + get_segment_fundamentals, + get_segment_income_statement, + get_segment_news, +) + + +def _extract_segment_payload(report: str) -> dict: + if not report: + return {} + + # Prefer fenced JSON payloads (supports ```json, ```JSON, and unlabeled ```). + for match in re.finditer( + r"```(?:\s*([A-Za-z]+))?\s*(\{.*?\})\s*```", + report, + re.DOTALL, + ): + label = (match.group(1) or "").strip().lower() + if label and label != "json": + continue + try: + payload = json.loads(match.group(2)) + except json.JSONDecodeError: + continue + if isinstance(payload, dict): + return payload + + # Fallback: tolerate raw JSON object embedded in body text. + decoder = json.JSONDecoder() + for brace_match in re.finditer(r"\{", report): + candidate = report[brace_match.start() :].lstrip() + try: + payload, _ = decoder.raw_decode(candidate) + except json.JSONDecodeError: + continue + if isinstance(payload, dict): + return payload + + return {} + + +def _build_segment_data(ticker: str, analysis_date: str, report: str) -> dict: + payload = _extract_segment_payload(report) + business_unit_decomposition = payload.get("business_unit_decomposition", []) + segment_economics = payload.get("segment_economics", {}) + value_driver_map = payload.get("value_driver_map", []) + + if not isinstance(business_unit_decomposition, list): + business_unit_decomposition = [] + if not isinstance(segment_economics, dict): + segment_economics = {} + if not isinstance(value_driver_map, list): + value_driver_map = [] + + return { + "ticker": ticker, + "analysis_date": analysis_date, + "business_unit_decomposition": business_unit_decomposition, + "segment_economics": segment_economics, + "value_driver_map": value_driver_map, + } + + +def create_segment_analyst(llm): + def segment_analyst_node(state): + current_date = state["trade_date"] + ticker = state["company_of_interest"] + instrument_context = build_instrument_context(ticker) + + tools = [ + get_segment_fundamentals, + get_segment_income_statement, + get_segment_news, + ] + + system_message = ( + "You are a segment analyst focused on business-mix quality and revenue durability. " + "Use `get_segment_fundamentals` to summarize business lines and strategic positioning, " + "`get_segment_income_statement` to infer segment-level margin direction from reported trends, " + "and `get_segment_news` to identify demand, pricing, and competitive catalysts for key segments. " + "Deliver a concise segment-by-segment view, highlight concentration risks, and append a Markdown " + "table that maps each major segment to growth outlook, margin trend, and trading implication. " + "Your response must contain two parts: " + "(1) a Markdown narrative summary and table, followed by " + "(2) a fenced JSON block (```json ... ```) with exactly these top-level keys: " + "`business_unit_decomposition` (list of objects with `segment`, `revenue_share_pct`, " + "`growth_trend`, `strategic_role`), `segment_economics` (object summarizing margin profile, " + "capital intensity, cyclicality), and `value_driver_map` (list of objects with `driver`, " + "`impacted_segments`, `direction`, `horizon`, `evidence`). " + "If data is unavailable, still include all keys using empty lists/objects." + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," + " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." + " You have access to the following tools: {tool_names}.\n{system_message}" + "For your reference, the current date is {current_date}. {instrument_context}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join(tool.name for tool in tools)) + prompt = prompt.partial(current_date=current_date) + prompt = prompt.partial(instrument_context=instrument_context) + + chain = prompt | llm.bind_tools(tools) + result = chain.invoke(state["messages"]) + + tool_calls = getattr(result, "tool_calls", None) or [] + report = result.content if len(tool_calls) == 0 else "" + + return { + "messages": [result], + "segment_report": report, + "segment_data": _build_segment_data(ticker, current_date, report), + } + + return segment_analyst_node diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 315e9780..a1a09292 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -60,9 +60,11 @@ class ValuationData(TypedDict): class SegmentData(TypedDict): - segments: list[dict[str, Any]] - dominant_segment: str - thesis: str + ticker: str + analysis_date: str + business_unit_decomposition: list[dict[str, Any]] + segment_economics: dict[str, Any] + value_driver_map: list[dict[str, Any]] class ScenarioCaseData(TypedDict): @@ -72,11 +74,11 @@ class ScenarioCaseData(TypedDict): class ScenarioCatalystData(TypedDict): - bull_case: ScenarioCaseData - base_case: ScenarioCaseData - bear_case: ScenarioCaseData - catalysts: list[dict[str, Any]] - invalidation_triggers: list[str] + ticker: str + analysis_date: str + scenario_map: list[dict[str, Any]] + dated_catalyst_map: list[dict[str, Any]] + invalidation_triggers: list[dict[str, Any]] class PositionSizingData(TypedDict): @@ -104,9 +106,11 @@ def make_default_valuation_data() -> ValuationData: def make_default_segment_data() -> SegmentData: return { - "segments": [], - "dominant_segment": "", - "thesis": "", + "ticker": "", + "analysis_date": "", + "business_unit_decomposition": [], + "segment_economics": {}, + "value_driver_map": [], } @@ -120,10 +124,10 @@ def make_default_scenario_case_data() -> ScenarioCaseData: def make_default_scenario_catalyst_data() -> ScenarioCatalystData: return { - "bull_case": make_default_scenario_case_data(), - "base_case": make_default_scenario_case_data(), - "bear_case": make_default_scenario_case_data(), - "catalysts": [], + "ticker": "", + "analysis_date": "", + "scenario_map": [], + "dated_catalyst_map": [], "invalidation_triggers": [], } @@ -173,6 +177,11 @@ class AgentState(MessagesState): factor_rules_report: Annotated[ str, "Summary from the optional factor rule analyst" ] + segment_report: Annotated[str, "Report from the Segment Analyst"] + scenario_catalyst_report: Annotated[ + str, + "Report from the Scenario and Catalyst Analyst", + ] valuation_data: Annotated[ ValuationData, "Structured valuation underwriting output" ] diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index b1df53ed..392d6bc9 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -26,6 +26,16 @@ from tradingagents.agents.utils.macro_data_tools import ( get_fed_calendar, get_yield_curve, ) +from tradingagents.agents.utils.scenario_tools import ( + get_catalyst_calendar, + get_scenario_fundamentals, + get_scenario_news, +) +from tradingagents.agents.utils.segment_tools import ( + get_segment_fundamentals, + get_segment_income_statement, + get_segment_news, +) from tradingagents.agents.utils.valuation_tools import ( get_valuation_inputs, ) @@ -45,6 +55,12 @@ __all__ = [ "get_indicators", "get_insider_transactions", "get_news", + "get_catalyst_calendar", + "get_scenario_fundamentals", + "get_scenario_news", + "get_segment_fundamentals", + "get_segment_income_statement", + "get_segment_news", "get_stock_data", "get_valuation_inputs", "get_yield_curve", diff --git a/tradingagents/agents/utils/scenario_tools.py b/tradingagents/agents/utils/scenario_tools.py new file mode 100644 index 00000000..3557ec05 --- /dev/null +++ b/tradingagents/agents/utils/scenario_tools.py @@ -0,0 +1,41 @@ +from typing import Annotated + +from langchain_core.tools import tool + + +@tool +def get_scenario_fundamentals( + ticker: Annotated[str, "company ticker symbol"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +) -> str: + """Retrieve fundamentals context to support scenario probability framing.""" + from tradingagents.dataflows.interface import route_to_vendor + + return route_to_vendor("get_fundamentals", ticker=ticker, curr_date=curr_date) + + +@tool +def get_scenario_news( + query: Annotated[str, "scenario-specific catalyst query"], + start_date: Annotated[str, "start date for search window, YYYY-MM-DD"], + end_date: Annotated[str, "end date for search window, YYYY-MM-DD"], +) -> str: + """Retrieve company-specific news that can update bull/base/bear probabilities.""" + from tradingagents.dataflows.interface import route_to_vendor + + return route_to_vendor( + "get_news", + query=query, + start_date=start_date, + end_date=end_date, + ) + + +@tool +def get_catalyst_calendar( + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +) -> str: + """Retrieve policy-calendar events that can act as dated catalysts.""" + from tradingagents.dataflows.interface import route_to_vendor + + return route_to_vendor("get_fed_calendar", curr_date=curr_date) diff --git a/tradingagents/agents/utils/segment_tools.py b/tradingagents/agents/utils/segment_tools.py new file mode 100644 index 00000000..9082bb08 --- /dev/null +++ b/tradingagents/agents/utils/segment_tools.py @@ -0,0 +1,48 @@ +from typing import Annotated + +from langchain_core.tools import tool + + +@tool +def get_segment_fundamentals( + ticker: Annotated[str, "company ticker symbol"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +) -> str: + """Retrieve company fundamentals for segment-level business mix analysis.""" + from tradingagents.dataflows.interface import route_to_vendor + + return route_to_vendor("get_fundamentals", ticker=ticker, curr_date=curr_date) + + +@tool +def get_segment_income_statement( + ticker: Annotated[str, "company ticker symbol"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], + freq: Annotated[str, "financial statement frequency: quarterly or annual"] = "quarterly", +) -> str: + """Retrieve income statement details that support segment profitability analysis.""" + from tradingagents.dataflows.interface import route_to_vendor + + return route_to_vendor( + "get_income_statement", + ticker=ticker, + freq=freq, + curr_date=curr_date, + ) + + +@tool +def get_segment_news( + query: Annotated[str, "segment-specific search query, including company or product line"], + start_date: Annotated[str, "start date for search window, YYYY-MM-DD"], + end_date: Annotated[str, "end date for search window, YYYY-MM-DD"], +) -> str: + """Retrieve segment-relevant news that can explain demand and pricing trends.""" + from tradingagents.dataflows.interface import route_to_vendor + + return route_to_vendor( + "get_news", + query=query, + start_date=start_date, + end_date=end_date, + ) diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index 9bb35081..7741e973 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -54,6 +54,8 @@ class Propagator: "news_report": "", "factor_rules_report": "", "macro_report": "", + "segment_report": "", + "scenario_catalyst_report": "", **make_default_structured_stock_underwriting_state(), } diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 15509e7a..e117b4e9 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -57,6 +57,8 @@ class GraphSetup: self.valuation_analyst_llm = self._get_role_llm( "valuation", self.quick_thinking_llm ) + self.segment_analyst_llm = self._get_role_llm("segment", self.quick_thinking_llm) + self.scenario_analyst_llm = self._get_role_llm("scenario", self.quick_thinking_llm) self.macro_analyst_llm = self._get_role_llm("macro", self.quick_thinking_llm) self.bull_researcher_llm = self._get_role_llm( "bull_researcher", self.quick_thinking_llm @@ -115,6 +117,8 @@ class GraphSetup: - "fundamentals": Fundamentals analyst - "factor_rules": Factor rule analyst - "valuation": Valuation analyst + - "segment": Segment analyst + - "scenario": Scenario and catalyst analyst - "macro": Macro analyst """ if len(selected_analysts) == 0: @@ -167,6 +171,18 @@ class GraphSetup: delete_nodes["valuation"] = create_msg_delete() tool_nodes["valuation"] = self.tool_nodes["valuation"] + if "segment" in selected_analysts: + analyst_nodes["segment"] = create_segment_analyst(self.segment_analyst_llm) + delete_nodes["segment"] = create_msg_delete() + tool_nodes["segment"] = self.tool_nodes["segment"] + + if "scenario" in selected_analysts: + analyst_nodes["scenario"] = create_scenario_catalyst_analyst( + self.scenario_analyst_llm + ) + delete_nodes["scenario"] = create_msg_delete() + tool_nodes["scenario"] = self.tool_nodes["scenario"] + if "macro" in selected_analysts: analyst_nodes["macro"] = create_macro_analyst(self.macro_analyst_llm) delete_nodes["macro"] = create_msg_delete() diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 05cbee23..c5f61e83 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -33,6 +33,12 @@ from tradingagents.agents.utils.agent_utils import ( get_news, get_insider_transactions, get_global_news, + get_catalyst_calendar, + get_scenario_fundamentals, + get_scenario_news, + get_segment_fundamentals, + get_segment_income_statement, + get_segment_news, get_valuation_inputs, get_yield_curve, ) @@ -64,6 +70,8 @@ class TradingAgentsGraph: "fundamentals", "factor_rules", "valuation", + "segment", + "scenario", "macro", "bull_researcher", "bear_researcher", @@ -312,6 +320,22 @@ class TradingAgentsGraph: get_valuation_inputs, ] ), + "segment": ToolNode( + [ + # Segment and business-mix analysis tools + get_segment_fundamentals, + get_segment_income_statement, + get_segment_news, + ] + ), + "scenario": ToolNode( + [ + # Scenario and catalyst mapping tools + get_scenario_fundamentals, + get_scenario_news, + get_catalyst_calendar, + ] + ), "macro": ToolNode( [ # Macroeconomic analysis tools @@ -367,7 +391,11 @@ class TradingAgentsGraph: "news_report": final_state["news_report"], "fundamentals_report": final_state["fundamentals_report"], "factor_rules_report": final_state.get("factor_rules_report", ""), + "segment_report": final_state.get("segment_report", ""), + "segment_data": final_state.get("segment_data", {}), "macro_report": final_state.get("macro_report", ""), + "scenario_catalyst_report": final_state.get("scenario_catalyst_report", ""), + "scenario_catalyst_data": final_state.get("scenario_catalyst_data", {}), "investment_debate_state": { "bull_history": final_state["investment_debate_state"]["bull_history"], "bear_history": final_state["investment_debate_state"]["bear_history"],