refactor: share factor rules clear node name
This commit is contained in:
parent
48ef90741f
commit
a17fc55fb4
|
|
@ -7,6 +7,7 @@ from pathlib import Path
|
||||||
|
|
||||||
MODULE_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "utils" / "factor_rules.py"
|
MODULE_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "utils" / "factor_rules.py"
|
||||||
GRAPH_SETUP_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "setup.py"
|
GRAPH_SETUP_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "setup.py"
|
||||||
|
CONDITIONAL_LOGIC_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "conditional_logic.py"
|
||||||
DEFAULT_CONFIG_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "default_config.py"
|
DEFAULT_CONFIG_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "default_config.py"
|
||||||
FACTOR_RULE_ANALYST_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "analysts" / "factor_rule_analyst.py"
|
FACTOR_RULE_ANALYST_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "analysts" / "factor_rule_analyst.py"
|
||||||
SPEC = importlib.util.spec_from_file_location("factor_rules", MODULE_PATH)
|
SPEC = importlib.util.spec_from_file_location("factor_rules", MODULE_PATH)
|
||||||
|
|
@ -399,6 +400,17 @@ class GraphSetupSourceTests(unittest.TestCase):
|
||||||
self.assertIn('selected_analysts = ["market", "social", "news", "fundamentals", "factor_rules"]', source)
|
self.assertIn('selected_analysts = ["market", "social", "news", "fundamentals", "factor_rules"]', source)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalLogicSourceTests(unittest.TestCase):
|
||||||
|
def test_factor_rules_clear_node_uses_shared_constant(self):
|
||||||
|
conditional_source = CONDITIONAL_LOGIC_PATH.read_text(encoding="utf-8")
|
||||||
|
setup_source = GRAPH_SETUP_PATH.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
self.assertIn('FACTOR_RULES_CLEAR_NODE = "Msg Clear Factor_rules"', conditional_source)
|
||||||
|
self.assertIn('return FACTOR_RULES_CLEAR_NODE', conditional_source)
|
||||||
|
self.assertIn('from .conditional_logic import ConditionalLogic, FACTOR_RULES_CLEAR_NODE', setup_source)
|
||||||
|
self.assertIn('FACTOR_RULES_CLEAR_NODE', setup_source)
|
||||||
|
|
||||||
|
|
||||||
class DefaultConfigSourceTests(unittest.TestCase):
|
class DefaultConfigSourceTests(unittest.TestCase):
|
||||||
def test_default_headers_is_opt_in_none(self):
|
def test_default_headers_is_opt_in_none(self):
|
||||||
source = DEFAULT_CONFIG_PATH.read_text(encoding="utf-8")
|
source = DEFAULT_CONFIG_PATH.read_text(encoding="utf-8")
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
|
|
||||||
|
|
||||||
|
FACTOR_RULES_CLEAR_NODE = "Msg Clear Factor_rules"
|
||||||
|
|
||||||
|
|
||||||
class ConditionalLogic:
|
class ConditionalLogic:
|
||||||
"""Handles conditional logic for determining graph flow."""
|
"""Handles conditional logic for determining graph flow."""
|
||||||
|
|
||||||
|
|
@ -45,7 +48,7 @@ class ConditionalLogic:
|
||||||
|
|
||||||
def should_continue_factor_rules(self, state: AgentState):
|
def should_continue_factor_rules(self, state: AgentState):
|
||||||
"""Factor rule analyst is a pure context node with no tool loop."""
|
"""Factor rule analyst is a pure context node with no tool loop."""
|
||||||
return "Msg Clear Factor_rules"
|
return FACTOR_RULES_CLEAR_NODE
|
||||||
|
|
||||||
def should_continue_debate(self, state: AgentState) -> str:
|
def should_continue_debate(self, state: AgentState) -> str:
|
||||||
"""Determine if debate should continue."""
|
"""Determine if debate should continue."""
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from langgraph.prebuilt import ToolNode
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic, FACTOR_RULES_CLEAR_NODE
|
||||||
|
|
||||||
|
|
||||||
class GraphSetup:
|
class GraphSetup:
|
||||||
|
|
@ -146,7 +146,11 @@ class GraphSetup:
|
||||||
for i, analyst_type in enumerate(selected_analysts):
|
for i, analyst_type in enumerate(selected_analysts):
|
||||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
||||||
current_tools = f"tools_{analyst_type}"
|
current_tools = f"tools_{analyst_type}"
|
||||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
current_clear = (
|
||||||
|
FACTOR_RULES_CLEAR_NODE
|
||||||
|
if analyst_type == "factor_rules"
|
||||||
|
else f"Msg Clear {analyst_type.capitalize()}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add conditional edges for current analyst
|
# Add conditional edges for current analyst
|
||||||
continue_fn = getattr(self.conditional_logic, f"should_continue_{analyst_type}")
|
continue_fn = getattr(self.conditional_logic, f"should_continue_{analyst_type}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue