From fbefcc62ea98d69f6a1d18d639247757605bf4d4 Mon Sep 17 00:00:00 2001 From: Garrick Date: Tue, 24 Mar 2026 16:39:20 -0700 Subject: [PATCH] refactor: rely on graph merge for stock state --- tests/test_stock_role_wiring.py | 85 +++++++++++++++++-- .../agents/managers/portfolio_manager.py | 24 ------ .../agents/managers/research_manager.py | 24 ------ 3 files changed, 76 insertions(+), 57 deletions(-) diff --git a/tests/test_stock_role_wiring.py b/tests/test_stock_role_wiring.py index 777e26df..4f49598c 100644 --- a/tests/test_stock_role_wiring.py +++ b/tests/test_stock_role_wiring.py @@ -1,5 +1,8 @@ from copy import deepcopy +from langgraph.graph import END, START, StateGraph + +from tradingagents.agents.utils.agent_states import AgentState from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager from tradingagents.agents.managers.research_manager import create_research_manager from tradingagents.graph.propagation import Propagator @@ -67,25 +70,27 @@ def assert_structured_stock_fields(payload): assert payload["chief_analyst_data"] == EXPECTED_CHIEF_ANALYST_DATA +def compile_single_node_graph(node_name, node): + workflow = StateGraph(AgentState) + workflow.add_node(node_name, node) + workflow.add_edge(START, node_name) + workflow.add_edge(node_name, END) + return workflow.compile() + + def test_propagator_initializes_structured_stock_underwriting_fields(): initial_state = Propagator().create_initial_state("NVDA", "2026-03-24") assert_structured_stock_fields(initial_state) -def test_manager_nodes_preserve_structured_stock_underwriting_fields(monkeypatch): +def test_research_manager_update_omits_structured_stock_passthrough_fields(monkeypatch): monkeypatch.setattr( "tradingagents.agents.managers.research_manager.build_instrument_context", lambda _ticker: "instrument context", ) - monkeypatch.setattr( - "tradingagents.agents.managers.portfolio_manager.build_instrument_context", - lambda _ticker: "instrument context", - ) state = Propagator().create_initial_state("NVDA", "2026-03-24") - state["investment_plan"] = "Existing investment plan" - research_manager = create_research_manager( DummyLLM("Research manager output"), DummyMemory(), @@ -96,8 +101,44 @@ def test_manager_nodes_preserve_structured_stock_underwriting_fields(monkeypatch assert research_result["investment_debate_state"]["judge_decision"] == ( "Research manager output" ) - assert_structured_stock_fields(research_result) + assert set(research_result) == {"investment_debate_state", "investment_plan"} + +def test_research_manager_graph_preserves_structured_stock_underwriting_fields( + monkeypatch, +): + monkeypatch.setattr( + "tradingagents.agents.managers.research_manager.build_instrument_context", + lambda _ticker: "instrument context", + ) + + research_manager = create_research_manager( + DummyLLM("Research manager output"), + DummyMemory(), + ) + state = Propagator().create_initial_state("NVDA", "2026-03-24") + + final_state = compile_single_node_graph("Research Manager", research_manager).invoke( + state + ) + + assert final_state["investment_plan"] == "Research manager output" + assert final_state["investment_debate_state"]["judge_decision"] == ( + "Research manager output" + ) + assert_structured_stock_fields(final_state) + + +def test_portfolio_manager_update_omits_structured_stock_passthrough_fields( + monkeypatch, +): + monkeypatch.setattr( + "tradingagents.agents.managers.portfolio_manager.build_instrument_context", + lambda _ticker: "instrument context", + ) + + state = Propagator().create_initial_state("NVDA", "2026-03-24") + state["investment_plan"] = "Existing investment plan" portfolio_manager = create_portfolio_manager( DummyLLM("Portfolio manager output"), DummyMemory(), @@ -108,4 +149,30 @@ def test_manager_nodes_preserve_structured_stock_underwriting_fields(monkeypatch assert portfolio_result["risk_debate_state"]["judge_decision"] == ( "Portfolio manager output" ) - assert_structured_stock_fields(portfolio_result) + assert set(portfolio_result) == {"risk_debate_state", "final_trade_decision"} + + +def test_portfolio_manager_graph_preserves_structured_stock_underwriting_fields( + monkeypatch, +): + monkeypatch.setattr( + "tradingagents.agents.managers.portfolio_manager.build_instrument_context", + lambda _ticker: "instrument context", + ) + + portfolio_manager = create_portfolio_manager( + DummyLLM("Portfolio manager output"), + DummyMemory(), + ) + state = Propagator().create_initial_state("NVDA", "2026-03-24") + state["investment_plan"] = "Existing investment plan" + + final_state = compile_single_node_graph( + "Portfolio Manager", portfolio_manager + ).invoke(state) + + assert final_state["final_trade_decision"] == "Portfolio manager output" + assert final_state["risk_debate_state"]["judge_decision"] == ( + "Portfolio manager output" + ) + assert_structured_stock_fields(final_state) diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py index 69da4d56..acdf940b 100644 --- a/tradingagents/agents/managers/portfolio_manager.py +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -1,12 +1,8 @@ from tradingagents.agents.utils.agent_utils import build_instrument_context -from tradingagents.agents.utils.agent_states import ( - make_default_structured_stock_underwriting_state, -) def create_portfolio_manager(llm, memory): def portfolio_manager_node(state) -> dict: - structured_stock_defaults = make_default_structured_stock_underwriting_state() instrument_context = build_instrument_context(state["company_of_interest"]) @@ -74,26 +70,6 @@ Be decisive and ground every conclusion in specific evidence from the analysts." return { "risk_debate_state": new_risk_debate_state, "final_trade_decision": response.content, - "valuation_data": state.get( - "valuation_data", - structured_stock_defaults["valuation_data"], - ), - "segment_data": state.get( - "segment_data", - structured_stock_defaults["segment_data"], - ), - "scenario_catalyst_data": state.get( - "scenario_catalyst_data", - structured_stock_defaults["scenario_catalyst_data"], - ), - "position_sizing_data": state.get( - "position_sizing_data", - structured_stock_defaults["position_sizing_data"], - ), - "chief_analyst_data": state.get( - "chief_analyst_data", - structured_stock_defaults["chief_analyst_data"], - ), } return portfolio_manager_node diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 884ce196..2251bd50 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -1,12 +1,8 @@ from tradingagents.agents.utils.agent_utils import build_instrument_context -from tradingagents.agents.utils.agent_states import ( - make_default_structured_stock_underwriting_state, -) def create_research_manager(llm, memory): def research_manager_node(state) -> dict: - structured_stock_defaults = make_default_structured_stock_underwriting_state() instrument_context = build_instrument_context(state["company_of_interest"]) history = state["investment_debate_state"].get("history", "") market_research_report = state["market_report"] @@ -56,26 +52,6 @@ Debate History: return { "investment_debate_state": new_investment_debate_state, "investment_plan": response.content, - "valuation_data": state.get( - "valuation_data", - structured_stock_defaults["valuation_data"], - ), - "segment_data": state.get( - "segment_data", - structured_stock_defaults["segment_data"], - ), - "scenario_catalyst_data": state.get( - "scenario_catalyst_data", - structured_stock_defaults["scenario_catalyst_data"], - ), - "position_sizing_data": state.get( - "position_sizing_data", - structured_stock_defaults["position_sizing_data"], - ), - "chief_analyst_data": state.get( - "chief_analyst_data", - structured_stock_defaults["chief_analyst_data"], - ), } return research_manager_node