fix: avoid mutable analyst defaults
This commit is contained in:
parent
80f03f2a13
commit
9f62a305b1
|
|
@ -1,3 +1,4 @@
|
|||
import ast
|
||||
import importlib.util
|
||||
import json
|
||||
import tempfile
|
||||
|
|
@ -5,6 +6,7 @@ import unittest
|
|||
from pathlib import Path
|
||||
|
||||
MODULE_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "utils" / "factor_rules.py"
|
||||
GRAPH_SETUP_PATH = Path(__file__).resolve().parents[1] / "tradingagents" / "graph" / "setup.py"
|
||||
SPEC = importlib.util.spec_from_file_location("factor_rules", MODULE_PATH)
|
||||
factor_rules = importlib.util.module_from_spec(SPEC)
|
||||
SPEC.loader.exec_module(factor_rules)
|
||||
|
|
@ -375,5 +377,25 @@ class FactorRulesPathTests(unittest.TestCase):
|
|||
self.assertIn("- Conditions: ", summary)
|
||||
|
||||
|
||||
class GraphSetupSourceTests(unittest.TestCase):
|
||||
def test_setup_graph_avoids_mutable_default_selected_analysts(self):
|
||||
source = GRAPH_SETUP_PATH.read_text(encoding="utf-8")
|
||||
module = ast.parse(source)
|
||||
|
||||
setup_graph = None
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.ClassDef) and node.name == "GraphSetup":
|
||||
for item in node.body:
|
||||
if isinstance(item, ast.FunctionDef) and item.name == "setup_graph":
|
||||
setup_graph = item
|
||||
break
|
||||
|
||||
self.assertIsNotNone(setup_graph)
|
||||
self.assertEqual(len(setup_graph.args.defaults), 1)
|
||||
self.assertIsInstance(setup_graph.args.defaults[0], ast.Constant)
|
||||
self.assertIsNone(setup_graph.args.defaults[0].value)
|
||||
self.assertIn('selected_analysts = ["market", "social", "news", "fundamentals", "factor_rules"]', source)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# TradingAgents/graph/setup.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional, List
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
|
@ -38,7 +38,7 @@ class GraphSetup:
|
|||
self.conditional_logic = conditional_logic
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals", "factor_rules"]
|
||||
self, selected_analysts: Optional[List[str]] = None
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
|
|
@ -50,6 +50,9 @@ class GraphSetup:
|
|||
- "fundamentals": Fundamentals analyst
|
||||
- "factor_rules": Factor rule analyst
|
||||
"""
|
||||
if selected_analysts is None:
|
||||
selected_analysts = ["market", "social", "news", "fundamentals", "factor_rules"]
|
||||
|
||||
if len(selected_analysts) == 0:
|
||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue