fix: avoid mutable analyst defaults
This commit is contained in:
parent
80f03f2a13
commit
9f62a305b1
|
|
@ -1,3 +1,4 @@
|
||||||
|
import ast
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
@ -5,6 +6,7 @@ import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
MODULE_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "utils" / "factor_rules.py"
|
MODULE_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "utils" / "factor_rules.py"
|
||||||
|
GRAPH_SETUP_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "setup.py"
|
||||||
SPEC = importlib.util.spec_from_file_location("factor_rules", MODULE_PATH)
|
SPEC = importlib.util.spec_from_file_location("factor_rules", MODULE_PATH)
|
||||||
factor_rules = importlib.util.module_from_spec(SPEC)
|
factor_rules = importlib.util.module_from_spec(SPEC)
|
||||||
SPEC.loader.exec_module(factor_rules)
|
SPEC.loader.exec_module(factor_rules)
|
||||||
|
|
@ -375,5 +377,25 @@ class FactorRulesPathTests(unittest.TestCase):
|
||||||
self.assertIn("- Conditions: ", summary)
|
self.assertIn("- Conditions: ", summary)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphSetupSourceTests(unittest.TestCase):
|
||||||
|
def test_setup_graph_avoids_mutable_default_selected_analysts(self):
|
||||||
|
source = GRAPH_SETUP_PATH.read_text(encoding="utf-8")
|
||||||
|
module = ast.parse(source)
|
||||||
|
|
||||||
|
setup_graph = None
|
||||||
|
for node in module.body:
|
||||||
|
if isinstance(node, ast.ClassDef) and node.name == "GraphSetup":
|
||||||
|
for item in node.body:
|
||||||
|
if isinstance(item, ast.FunctionDef) and item.name == "setup_graph":
|
||||||
|
setup_graph = item
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertIsNotNone(setup_graph)
|
||||||
|
self.assertEqual(len(setup_graph.args.defaults), 1)
|
||||||
|
self.assertIsInstance(setup_graph.args.defaults[0], ast.Constant)
|
||||||
|
self.assertIsNone(setup_graph.args.defaults[0].value)
|
||||||
|
self.assertIn('selected_analysts = ["market", "social", "news", "fundamentals", "factor_rules"]', source)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# TradingAgents/graph/setup.py
|
# TradingAgents/graph/setup.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, Optional, List
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langgraph.graph import END, StateGraph, START
|
from langgraph.graph import END, StateGraph, START
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
@ -38,7 +38,7 @@ class GraphSetup:
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals", "factor_rules"]
|
self, selected_analysts: Optional[List[str]] = None
|
||||||
):
|
):
|
||||||
"""Set up and compile the agent workflow graph.
|
"""Set up and compile the agent workflow graph.
|
||||||
|
|
||||||
|
|
@ -50,6 +50,9 @@ class GraphSetup:
|
||||||
- "fundamentals": Fundamentals analyst
|
- "fundamentals": Fundamentals analyst
|
||||||
- "factor_rules": Factor rule analyst
|
- "factor_rules": Factor rule analyst
|
||||||
"""
|
"""
|
||||||
|
if selected_analysts is None:
|
||||||
|
selected_analysts = ["market", "social", "news", "fundamentals", "factor_rules"]
|
||||||
|
|
||||||
if len(selected_analysts) == 0:
|
if len(selected_analysts) == 0:
|
||||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue