From a17fc55fb49f64cdaad020a085b09d408eacebbf Mon Sep 17 00:00:00 2001 From: 69049ed6x <69049ed6x@users.noreply.github.com> Date: Sun, 8 Mar 2026 14:45:57 +0800 Subject: [PATCH] refactor: share factor rules clear node name --- tests/test_factor_rules.py | 12 ++++++++++++ tradingagents/graph/conditional_logic.py | 5 ++++- tradingagents/graph/setup.py | 8 ++++++-- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/test_factor_rules.py b/tests/test_factor_rules.py index 46bdc31a..c64190db 100644 --- a/tests/test_factor_rules.py +++ b/tests/test_factor_rules.py @@ -7,6 +7,7 @@ from pathlib import Path MODULE_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "utils" / "factor_rules.py" GRAPH_SETUP_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "setup.py" +CONDITIONAL_LOGIC_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "conditional_logic.py" DEFAULT_CONFIG_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "default_config.py" FACTOR_RULE_ANALYST_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "analysts" / "factor_rule_analyst.py" SPEC = importlib.util.spec_from_file_location("factor_rules", MODULE_PATH) @@ -399,6 +400,17 @@ class GraphSetupSourceTests(unittest.TestCase): self.assertIn('selected_analysts = ["market", "social", "news", "fundamentals", "factor_rules"]', source) +class ConditionalLogicSourceTests(unittest.TestCase): + def test_factor_rules_clear_node_uses_shared_constant(self): + conditional_source = CONDITIONAL_LOGIC_PATH.read_text(encoding="utf-8") + setup_source = GRAPH_SETUP_PATH.read_text(encoding="utf-8") + + self.assertIn('FACTOR_RULES_CLEAR_NODE = "Msg Clear Factor_rules"', conditional_source) + self.assertIn('return FACTOR_RULES_CLEAR_NODE', conditional_source) + self.assertIn('from .conditional_logic import ConditionalLogic, FACTOR_RULES_CLEAR_NODE', setup_source) + self.assertIn('FACTOR_RULES_CLEAR_NODE', setup_source) + + class DefaultConfigSourceTests(unittest.TestCase): def test_default_headers_is_opt_in_none(self): source = DEFAULT_CONFIG_PATH.read_text(encoding="utf-8") diff --git a/tradingagents/graph/conditional_logic.py b/tradingagents/graph/conditional_logic.py index 5b6c4587..b704fcfc 100644 --- a/tradingagents/graph/conditional_logic.py +++ b/tradingagents/graph/conditional_logic.py @@ -3,6 +3,9 @@ from tradingagents.agents.utils.agent_states import AgentState +FACTOR_RULES_CLEAR_NODE = "Msg Clear Factor_rules" + + class ConditionalLogic: """Handles conditional logic for determining graph flow.""" @@ -45,7 +48,7 @@ class ConditionalLogic: def should_continue_factor_rules(self, state: AgentState): """Factor rule analyst is a pure context node with no tool loop.""" - return "Msg Clear Factor_rules" + return FACTOR_RULES_CLEAR_NODE def should_continue_debate(self, state: AgentState) -> str: """Determine if debate should continue.""" diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index dd3e1926..2eee5238 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -8,7 +8,7 @@ from langgraph.prebuilt import ToolNode from tradingagents.agents import * from tradingagents.agents.utils.agent_states import AgentState -from .conditional_logic import ConditionalLogic +from .conditional_logic import ConditionalLogic, FACTOR_RULES_CLEAR_NODE class GraphSetup: @@ -146,7 +146,11 @@ class GraphSetup: for i, analyst_type in enumerate(selected_analysts): current_analyst = f"{analyst_type.capitalize()} Analyst" current_tools = f"tools_{analyst_type}" - current_clear = f"Msg Clear {analyst_type.capitalize()}" + current_clear = ( + FACTOR_RULES_CLEAR_NODE + if analyst_type == "factor_rules" + else f"Msg Clear {analyst_type.capitalize()}" + ) # Add conditional edges for current analyst continue_fn = getattr(self.conditional_logic, f"should_continue_{analyst_type}")