TradingAgents/orchestrator/llm_runner.py

271 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import os
import re
from datetime import datetime, timezone
from orchestrator.config import OrchestratorConfig
from orchestrator.contracts.error_taxonomy import ReasonCode
from orchestrator.contracts.result_contract import Signal, build_error_signal
from tradingagents.agents.utils.agent_states import extract_research_provenance
logger = logging.getLogger(__name__)
# Provider × base_url validation matrix
# Note: ollama/openrouter share openai's canonical provider but have different URL patterns
_PROVIDER_BASE_URL_PATTERNS = {
"anthropic": [r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"],
"openai": [r"api\.openai\.com"],
"google": [r"generativelanguage\.googleapis\.com"],
"xai": [r"api\.x\.ai"],
"ollama": [r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"],
"openrouter": [r"openrouter\.ai"],
}
# Precompile regex patterns for efficiency
_COMPILED_PATTERNS = {
provider: [re.compile(pattern) for pattern in patterns]
for provider, patterns in _PROVIDER_BASE_URL_PATTERNS.items()
}
# Recommended timeout thresholds by analyst count
_RECOMMENDED_TIMEOUTS = {
1: {"analyst": 75.0, "research": 30.0},
2: {"analyst": 90.0, "research": 45.0},
3: {"analyst": 105.0, "research": 60.0},
4: {"analyst": 120.0, "research": 75.0},
}
def _build_data_quality(state: str, **details):
payload = {"state": state}
payload.update({key: value for key, value in details.items() if value is not None})
return payload
def _extract_research_metadata(final_state: dict | None) -> dict | None:
if not isinstance(final_state, dict):
return None
debate_state = final_state.get("investment_debate_state") or {}
return extract_research_provenance(debate_state)
def _looks_like_provider_auth_failure(exc: Exception) -> bool:
text = str(exc).lower()
markers = (
"authentication_error",
"login fail",
"please carry the api secret key",
"invalid api key",
"unauthorized",
"error code: 401",
)
return any(marker in text for marker in markers)
class LLMRunner:
def __init__(self, config: OrchestratorConfig):
self._config = config
self._graph = None # Lazy-initialized on first get_signal() call (requires API key)
self.cache_dir = config.cache_dir
os.makedirs(self.cache_dir, exist_ok=True)
self._validate_timeout_config()
def _validate_timeout_config(self):
"""Warn if timeout configuration may be insufficient for selected analysts."""
trading_cfg = self._config.trading_agents_config or {}
selected_analysts = trading_cfg.get("selected_analysts", ["market", "social", "news", "fundamentals"])
analyst_count = len(selected_analysts) if selected_analysts else 4
analyst_timeout = float(trading_cfg.get("analyst_node_timeout_secs", 75.0))
research_timeout = float(trading_cfg.get("research_node_timeout_secs", 30.0))
# Get recommended thresholds (use max if analyst_count > 4)
recommended = _RECOMMENDED_TIMEOUTS.get(analyst_count, _RECOMMENDED_TIMEOUTS[4])
warnings = []
if analyst_timeout < recommended["analyst"]:
warnings.append(
f"analyst_node_timeout_secs={analyst_timeout:.1f}s may be insufficient "
f"for {analyst_count} analyst(s) (recommended: {recommended['analyst']:.1f}s)"
)
if research_timeout < recommended["research"]:
warnings.append(
f"research_node_timeout_secs={research_timeout:.1f}s may be insufficient "
f"for {analyst_count} analyst(s) (recommended: {recommended['research']:.1f}s)"
)
for warning in warnings:
logger.warning("LLMRunner: %s", warning)
def _get_graph(self):
"""Lazy-initialize TradingAgentsGraph (heavy, requires API key at init time)."""
if self._graph is None:
from tradingagents.graph.trading_graph import TradingAgentsGraph
trading_cfg = self._config.trading_agents_config if self._config.trading_agents_config else None
graph_kwargs = {"config": trading_cfg}
if trading_cfg and "selected_analysts" in trading_cfg:
graph_kwargs["selected_analysts"] = trading_cfg["selected_analysts"]
self._graph = TradingAgentsGraph(**graph_kwargs)
return self._graph
def _detect_provider_mismatch(self):
"""Validate provider × base_url compatibility using pattern matrix.
Uses the original provider name (not canonical) for validation since
ollama/openrouter share openai's canonical provider but have different URLs.
"""
trading_cfg = self._config.trading_agents_config or {}
provider = str(trading_cfg.get("llm_provider", "")).lower()
base_url = str(trading_cfg.get("backend_url", "") or "").lower()
if not provider or not base_url:
return None
# Use original provider name for pattern matching (not canonical)
# This handles ollama/openrouter which share openai's canonical provider
compiled_patterns = _COMPILED_PATTERNS.get(provider, [])
if not compiled_patterns:
# No validation rules defined for this provider
return None
for pattern in compiled_patterns:
if pattern.search(base_url):
return None # Match found, no mismatch
# No pattern matched - return raw patterns for error message
return {
"provider": provider,
"backend_url": trading_cfg.get("backend_url"),
"expected_patterns": _PROVIDER_BASE_URL_PATTERNS[provider],
}
def get_signal(self, ticker: str, date: str) -> Signal:
"""获取指定股票在指定日期的 LLM 信号,带缓存。"""
# Validate configuration first (lightweight, prevents returning stale cache on config errors)
mismatch = self._detect_provider_mismatch()
if mismatch is not None:
return build_error_signal(
ticker=ticker,
source="llm",
reason_code=ReasonCode.PROVIDER_MISMATCH.value,
message=(
f"provider '{mismatch['provider']}' does not match backend_url "
f"'{mismatch['backend_url']}'"
),
metadata={
"data_quality": _build_data_quality("provider_mismatch", **mismatch),
},
)
# Check cache after validation
safe_ticker = ticker.replace("/", "_")
cache_path = os.path.join(self.cache_dir, f"{safe_ticker}_{date}.json")
try:
with open(cache_path, "r", encoding="utf-8") as f:
data = json.load(f)
logger.info("LLMRunner: cache hit for %s %s", ticker, date)
return Signal(
ticker=ticker,
direction=data["direction"],
confidence=data["confidence"],
source="llm",
timestamp=datetime.fromisoformat(data["timestamp"]),
metadata=data,
)
except FileNotFoundError:
pass # Continue to LLM call
try:
_final_state, processed_signal = self._get_graph().propagate(ticker, date)
rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal)
direction, confidence = self._map_rating(rating)
now = datetime.now(timezone.utc)
research_metadata = _extract_research_metadata(_final_state)
if research_metadata and research_metadata.get("research_status") != "full":
data_quality = _build_data_quality(
"research_degraded",
research_status=research_metadata.get("research_status"),
research_mode=research_metadata.get("research_mode"),
degraded_reason=research_metadata.get("degraded_reason"),
timed_out_nodes=research_metadata.get("timed_out_nodes"),
)
else:
data_quality = _build_data_quality("ok")
cache_data = {
"rating": rating,
"direction": direction,
"confidence": confidence,
"timestamp": now.isoformat(),
"ticker": ticker,
"date": date,
"decision_structured": (
(_final_state or {}).get("final_trade_decision_structured")
if isinstance(_final_state, dict)
else None
),
"data_quality": data_quality,
"research": research_metadata,
"sample_quality": (
"degraded_research"
if research_metadata and research_metadata.get("research_status") != "full"
else "full_research"
),
}
with open(cache_path, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
return Signal(
ticker=ticker,
direction=direction,
confidence=confidence,
source="llm",
timestamp=now,
metadata=cache_data,
)
except Exception as e:
logger.error("LLMRunner: propagate failed for %s %s: %s", ticker, date, e)
reason_code = ReasonCode.LLM_SIGNAL_FAILED.value
if "Unsupported LLM provider" in str(e):
reason_code = ReasonCode.PROVIDER_MISMATCH.value
elif _looks_like_provider_auth_failure(e):
reason_code = ReasonCode.PROVIDER_AUTH_FAILED.value
# Map reason code to data quality state
state_map = {
ReasonCode.PROVIDER_MISMATCH.value: "provider_mismatch",
ReasonCode.PROVIDER_AUTH_FAILED.value: "provider_auth_failed",
}
state = state_map.get(reason_code, "unknown")
return build_error_signal(
ticker=ticker,
source="llm",
reason_code=reason_code,
message=str(e),
metadata={
"data_quality": _build_data_quality(
state,
provider=(self._config.trading_agents_config or {}).get("llm_provider"),
backend_url=(self._config.trading_agents_config or {}).get("backend_url"),
),
},
)
def _map_rating(self, rating: str) -> tuple[int, float]:
"""将 5 级评级映射为 (direction, confidence)。"""
mapping = {
"BUY": (1, 0.9),
"OVERWEIGHT": (1, 0.6),
"HOLD": (0, 0.5),
"UNDERWEIGHT": (-1, 0.6),
"SELL": (-1, 0.9),
}
result = mapping.get(rating.upper() if rating else "", None)
if result is None:
logger.warning("LLMRunner: unknown rating %r, falling back to HOLD", rating)
return (0, 0.5)
return result