TradingAgents/tradingagents/schemas/decision.py

183 lines
6.0 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass
from enum import Enum
from json import JSONDecodeError
from typing import Any, Mapping
class StructuredDecisionValidationError(ValueError):
"""Raised when a decision payload does not match the required schema."""
class DecisionRating(str, Enum):
BUY = "BUY"
OVERWEIGHT = "OVERWEIGHT"
HOLD = "HOLD"
UNDERWEIGHT = "UNDERWEIGHT"
SELL = "SELL"
NO_TRADE = "NO_TRADE"
class TimeHorizon(str, Enum):
SHORT = "short"
MEDIUM = "medium"
LONG = "long"
@dataclass(frozen=True)
class StructuredDecision:
rating: DecisionRating
confidence: float
time_horizon: TimeHorizon
entry_logic: str
exit_logic: str
position_sizing: str
risk_limits: str
catalysts: tuple[str, ...]
invalidators: tuple[str, ...]
def to_dict(self) -> dict[str, Any]:
return {
"rating": self.rating.value,
"confidence": self.confidence,
"time_horizon": self.time_horizon.value,
"entry_logic": self.entry_logic,
"exit_logic": self.exit_logic,
"position_sizing": self.position_sizing,
"risk_limits": self.risk_limits,
"catalysts": list(self.catalysts),
"invalidators": list(self.invalidators),
}
def to_json(self, *, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
def build_decision_output_instructions(context: str) -> str:
return (
f"Return only one valid JSON object for the {context}. "
"Do not wrap it in markdown fences. "
"The schema is: "
'{"rating":"BUY | OVERWEIGHT | HOLD | UNDERWEIGHT | SELL | NO_TRADE",'
'"confidence":0.0,'
'"time_horizon":"short | medium | long",'
'"entry_logic":"...",'
'"exit_logic":"...",'
'"position_sizing":"...",'
'"risk_limits":"...",'
'"catalysts":["..."],'
'"invalidators":["..."]}. '
"Use an uppercase rating, confidence between 0 and 1 inclusive, and concise but specific strings."
)
def _extract_json_object(payload: str | Mapping[str, Any]) -> Mapping[str, Any]:
if isinstance(payload, Mapping):
return payload
if not isinstance(payload, str) or not payload.strip():
raise StructuredDecisionValidationError("Decision payload must be a non-empty JSON string or mapping.")
text = payload.strip()
if text.startswith("```"):
lines = text.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
text = "\n".join(lines[1:-1]).strip()
try:
parsed = json.loads(text)
if isinstance(parsed, Mapping):
return parsed
except JSONDecodeError:
pass
decoder = json.JSONDecoder()
for index, char in enumerate(text):
if char != "{":
continue
try:
parsed, _ = decoder.raw_decode(text[index:])
except JSONDecodeError:
continue
if isinstance(parsed, Mapping):
return parsed
raise StructuredDecisionValidationError("Could not locate a valid JSON object in the decision payload.")
def _require_string(data: Mapping[str, Any], field_name: str) -> str:
value = data.get(field_name)
if not isinstance(value, str) or not value.strip():
raise StructuredDecisionValidationError(f"Field '{field_name}' must be a non-empty string.")
return value.strip()
def _require_string_list(data: Mapping[str, Any], field_name: str) -> tuple[str, ...]:
value = data.get(field_name)
if not isinstance(value, list):
raise StructuredDecisionValidationError(f"Field '{field_name}' must be a list of strings.")
normalized: list[str] = []
for item in value:
if not isinstance(item, str) or not item.strip():
raise StructuredDecisionValidationError(
f"Field '{field_name}' must contain only non-empty strings."
)
normalized.append(item.strip())
return tuple(normalized)
def parse_structured_decision(payload: str | Mapping[str, Any]) -> StructuredDecision:
data = _extract_json_object(payload)
missing_fields = {
"rating",
"confidence",
"time_horizon",
"entry_logic",
"exit_logic",
"position_sizing",
"risk_limits",
"catalysts",
"invalidators",
} - set(data.keys())
if missing_fields:
missing = ", ".join(sorted(missing_fields))
raise StructuredDecisionValidationError(f"Decision payload is missing required fields: {missing}.")
try:
rating = DecisionRating(str(data["rating"]).strip().upper())
except ValueError as exc:
raise StructuredDecisionValidationError(f"Unsupported rating: {data.get('rating')!r}.") from exc
try:
confidence = float(data["confidence"])
except (TypeError, ValueError) as exc:
raise StructuredDecisionValidationError("Field 'confidence' must be numeric.") from exc
if not 0.0 <= confidence <= 1.0:
raise StructuredDecisionValidationError("Field 'confidence' must be between 0 and 1 inclusive.")
try:
time_horizon = TimeHorizon(str(data["time_horizon"]).strip().lower())
except ValueError as exc:
raise StructuredDecisionValidationError(
f"Unsupported time horizon: {data.get('time_horizon')!r}."
) from exc
return StructuredDecision(
rating=rating,
confidence=confidence,
time_horizon=time_horizon,
entry_logic=_require_string(data, "entry_logic"),
exit_logic=_require_string(data, "exit_logic"),
position_sizing=_require_string(data, "position_sizing"),
risk_limits=_require_string(data, "risk_limits"),
catalysts=_require_string_list(data, "catalysts"),
invalidators=_require_string_list(data, "invalidators"),
)
def ensure_structured_decision_json(payload: str | Mapping[str, Any]) -> str:
return parse_structured_decision(payload).to_json()