fix: make default llm routing shallow-copy safe
This commit is contained in:
parent
73e8974182
commit
27d60e1300
|
|
@ -300,6 +300,18 @@ def test_google_client_passes_base_url_to_langchain(monkeypatch):
|
|||
assert captured_kwargs["base_url"] == "https://google.example/v1beta"
|
||||
|
||||
|
||||
def test_default_config_copy_does_not_share_mutable_llm_routing_state():
|
||||
config = default_config.DEFAULT_CONFIG.copy()
|
||||
|
||||
assert config["llm_routing"] is None
|
||||
config["llm_routing"] = {
|
||||
"default": {},
|
||||
"roles": {"portfolio_manager": {"model": "gpt-5.2"}},
|
||||
}
|
||||
|
||||
assert default_config.DEFAULT_CONFIG["llm_routing"] is None
|
||||
|
||||
|
||||
def test_dataflow_config_returns_isolated_nested_routing(monkeypatch):
|
||||
monkeypatch.setattr(dataflow_config, "_config", None)
|
||||
dataflow_config.initialize_config()
|
||||
|
|
@ -311,7 +323,7 @@ def test_dataflow_config_returns_isolated_nested_routing(monkeypatch):
|
|||
}
|
||||
|
||||
assert dataflow_config.get_config()["llm_routing"]["roles"] == {}
|
||||
assert default_config.DEFAULT_CONFIG["llm_routing"]["roles"] == {}
|
||||
assert default_config.DEFAULT_CONFIG["llm_routing"] is None
|
||||
|
||||
|
||||
def test_log_state_writes_json_snapshot(tmp_path, monkeypatch):
|
||||
|
|
|
|||
|
|
@ -20,20 +20,22 @@ def initialize_config():
|
|||
"""Initialize the configuration with default values."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||
_config = default_config.normalize_llm_routing(default_config.DEFAULT_CONFIG)
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
"""Update the configuration with custom values."""
|
||||
global _config
|
||||
_config = _deep_merge_dicts(default_config.DEFAULT_CONFIG, config)
|
||||
_config = default_config.normalize_llm_routing(
|
||||
_deep_merge_dicts(default_config.DEFAULT_CONFIG, config)
|
||||
)
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
"""Get the current configuration."""
|
||||
if _config is None:
|
||||
initialize_config()
|
||||
return deepcopy(_config)
|
||||
return default_config.normalize_llm_routing(_config)
|
||||
|
||||
|
||||
# Initialize with default config
|
||||
|
|
|
|||
|
|
@ -1,5 +1,24 @@
|
|||
from copy import deepcopy
|
||||
import os
|
||||
|
||||
|
||||
def normalize_llm_routing(config):
|
||||
"""Return config with llm_routing normalized to the expected dict shape."""
|
||||
normalized = deepcopy(config)
|
||||
llm_routing = normalized.get("llm_routing")
|
||||
|
||||
if not isinstance(llm_routing, dict):
|
||||
normalized["llm_routing"] = {"default": {}, "roles": {}}
|
||||
return normalized
|
||||
|
||||
default_route = llm_routing.get("default")
|
||||
role_routes = llm_routing.get("roles")
|
||||
normalized["llm_routing"] = {
|
||||
"default": deepcopy(default_route) if isinstance(default_route, dict) else {},
|
||||
"roles": deepcopy(role_routes) if isinstance(role_routes, dict) else {},
|
||||
}
|
||||
return normalized
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
|
|
@ -12,10 +31,7 @@ DEFAULT_CONFIG = {
|
|||
"deep_think_llm": "gpt-5.2",
|
||||
"quick_think_llm": "gpt-5-mini",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
"llm_routing": {
|
||||
"default": {},
|
||||
"roles": {},
|
||||
},
|
||||
"llm_routing": None,
|
||||
# Provider-specific thinking configuration
|
||||
"google_thinking_level": None, # "high", "minimal", etc.
|
||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from langgraph.prebuilt import ToolNode
|
|||
from tradingagents.llm_clients import create_llm_client
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.default_config import DEFAULT_CONFIG, normalize_llm_routing
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
|
|
@ -144,7 +144,7 @@ class TradingAgentsGraph:
|
|||
|
||||
def _build_config(self, config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Merge user config over defaults without mutating the shared defaults."""
|
||||
return self._deep_merge_dicts(DEFAULT_CONFIG, config or {})
|
||||
return normalize_llm_routing(self._deep_merge_dicts(DEFAULT_CONFIG, config or {}))
|
||||
|
||||
def _normalize_provider(self, provider: Optional[str]) -> str:
|
||||
return (provider or "").lower()
|
||||
|
|
|
|||
Loading…
Reference in New Issue