refactor: share factor rules clear node name
This commit is contained in:
parent
48ef90741f
commit
a17fc55fb4
|
|
@ -7,6 +7,7 @@ from pathlib import Path
|
|||
|
||||
MODULE_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "utils" / "factor_rules.py"
|
||||
GRAPH_SETUP_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "setup.py"
|
||||
CONDITIONAL_LOGIC_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "conditional_logic.py"
|
||||
DEFAULT_CONFIG_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "default_config.py"
|
||||
FACTOR_RULE_ANALYST_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "analysts" / "factor_rule_analyst.py"
|
||||
SPEC = importlib.util.spec_from_file_location("factor_rules", MODULE_PATH)
|
||||
|
|
@ -399,6 +400,17 @@ class GraphSetupSourceTests(unittest.TestCase):
|
|||
self.assertIn('selected_analysts = ["market", "social", "news", "fundamentals", "factor_rules"]', source)
|
||||
|
||||
|
||||
class ConditionalLogicSourceTests(unittest.TestCase):
|
||||
def test_factor_rules_clear_node_uses_shared_constant(self):
|
||||
conditional_source = CONDITIONAL_LOGIC_PATH.read_text(encoding="utf-8")
|
||||
setup_source = GRAPH_SETUP_PATH.read_text(encoding="utf-8")
|
||||
|
||||
self.assertIn('FACTOR_RULES_CLEAR_NODE = "Msg Clear Factor_rules"', conditional_source)
|
||||
self.assertIn('return FACTOR_RULES_CLEAR_NODE', conditional_source)
|
||||
self.assertIn('from .conditional_logic import ConditionalLogic, FACTOR_RULES_CLEAR_NODE', setup_source)
|
||||
self.assertIn('FACTOR_RULES_CLEAR_NODE', setup_source)
|
||||
|
||||
|
||||
class DefaultConfigSourceTests(unittest.TestCase):
|
||||
def test_default_headers_is_opt_in_none(self):
|
||||
source = DEFAULT_CONFIG_PATH.read_text(encoding="utf-8")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
|
||||
|
||||
FACTOR_RULES_CLEAR_NODE = "Msg Clear Factor_rules"
|
||||
|
||||
|
||||
class ConditionalLogic:
|
||||
"""Handles conditional logic for determining graph flow."""
|
||||
|
||||
|
|
@ -45,7 +48,7 @@ class ConditionalLogic:
|
|||
|
||||
def should_continue_factor_rules(self, state: AgentState):
|
||||
"""Factor rule analyst is a pure context node with no tool loop."""
|
||||
return "Msg Clear Factor_rules"
|
||||
return FACTOR_RULES_CLEAR_NODE
|
||||
|
||||
def should_continue_debate(self, state: AgentState) -> str:
|
||||
"""Determine if debate should continue."""
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from langgraph.prebuilt import ToolNode
|
|||
from tradingagents.agents import *
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .conditional_logic import ConditionalLogic, FACTOR_RULES_CLEAR_NODE
|
||||
|
||||
|
||||
class GraphSetup:
|
||||
|
|
@ -146,7 +146,11 @@ class GraphSetup:
|
|||
for i, analyst_type in enumerate(selected_analysts):
|
||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
||||
current_tools = f"tools_{analyst_type}"
|
||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
||||
current_clear = (
|
||||
FACTOR_RULES_CLEAR_NODE
|
||||
if analyst_type == "factor_rules"
|
||||
else f"Msg Clear {analyst_type.capitalize()}"
|
||||
)
|
||||
|
||||
# Add conditional edges for current analyst
|
||||
continue_fn = getattr(self.conditional_logic, f"should_continue_{analyst_type}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue