175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Mapping, Optional, TypedDict, cast
|
|
|
|
from tradingagents.default_config import get_default_config
|
|
|
|
|
|
CONTRACT_VERSION = "v1alpha1"
|
|
|
|
|
|
class TradingAgentsConfigPayload(TypedDict, total=False):
|
|
project_dir: str
|
|
results_dir: str
|
|
data_cache_dir: str
|
|
llm_provider: str
|
|
deep_think_llm: str
|
|
quick_think_llm: str
|
|
backend_url: str
|
|
google_thinking_level: Optional[str]
|
|
openai_reasoning_effort: Optional[str]
|
|
anthropic_effort: Optional[str]
|
|
output_language: str
|
|
portfolio_context: str
|
|
peer_context: str
|
|
peer_context_mode: str
|
|
max_debate_rounds: int
|
|
max_risk_discuss_rounds: int
|
|
max_recur_limit: int
|
|
analyst_node_timeout_secs: float
|
|
data_vendors: dict[str, str]
|
|
tool_vendors: dict[str, str]
|
|
selected_analysts: list[str]
|
|
llm_timeout: float
|
|
llm_max_retries: int
|
|
minimax_retry_attempts: int
|
|
minimax_retry_base_delay: float
|
|
timeout: float
|
|
max_retries: int
|
|
use_responses_api: bool
|
|
|
|
|
|
REQUIRED_TRADING_CONFIG_KEYS = (
|
|
"project_dir",
|
|
"results_dir",
|
|
"data_cache_dir",
|
|
"llm_provider",
|
|
"deep_think_llm",
|
|
"quick_think_llm",
|
|
)
|
|
|
|
|
|
def _validate_probability(name: str, value: Any) -> float:
|
|
if not isinstance(value, (int, float)):
|
|
raise TypeError(f"{name} must be a number")
|
|
if not 0.0 <= float(value) <= 1.0:
|
|
raise ValueError(f"{name} must be between 0.0 and 1.0")
|
|
return float(value)
|
|
|
|
|
|
def _validate_positive_int(name: str, value: Any) -> int:
|
|
if not isinstance(value, int):
|
|
raise TypeError(f"{name} must be an int")
|
|
if value <= 0:
|
|
raise ValueError(f"{name} must be > 0")
|
|
return value
|
|
|
|
|
|
def _validate_string_map(name: str, value: Any) -> dict[str, str]:
|
|
if not isinstance(value, Mapping):
|
|
raise TypeError(f"{name} must be a mapping")
|
|
normalized = {}
|
|
for key, item in value.items():
|
|
if not isinstance(key, str) or not isinstance(item, str):
|
|
raise TypeError(f"{name} keys and values must be strings")
|
|
normalized[key] = item
|
|
return normalized
|
|
|
|
|
|
def build_trading_agents_config(
|
|
overrides: Optional[Mapping[str, Any]],
|
|
) -> TradingAgentsConfigPayload:
|
|
merged: dict[str, Any] = get_default_config()
|
|
|
|
if overrides:
|
|
if not isinstance(overrides, Mapping):
|
|
raise TypeError("trading_agents_config must be a mapping")
|
|
for key, value in overrides.items():
|
|
if (
|
|
key in ("data_vendors", "tool_vendors")
|
|
and value is not None
|
|
):
|
|
merged[key] = _validate_string_map(key, value)
|
|
elif key == "selected_analysts" and value is not None:
|
|
if not isinstance(value, list) or any(
|
|
not isinstance(item, str) for item in value
|
|
):
|
|
raise TypeError("selected_analysts must be a list of strings")
|
|
merged[key] = list(value)
|
|
else:
|
|
merged[key] = value
|
|
|
|
for key in REQUIRED_TRADING_CONFIG_KEYS:
|
|
value = merged.get(key)
|
|
if not isinstance(value, str) or not value.strip():
|
|
raise ValueError(f"trading_agents_config.{key} must be a non-empty string")
|
|
|
|
merged["data_vendors"] = _validate_string_map("data_vendors", merged["data_vendors"])
|
|
merged["tool_vendors"] = _validate_string_map("tool_vendors", merged["tool_vendors"])
|
|
|
|
return cast(TradingAgentsConfigPayload, merged)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class OrchestratorConfigSchema:
|
|
quant_backtest_path: str = ""
|
|
trading_agents_config: TradingAgentsConfigPayload = field(
|
|
default_factory=lambda: build_trading_agents_config(None)
|
|
)
|
|
quant_weight_cap: float = 0.8
|
|
llm_weight_cap: float = 0.9
|
|
llm_batch_days: int = 7
|
|
cache_dir: str = "orchestrator/cache"
|
|
llm_solo_penalty: float = 0.7
|
|
quant_solo_penalty: float = 0.8
|
|
contract_version: str = CONTRACT_VERSION
|
|
|
|
def to_runtime_fields(self) -> dict[str, Any]:
|
|
return {
|
|
"quant_backtest_path": self.quant_backtest_path,
|
|
"trading_agents_config": dict(self.trading_agents_config),
|
|
"quant_weight_cap": self.quant_weight_cap,
|
|
"llm_weight_cap": self.llm_weight_cap,
|
|
"llm_batch_days": self.llm_batch_days,
|
|
"cache_dir": self.cache_dir,
|
|
"llm_solo_penalty": self.llm_solo_penalty,
|
|
"quant_solo_penalty": self.quant_solo_penalty,
|
|
}
|
|
|
|
|
|
def build_orchestrator_schema(raw: Mapping[str, Any]) -> OrchestratorConfigSchema:
|
|
if not isinstance(raw, Mapping):
|
|
raise TypeError("orchestrator config must be a mapping")
|
|
|
|
quant_backtest_path = raw.get("quant_backtest_path", "")
|
|
if not isinstance(quant_backtest_path, str):
|
|
raise TypeError("quant_backtest_path must be a string")
|
|
|
|
cache_dir = raw.get("cache_dir", "orchestrator/cache")
|
|
if not isinstance(cache_dir, str) or not cache_dir.strip():
|
|
raise ValueError("cache_dir must be a non-empty string")
|
|
|
|
return OrchestratorConfigSchema(
|
|
quant_backtest_path=quant_backtest_path,
|
|
trading_agents_config=build_trading_agents_config(
|
|
cast(Optional[Mapping[str, Any]], raw.get("trading_agents_config"))
|
|
),
|
|
quant_weight_cap=_validate_probability(
|
|
"quant_weight_cap", raw.get("quant_weight_cap", 0.8)
|
|
),
|
|
llm_weight_cap=_validate_probability(
|
|
"llm_weight_cap", raw.get("llm_weight_cap", 0.9)
|
|
),
|
|
llm_batch_days=_validate_positive_int(
|
|
"llm_batch_days", raw.get("llm_batch_days", 7)
|
|
),
|
|
cache_dir=cache_dir,
|
|
llm_solo_penalty=_validate_probability(
|
|
"llm_solo_penalty", raw.get("llm_solo_penalty", 0.7)
|
|
),
|
|
quant_solo_penalty=_validate_probability(
|
|
"quant_solo_penalty", raw.get("quant_solo_penalty", 0.8)
|
|
),
|
|
)
|