TradingAgents/orchestrator/contracts/config_schema.py

175 lines
5.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Mapping, Optional, TypedDict, cast
from tradingagents.default_config import get_default_config
CONTRACT_VERSION = "v1alpha1"
class TradingAgentsConfigPayload(TypedDict, total=False):
project_dir: str
results_dir: str
data_cache_dir: str
llm_provider: str
deep_think_llm: str
quick_think_llm: str
backend_url: str
google_thinking_level: Optional[str]
openai_reasoning_effort: Optional[str]
anthropic_effort: Optional[str]
output_language: str
portfolio_context: str
peer_context: str
peer_context_mode: str
max_debate_rounds: int
max_risk_discuss_rounds: int
max_recur_limit: int
analyst_node_timeout_secs: float
data_vendors: dict[str, str]
tool_vendors: dict[str, str]
selected_analysts: list[str]
llm_timeout: float
llm_max_retries: int
minimax_retry_attempts: int
minimax_retry_base_delay: float
timeout: float
max_retries: int
use_responses_api: bool
REQUIRED_TRADING_CONFIG_KEYS = (
"project_dir",
"results_dir",
"data_cache_dir",
"llm_provider",
"deep_think_llm",
"quick_think_llm",
)
def _validate_probability(name: str, value: Any) -> float:
if not isinstance(value, (int, float)):
raise TypeError(f"{name} must be a number")
if not 0.0 <= float(value) <= 1.0:
raise ValueError(f"{name} must be between 0.0 and 1.0")
return float(value)
def _validate_positive_int(name: str, value: Any) -> int:
if not isinstance(value, int):
raise TypeError(f"{name} must be an int")
if value <= 0:
raise ValueError(f"{name} must be > 0")
return value
def _validate_string_map(name: str, value: Any) -> dict[str, str]:
if not isinstance(value, Mapping):
raise TypeError(f"{name} must be a mapping")
normalized = {}
for key, item in value.items():
if not isinstance(key, str) or not isinstance(item, str):
raise TypeError(f"{name} keys and values must be strings")
normalized[key] = item
return normalized
def build_trading_agents_config(
overrides: Optional[Mapping[str, Any]],
) -> TradingAgentsConfigPayload:
merged: dict[str, Any] = get_default_config()
if overrides:
if not isinstance(overrides, Mapping):
raise TypeError("trading_agents_config must be a mapping")
for key, value in overrides.items():
if (
key in ("data_vendors", "tool_vendors")
and value is not None
):
merged[key] = _validate_string_map(key, value)
elif key == "selected_analysts" and value is not None:
if not isinstance(value, list) or any(
not isinstance(item, str) for item in value
):
raise TypeError("selected_analysts must be a list of strings")
merged[key] = list(value)
else:
merged[key] = value
for key in REQUIRED_TRADING_CONFIG_KEYS:
value = merged.get(key)
if not isinstance(value, str) or not value.strip():
raise ValueError(f"trading_agents_config.{key} must be a non-empty string")
merged["data_vendors"] = _validate_string_map("data_vendors", merged["data_vendors"])
merged["tool_vendors"] = _validate_string_map("tool_vendors", merged["tool_vendors"])
return cast(TradingAgentsConfigPayload, merged)
@dataclass(frozen=True)
class OrchestratorConfigSchema:
quant_backtest_path: str = ""
trading_agents_config: TradingAgentsConfigPayload = field(
default_factory=lambda: build_trading_agents_config(None)
)
quant_weight_cap: float = 0.8
llm_weight_cap: float = 0.9
llm_batch_days: int = 7
cache_dir: str = "orchestrator/cache"
llm_solo_penalty: float = 0.7
quant_solo_penalty: float = 0.8
contract_version: str = CONTRACT_VERSION
def to_runtime_fields(self) -> dict[str, Any]:
return {
"quant_backtest_path": self.quant_backtest_path,
"trading_agents_config": dict(self.trading_agents_config),
"quant_weight_cap": self.quant_weight_cap,
"llm_weight_cap": self.llm_weight_cap,
"llm_batch_days": self.llm_batch_days,
"cache_dir": self.cache_dir,
"llm_solo_penalty": self.llm_solo_penalty,
"quant_solo_penalty": self.quant_solo_penalty,
}
def build_orchestrator_schema(raw: Mapping[str, Any]) -> OrchestratorConfigSchema:
if not isinstance(raw, Mapping):
raise TypeError("orchestrator config must be a mapping")
quant_backtest_path = raw.get("quant_backtest_path", "")
if not isinstance(quant_backtest_path, str):
raise TypeError("quant_backtest_path must be a string")
cache_dir = raw.get("cache_dir", "orchestrator/cache")
if not isinstance(cache_dir, str) or not cache_dir.strip():
raise ValueError("cache_dir must be a non-empty string")
return OrchestratorConfigSchema(
quant_backtest_path=quant_backtest_path,
trading_agents_config=build_trading_agents_config(
cast(Optional[Mapping[str, Any]], raw.get("trading_agents_config"))
),
quant_weight_cap=_validate_probability(
"quant_weight_cap", raw.get("quant_weight_cap", 0.8)
),
llm_weight_cap=_validate_probability(
"llm_weight_cap", raw.get("llm_weight_cap", 0.9)
),
llm_batch_days=_validate_positive_int(
"llm_batch_days", raw.get("llm_batch_days", 7)
),
cache_dir=cache_dir,
llm_solo_penalty=_validate_probability(
"llm_solo_penalty", raw.get("llm_solo_penalty", 0.7)
),
quant_solo_penalty=_validate_probability(
"quant_solo_penalty", raw.get("quant_solo_penalty", 0.8)
),
)