105 lines
4.5 KiB
Python
105 lines
4.5 KiB
Python
import logging
|
|
from collections import defaultdict
|
|
from typing import Generator
|
|
from api.store.runs_store import RunsStore
|
|
from api.models.run import RunConfig, RunStatus
|
|
from api.callbacks.token_handler import TokenCallbackHandler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
except ImportError:
|
|
TradingAgentsGraph = None # type: ignore
|
|
DEFAULT_CONFIG = {}
|
|
|
|
|
|
class RunService:
|
|
def __init__(self, store: RunsStore):
|
|
self._store = store
|
|
|
|
def stream_events(self, run_id: str) -> Generator[dict, None, None]:
|
|
run = self._store.get(run_id)
|
|
if not run or not run.config:
|
|
yield {"event": "run:error", "data": {"message": "Run not found"}}
|
|
return
|
|
|
|
if run.status == RunStatus.COMPLETE:
|
|
token_usage = {k: v.model_dump() for k, v in (run.token_usage or {}).items()}
|
|
for key, report in run.reports.items():
|
|
if ":" not in key:
|
|
logger.warning(
|
|
"Skipping malformed report key %r for run %s", key, run_id
|
|
)
|
|
continue
|
|
step_key, turn_str = key.rsplit(":", 1)
|
|
if not turn_str.isdigit():
|
|
logger.warning(
|
|
"Skipping report key with non-numeric turn %r for run %s", key, run_id
|
|
)
|
|
continue
|
|
turn = int(turn_str)
|
|
raw = token_usage.get(key, {"tokens_in": 0, "tokens_out": 0})
|
|
yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}}
|
|
yield {"event": "agent:complete", "data": {
|
|
"step": step_key, "turn": turn, "report": report,
|
|
"tokens_in": raw.get("tokens_in", 0),
|
|
"tokens_out": raw.get("tokens_out", 0),
|
|
}}
|
|
yield {"event": "run:complete", "data": {"decision": run.decision or "HOLD", "run_id": run_id}}
|
|
return
|
|
|
|
if run.status == RunStatus.RUNNING:
|
|
yield {"event": "run:error", "data": {"message": "Run is already in progress"}}
|
|
return
|
|
|
|
self._store.clear_reports(run_id)
|
|
self._store.clear_token_usage(run_id)
|
|
self._store.update_status(run_id, RunStatus.RUNNING)
|
|
config = run.config
|
|
|
|
ta_config = DEFAULT_CONFIG.copy()
|
|
ta_config["llm_provider"] = config.llm_provider
|
|
ta_config["deep_think_llm"] = config.deep_think_llm
|
|
ta_config["quick_think_llm"] = config.quick_think_llm
|
|
ta_config["max_debate_rounds"] = config.max_debate_rounds
|
|
ta_config["max_risk_discuss_rounds"] = config.max_risk_discuss_rounds
|
|
|
|
try:
|
|
token_handler = TokenCallbackHandler()
|
|
ta = TradingAgentsGraph(
|
|
debug=False,
|
|
config=ta_config,
|
|
selected_analysts=config.enabled_analysts or
|
|
["market", "news", "fundamentals", "social"],
|
|
callbacks=[token_handler],
|
|
)
|
|
|
|
turn_counts: defaultdict[str, int] = defaultdict(int)
|
|
|
|
for step_key, report in ta.stream_propagate(config.ticker, config.date):
|
|
tokens = token_handler.snapshot_and_reset()
|
|
turn = turn_counts[step_key]
|
|
yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}}
|
|
yield {"event": "agent:complete", "data": {
|
|
"step": step_key, "turn": turn, "report": report,
|
|
"tokens_in": tokens["in"], "tokens_out": tokens["out"],
|
|
}}
|
|
self._store.add_report(run_id, f"{step_key}:{turn}", report)
|
|
# Normalize to TokenUsage field names before persisting
|
|
self._store.add_token_usage(
|
|
run_id, f"{step_key}:{turn}",
|
|
{"tokens_in": tokens["in"], "tokens_out": tokens["out"]},
|
|
)
|
|
turn_counts[step_key] += 1
|
|
|
|
decision = ta._last_decision or "HOLD"
|
|
self._store.update_decision(run_id, decision)
|
|
self._store.update_status(run_id, RunStatus.COMPLETE)
|
|
yield {"event": "run:complete", "data": {"decision": decision, "run_id": run_id}}
|
|
|
|
except Exception as e:
|
|
self._store.set_error(run_id, str(e))
|
|
yield {"event": "run:error", "data": {"message": str(e)}}
|