diff --git a/tests/test_stock_role_wiring.py b/tests/test_stock_role_wiring.py index 4f49598c..17e020c5 100644 --- a/tests/test_stock_role_wiring.py +++ b/tests/test_stock_role_wiring.py @@ -43,6 +43,14 @@ EXPECTED_CHIEF_ANALYST_DATA = { "confidence": "", } +STRUCTURED_PASSTHROUGH_KEYS = { + "valuation_data", + "segment_data", + "scenario_catalyst_data", + "position_sizing_data", + "chief_analyst_data", +} + class DummyMemory: def get_memories(self, _situation, n_matches=2): @@ -70,6 +78,14 @@ def assert_structured_stock_fields(payload): assert payload["chief_analyst_data"] == EXPECTED_CHIEF_ANALYST_DATA +def assert_manager_update_omits_structured_passthrough( + payload, expected_present_keys +): + for key in expected_present_keys: + assert key in payload + assert STRUCTURED_PASSTHROUGH_KEYS.isdisjoint(payload) + + def compile_single_node_graph(node_name, node): workflow = StateGraph(AgentState) workflow.add_node(node_name, node) @@ -101,7 +117,10 @@ def test_research_manager_update_omits_structured_stock_passthrough_fields(monke assert research_result["investment_debate_state"]["judge_decision"] == ( "Research manager output" ) - assert set(research_result) == {"investment_debate_state", "investment_plan"} + assert_manager_update_omits_structured_passthrough( + research_result, + {"investment_debate_state", "investment_plan"}, + ) def test_research_manager_graph_preserves_structured_stock_underwriting_fields( @@ -149,7 +168,10 @@ def test_portfolio_manager_update_omits_structured_stock_passthrough_fields( assert portfolio_result["risk_debate_state"]["judge_decision"] == ( "Portfolio manager output" ) - assert set(portfolio_result) == {"risk_debate_state", "final_trade_decision"} + assert_manager_update_omits_structured_passthrough( + portfolio_result, + {"risk_debate_state", "final_trade_decision"}, + ) def test_portfolio_manager_graph_preserves_structured_stock_underwriting_fields(