TradingAgents/api/services/run_service.py

105 lines
4.5 KiB
Python

import logging
from collections import defaultdict
from typing import Generator
from api.store.runs_store import RunsStore
from api.models.run import RunConfig, RunStatus
from api.callbacks.token_handler import TokenCallbackHandler
logger = logging.getLogger(__name__)
try:
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
except ImportError:
TradingAgentsGraph = None # type: ignore
DEFAULT_CONFIG = {}
class RunService:
def __init__(self, store: RunsStore):
self._store = store
def stream_events(self, run_id: str) -> Generator[dict, None, None]:
run = self._store.get(run_id)
if not run or not run.config:
yield {"event": "run:error", "data": {"message": "Run not found"}}
return
if run.status == RunStatus.COMPLETE:
token_usage = {k: v.model_dump() for k, v in (run.token_usage or {}).items()}
for key, report in run.reports.items():
if ":" not in key:
logger.warning(
"Skipping malformed report key %r for run %s", key, run_id
)
continue
step_key, turn_str = key.rsplit(":", 1)
if not turn_str.isdigit():
logger.warning(
"Skipping report key with non-numeric turn %r for run %s", key, run_id
)
continue
turn = int(turn_str)
raw = token_usage.get(key, {"tokens_in": 0, "tokens_out": 0})
yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}}
yield {"event": "agent:complete", "data": {
"step": step_key, "turn": turn, "report": report,
"tokens_in": raw.get("tokens_in", 0),
"tokens_out": raw.get("tokens_out", 0),
}}
yield {"event": "run:complete", "data": {"decision": run.decision or "HOLD", "run_id": run_id}}
return
if run.status == RunStatus.RUNNING:
yield {"event": "run:error", "data": {"message": "Run is already in progress"}}
return
self._store.clear_reports(run_id)
self._store.clear_token_usage(run_id)
self._store.update_status(run_id, RunStatus.RUNNING)
config = run.config
ta_config = DEFAULT_CONFIG.copy()
ta_config["llm_provider"] = config.llm_provider
ta_config["deep_think_llm"] = config.deep_think_llm
ta_config["quick_think_llm"] = config.quick_think_llm
ta_config["max_debate_rounds"] = config.max_debate_rounds
ta_config["max_risk_discuss_rounds"] = config.max_risk_discuss_rounds
try:
token_handler = TokenCallbackHandler()
ta = TradingAgentsGraph(
debug=False,
config=ta_config,
selected_analysts=config.enabled_analysts or
["market", "news", "fundamentals", "social"],
callbacks=[token_handler],
)
turn_counts: defaultdict[str, int] = defaultdict(int)
for step_key, report in ta.stream_propagate(config.ticker, config.date):
tokens = token_handler.snapshot_and_reset()
turn = turn_counts[step_key]
yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}}
yield {"event": "agent:complete", "data": {
"step": step_key, "turn": turn, "report": report,
"tokens_in": tokens["in"], "tokens_out": tokens["out"],
}}
self._store.add_report(run_id, f"{step_key}:{turn}", report)
# Normalize to TokenUsage field names before persisting
self._store.add_token_usage(
run_id, f"{step_key}:{turn}",
{"tokens_in": tokens["in"], "tokens_out": tokens["out"]},
)
turn_counts[step_key] += 1
decision = ta._last_decision or "HOLD"
self._store.update_decision(run_id, decision)
self._store.update_status(run_id, RunStatus.COMPLETE)
yield {"event": "run:complete", "data": {"decision": decision, "run_id": run_id}}
except Exception as e:
self._store.set_error(run_id, str(e))
yield {"event": "run:error", "data": {"message": str(e)}}