271 lines
11 KiB
Python
271 lines
11 KiB
Python
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from datetime import datetime, timezone
|
||
|
||
from orchestrator.config import OrchestratorConfig
|
||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||
from orchestrator.contracts.result_contract import Signal, build_error_signal
|
||
from tradingagents.agents.utils.agent_states import extract_research_provenance
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Provider × base_url validation matrix
|
||
# Note: ollama/openrouter share openai's canonical provider but have different URL patterns
|
||
_PROVIDER_BASE_URL_PATTERNS = {
|
||
"anthropic": [r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"],
|
||
"openai": [r"api\.openai\.com"],
|
||
"google": [r"generativelanguage\.googleapis\.com"],
|
||
"xai": [r"api\.x\.ai"],
|
||
"ollama": [r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"],
|
||
"openrouter": [r"openrouter\.ai"],
|
||
}
|
||
|
||
# Precompile regex patterns for efficiency
|
||
_COMPILED_PATTERNS = {
|
||
provider: [re.compile(pattern) for pattern in patterns]
|
||
for provider, patterns in _PROVIDER_BASE_URL_PATTERNS.items()
|
||
}
|
||
|
||
# Recommended timeout thresholds by analyst count
|
||
_RECOMMENDED_TIMEOUTS = {
|
||
1: {"analyst": 75.0, "research": 30.0},
|
||
2: {"analyst": 90.0, "research": 45.0},
|
||
3: {"analyst": 105.0, "research": 60.0},
|
||
4: {"analyst": 120.0, "research": 75.0},
|
||
}
|
||
|
||
|
||
def _build_data_quality(state: str, **details):
|
||
payload = {"state": state}
|
||
payload.update({key: value for key, value in details.items() if value is not None})
|
||
return payload
|
||
|
||
|
||
def _extract_research_metadata(final_state: dict | None) -> dict | None:
|
||
if not isinstance(final_state, dict):
|
||
return None
|
||
debate_state = final_state.get("investment_debate_state") or {}
|
||
return extract_research_provenance(debate_state)
|
||
|
||
|
||
def _looks_like_provider_auth_failure(exc: Exception) -> bool:
|
||
text = str(exc).lower()
|
||
markers = (
|
||
"authentication_error",
|
||
"login fail",
|
||
"please carry the api secret key",
|
||
"invalid api key",
|
||
"unauthorized",
|
||
"error code: 401",
|
||
)
|
||
return any(marker in text for marker in markers)
|
||
|
||
|
||
class LLMRunner:
|
||
def __init__(self, config: OrchestratorConfig):
|
||
self._config = config
|
||
self._graph = None # Lazy-initialized on first get_signal() call (requires API key)
|
||
self.cache_dir = config.cache_dir
|
||
os.makedirs(self.cache_dir, exist_ok=True)
|
||
self._validate_timeout_config()
|
||
|
||
def _validate_timeout_config(self):
|
||
"""Warn if timeout configuration may be insufficient for selected analysts."""
|
||
trading_cfg = self._config.trading_agents_config or {}
|
||
selected_analysts = trading_cfg.get("selected_analysts", ["market", "social", "news", "fundamentals"])
|
||
analyst_count = len(selected_analysts) if selected_analysts else 4
|
||
|
||
analyst_timeout = float(trading_cfg.get("analyst_node_timeout_secs", 75.0))
|
||
research_timeout = float(trading_cfg.get("research_node_timeout_secs", 30.0))
|
||
|
||
# Get recommended thresholds (use max if analyst_count > 4)
|
||
recommended = _RECOMMENDED_TIMEOUTS.get(analyst_count, _RECOMMENDED_TIMEOUTS[4])
|
||
|
||
warnings = []
|
||
if analyst_timeout < recommended["analyst"]:
|
||
warnings.append(
|
||
f"analyst_node_timeout_secs={analyst_timeout:.1f}s may be insufficient "
|
||
f"for {analyst_count} analyst(s) (recommended: {recommended['analyst']:.1f}s)"
|
||
)
|
||
if research_timeout < recommended["research"]:
|
||
warnings.append(
|
||
f"research_node_timeout_secs={research_timeout:.1f}s may be insufficient "
|
||
f"for {analyst_count} analyst(s) (recommended: {recommended['research']:.1f}s)"
|
||
)
|
||
|
||
for warning in warnings:
|
||
logger.warning("LLMRunner: %s", warning)
|
||
|
||
def _get_graph(self):
|
||
"""Lazy-initialize TradingAgentsGraph (heavy, requires API key at init time)."""
|
||
if self._graph is None:
|
||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||
trading_cfg = self._config.trading_agents_config if self._config.trading_agents_config else None
|
||
graph_kwargs = {"config": trading_cfg}
|
||
if trading_cfg and "selected_analysts" in trading_cfg:
|
||
graph_kwargs["selected_analysts"] = trading_cfg["selected_analysts"]
|
||
self._graph = TradingAgentsGraph(**graph_kwargs)
|
||
return self._graph
|
||
|
||
def _detect_provider_mismatch(self):
|
||
"""Validate provider × base_url compatibility using pattern matrix.
|
||
|
||
Uses the original provider name (not canonical) for validation since
|
||
ollama/openrouter share openai's canonical provider but have different URLs.
|
||
"""
|
||
trading_cfg = self._config.trading_agents_config or {}
|
||
provider = str(trading_cfg.get("llm_provider", "")).lower()
|
||
base_url = str(trading_cfg.get("backend_url", "") or "").lower()
|
||
|
||
if not provider or not base_url:
|
||
return None
|
||
|
||
# Use original provider name for pattern matching (not canonical)
|
||
# This handles ollama/openrouter which share openai's canonical provider
|
||
compiled_patterns = _COMPILED_PATTERNS.get(provider, [])
|
||
if not compiled_patterns:
|
||
# No validation rules defined for this provider
|
||
return None
|
||
|
||
for pattern in compiled_patterns:
|
||
if pattern.search(base_url):
|
||
return None # Match found, no mismatch
|
||
|
||
# No pattern matched - return raw patterns for error message
|
||
return {
|
||
"provider": provider,
|
||
"backend_url": trading_cfg.get("backend_url"),
|
||
"expected_patterns": _PROVIDER_BASE_URL_PATTERNS[provider],
|
||
}
|
||
|
||
def get_signal(self, ticker: str, date: str) -> Signal:
|
||
"""获取指定股票在指定日期的 LLM 信号,带缓存。"""
|
||
# Validate configuration first (lightweight, prevents returning stale cache on config errors)
|
||
mismatch = self._detect_provider_mismatch()
|
||
if mismatch is not None:
|
||
return build_error_signal(
|
||
ticker=ticker,
|
||
source="llm",
|
||
reason_code=ReasonCode.PROVIDER_MISMATCH.value,
|
||
message=(
|
||
f"provider '{mismatch['provider']}' does not match backend_url "
|
||
f"'{mismatch['backend_url']}'"
|
||
),
|
||
metadata={
|
||
"data_quality": _build_data_quality("provider_mismatch", **mismatch),
|
||
},
|
||
)
|
||
|
||
# Check cache after validation
|
||
safe_ticker = ticker.replace("/", "_")
|
||
cache_path = os.path.join(self.cache_dir, f"{safe_ticker}_{date}.json")
|
||
|
||
try:
|
||
with open(cache_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
logger.info("LLMRunner: cache hit for %s %s", ticker, date)
|
||
return Signal(
|
||
ticker=ticker,
|
||
direction=data["direction"],
|
||
confidence=data["confidence"],
|
||
source="llm",
|
||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||
metadata=data,
|
||
)
|
||
except FileNotFoundError:
|
||
pass # Continue to LLM call
|
||
|
||
try:
|
||
_final_state, processed_signal = self._get_graph().propagate(ticker, date)
|
||
rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal)
|
||
direction, confidence = self._map_rating(rating)
|
||
now = datetime.now(timezone.utc)
|
||
research_metadata = _extract_research_metadata(_final_state)
|
||
if research_metadata and research_metadata.get("research_status") != "full":
|
||
data_quality = _build_data_quality(
|
||
"research_degraded",
|
||
research_status=research_metadata.get("research_status"),
|
||
research_mode=research_metadata.get("research_mode"),
|
||
degraded_reason=research_metadata.get("degraded_reason"),
|
||
timed_out_nodes=research_metadata.get("timed_out_nodes"),
|
||
)
|
||
else:
|
||
data_quality = _build_data_quality("ok")
|
||
|
||
cache_data = {
|
||
"rating": rating,
|
||
"direction": direction,
|
||
"confidence": confidence,
|
||
"timestamp": now.isoformat(),
|
||
"ticker": ticker,
|
||
"date": date,
|
||
"decision_structured": (
|
||
(_final_state or {}).get("final_trade_decision_structured")
|
||
if isinstance(_final_state, dict)
|
||
else None
|
||
),
|
||
"data_quality": data_quality,
|
||
"research": research_metadata,
|
||
"sample_quality": (
|
||
"degraded_research"
|
||
if research_metadata and research_metadata.get("research_status") != "full"
|
||
else "full_research"
|
||
),
|
||
}
|
||
with open(cache_path, "w", encoding="utf-8") as f:
|
||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||
|
||
return Signal(
|
||
ticker=ticker,
|
||
direction=direction,
|
||
confidence=confidence,
|
||
source="llm",
|
||
timestamp=now,
|
||
metadata=cache_data,
|
||
)
|
||
except Exception as e:
|
||
logger.error("LLMRunner: propagate failed for %s %s: %s", ticker, date, e)
|
||
reason_code = ReasonCode.LLM_SIGNAL_FAILED.value
|
||
if "Unsupported LLM provider" in str(e):
|
||
reason_code = ReasonCode.PROVIDER_MISMATCH.value
|
||
elif _looks_like_provider_auth_failure(e):
|
||
reason_code = ReasonCode.PROVIDER_AUTH_FAILED.value
|
||
|
||
# Map reason code to data quality state
|
||
state_map = {
|
||
ReasonCode.PROVIDER_MISMATCH.value: "provider_mismatch",
|
||
ReasonCode.PROVIDER_AUTH_FAILED.value: "provider_auth_failed",
|
||
}
|
||
state = state_map.get(reason_code, "unknown")
|
||
|
||
return build_error_signal(
|
||
ticker=ticker,
|
||
source="llm",
|
||
reason_code=reason_code,
|
||
message=str(e),
|
||
metadata={
|
||
"data_quality": _build_data_quality(
|
||
state,
|
||
provider=(self._config.trading_agents_config or {}).get("llm_provider"),
|
||
backend_url=(self._config.trading_agents_config or {}).get("backend_url"),
|
||
),
|
||
},
|
||
)
|
||
|
||
def _map_rating(self, rating: str) -> tuple[int, float]:
|
||
"""将 5 级评级映射为 (direction, confidence)。"""
|
||
mapping = {
|
||
"BUY": (1, 0.9),
|
||
"OVERWEIGHT": (1, 0.6),
|
||
"HOLD": (0, 0.5),
|
||
"UNDERWEIGHT": (-1, 0.6),
|
||
"SELL": (-1, 0.9),
|
||
}
|
||
result = mapping.get(rating.upper() if rating else "", None)
|
||
if result is None:
|
||
logger.warning("LLMRunner: unknown rating %r, falling back to HOLD", rating)
|
||
return (0, 0.5)
|
||
return result
|