feat: add token usage tracking and reporting to run service
- Introduced TokenCallbackHandler to track input and output token usage during LLM operations. - Updated RunResult model to include token usage data. - Enhanced RunsStore to support token usage persistence in the database. - Modified RunService to yield token usage information during event streaming. - Implemented UI components to display token statistics in the run detail view. - Added tests for token handling and reporting functionality. Made-with: Cursor
This commit is contained in:
commit
29a338d957
|
|
@ -0,0 +1,41 @@
|
|||
import threading
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class TokenCallbackHandler(BaseCallbackHandler):
|
||||
"""Tracks LLM token usage. Call snapshot_and_reset() after each agent step
|
||||
to get the delta tokens consumed by that step, then zero the counters."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
self._tokens_in = 0
|
||||
self._tokens_out = 0
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
except (IndexError, TypeError):
|
||||
return
|
||||
if not hasattr(generation, "message"):
|
||||
return
|
||||
message = generation.message
|
||||
if not isinstance(message, AIMessage):
|
||||
return
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
if usage is not None:
|
||||
with self._lock:
|
||||
self._tokens_in += usage.get("input_tokens", 0)
|
||||
self._tokens_out += usage.get("output_tokens", 0)
|
||||
|
||||
def snapshot_and_reset(self) -> dict[str, int]:
|
||||
"""Return {"in": N, "out": M} for the current period and zero counters."""
|
||||
with self._lock:
|
||||
result = {"in": self._tokens_in, "out": self._tokens_out}
|
||||
self._tokens_in = 0
|
||||
self._tokens_out = 0
|
||||
return result
|
||||
|
|
@ -32,7 +32,13 @@ class RunSummary(BaseModel):
|
|||
created_at: str
|
||||
|
||||
|
||||
class TokenUsage(BaseModel):
|
||||
tokens_in: int = 0
|
||||
tokens_out: int = 0
|
||||
|
||||
|
||||
class RunResult(RunSummary):
|
||||
config: Optional[RunConfig] = None
|
||||
reports: dict[str, str] = {}
|
||||
reports: dict[str, str] = Field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
token_usage: dict[str, TokenUsage] = Field(default_factory=dict)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,19 @@
|
|||
import json
|
||||
import pathlib
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from api.models.run import RunConfig, RunResult, RunSummary
|
||||
from api.services.run_service import RunService
|
||||
from api.store.runs_store import RunsStore
|
||||
|
||||
try:
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
except ImportError:
|
||||
DEFAULT_CONFIG = {"results_dir": "./results"}
|
||||
|
||||
router = APIRouter()
|
||||
_store = RunsStore()
|
||||
_db_path = pathlib.Path(DEFAULT_CONFIG["results_dir"]) / "runs.sqlite"
|
||||
_store = RunsStore(_db_path)
|
||||
_service = RunService(_store)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from collections import defaultdict
|
|||
from typing import Generator
|
||||
from api.store.runs_store import RunsStore
|
||||
from api.models.run import RunConfig, RunStatus
|
||||
from api.callbacks.token_handler import TokenCallbackHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -24,6 +25,37 @@ class RunService:
|
|||
yield {"event": "run:error", "data": {"message": "Run not found"}}
|
||||
return
|
||||
|
||||
if run.status == RunStatus.COMPLETE:
|
||||
token_usage = {k: v.model_dump() for k, v in (run.token_usage or {}).items()}
|
||||
for key, report in run.reports.items():
|
||||
if ":" not in key:
|
||||
logger.warning(
|
||||
"Skipping malformed report key %r for run %s", key, run_id
|
||||
)
|
||||
continue
|
||||
step_key, turn_str = key.rsplit(":", 1)
|
||||
if not turn_str.isdigit():
|
||||
logger.warning(
|
||||
"Skipping report key with non-numeric turn %r for run %s", key, run_id
|
||||
)
|
||||
continue
|
||||
turn = int(turn_str)
|
||||
raw = token_usage.get(key, {"tokens_in": 0, "tokens_out": 0})
|
||||
yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}}
|
||||
yield {"event": "agent:complete", "data": {
|
||||
"step": step_key, "turn": turn, "report": report,
|
||||
"tokens_in": raw.get("tokens_in", 0),
|
||||
"tokens_out": raw.get("tokens_out", 0),
|
||||
}}
|
||||
yield {"event": "run:complete", "data": {"decision": run.decision or "HOLD", "run_id": run_id}}
|
||||
return
|
||||
|
||||
if run.status == RunStatus.RUNNING:
|
||||
yield {"event": "run:error", "data": {"message": "Run is already in progress"}}
|
||||
return
|
||||
|
||||
self._store.clear_reports(run_id)
|
||||
self._store.clear_token_usage(run_id)
|
||||
self._store.update_status(run_id, RunStatus.RUNNING)
|
||||
config = run.config
|
||||
|
||||
|
|
@ -35,20 +67,31 @@ class RunService:
|
|||
ta_config["max_risk_discuss_rounds"] = config.max_risk_discuss_rounds
|
||||
|
||||
try:
|
||||
token_handler = TokenCallbackHandler()
|
||||
ta = TradingAgentsGraph(
|
||||
debug=False,
|
||||
config=ta_config,
|
||||
selected_analysts=config.enabled_analysts or
|
||||
["market", "news", "fundamentals", "social"],
|
||||
callbacks=[token_handler],
|
||||
)
|
||||
|
||||
turn_counts: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
for step_key, report in ta.stream_propagate(config.ticker, config.date):
|
||||
tokens = token_handler.snapshot_and_reset()
|
||||
turn = turn_counts[step_key]
|
||||
yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}}
|
||||
yield {"event": "agent:complete", "data": {"step": step_key, "turn": turn, "report": report}}
|
||||
yield {"event": "agent:complete", "data": {
|
||||
"step": step_key, "turn": turn, "report": report,
|
||||
"tokens_in": tokens["in"], "tokens_out": tokens["out"],
|
||||
}}
|
||||
self._store.add_report(run_id, f"{step_key}:{turn}", report)
|
||||
# Normalize to TokenUsage field names before persisting
|
||||
self._store.add_token_usage(
|
||||
run_id, f"{step_key}:{turn}",
|
||||
{"tokens_in": tokens["in"], "tokens_out": tokens["out"]},
|
||||
)
|
||||
turn_counts[step_key] += 1
|
||||
|
||||
decision = ta._last_decision or "HOLD"
|
||||
|
|
|
|||
|
|
@ -1,63 +1,165 @@
|
|||
import json
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from api.models.run import RunConfig, RunResult, RunStatus
|
||||
from typing import Optional, Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from api.models.run import RunConfig, RunResult, RunStatus, TokenUsage
|
||||
|
||||
_CREATE_TABLE_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS runs (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
decision TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
config TEXT,
|
||||
reports TEXT NOT NULL DEFAULT '{}',
|
||||
error TEXT,
|
||||
token_usage TEXT NOT NULL DEFAULT '{}'
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class RunsStore:
|
||||
def __init__(self):
|
||||
self._runs: dict[str, RunResult] = {}
|
||||
def __init__(self, db_path: Path) -> None:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._conn = sqlite3.connect(str(db_path), check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute(_CREATE_TABLE_SQL)
|
||||
# Migration: add token_usage column if not present (handles existing DBs)
|
||||
existing_cols = {
|
||||
row["name"]
|
||||
for row in self._conn.execute("PRAGMA table_info(runs)")
|
||||
}
|
||||
if "token_usage" not in existing_cols:
|
||||
self._conn.execute(
|
||||
"ALTER TABLE runs ADD COLUMN token_usage TEXT NOT NULL DEFAULT '{}'"
|
||||
)
|
||||
self._conn.commit()
|
||||
self._lock = Lock()
|
||||
|
||||
def _row_to_run(self, row: sqlite3.Row) -> RunResult:
|
||||
return RunResult(
|
||||
id=row["id"],
|
||||
ticker=row["ticker"],
|
||||
date=row["date"],
|
||||
status=RunStatus(row["status"]),
|
||||
decision=row["decision"],
|
||||
created_at=row["created_at"],
|
||||
config=RunConfig(**json.loads(row["config"])) if row["config"] else None,
|
||||
reports=json.loads(row["reports"]),
|
||||
error=row["error"],
|
||||
token_usage={
|
||||
k: TokenUsage(**v)
|
||||
for k, v in json.loads(row["token_usage"] or "{}").items()
|
||||
},
|
||||
)
|
||||
|
||||
def create(self, config: RunConfig) -> RunResult:
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
run = RunResult(
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"INSERT INTO runs (id, ticker, date, status, created_at, config)"
|
||||
" VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(run_id, config.ticker, config.date, RunStatus.QUEUED.value,
|
||||
now, config.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
return RunResult(
|
||||
id=run_id,
|
||||
ticker=config.ticker,
|
||||
date=config.date,
|
||||
status=RunStatus.QUEUED,
|
||||
created_at=datetime.now(timezone.utc).isoformat(),
|
||||
created_at=now,
|
||||
config=config,
|
||||
)
|
||||
with self._lock:
|
||||
self._runs[run_id] = run
|
||||
return run
|
||||
|
||||
def get(self, run_id: str) -> Optional[RunResult]:
|
||||
return self._runs.get(run_id)
|
||||
with self._lock:
|
||||
row = self._conn.execute(
|
||||
"SELECT * FROM runs WHERE id = ?", (run_id,)
|
||||
).fetchone()
|
||||
return self._row_to_run(row) if row else None
|
||||
|
||||
def list_all(self) -> list[RunResult]:
|
||||
return list(self._runs.values())
|
||||
with self._lock:
|
||||
rows = self._conn.execute(
|
||||
"SELECT * FROM runs ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
return [self._row_to_run(row) for row in rows]
|
||||
|
||||
def update_status(self, run_id: str, status: RunStatus) -> None:
|
||||
with self._lock:
|
||||
if run_id in self._runs:
|
||||
self._runs[run_id] = self._runs[run_id].model_copy(
|
||||
update={"status": status}
|
||||
)
|
||||
self._conn.execute(
|
||||
"UPDATE runs SET status = ? WHERE id = ?",
|
||||
(status.value, run_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def update_decision(
|
||||
self, run_id: str, decision: Literal["BUY", "SELL", "HOLD"]
|
||||
) -> None:
|
||||
with self._lock:
|
||||
if run_id in self._runs:
|
||||
self._runs[run_id] = self._runs[run_id].model_copy(
|
||||
update={"decision": decision}
|
||||
)
|
||||
self._conn.execute(
|
||||
"UPDATE runs SET decision = ? WHERE id = ?",
|
||||
(decision, run_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def add_report(self, run_id: str, step: str, report: str) -> None:
|
||||
with self._lock:
|
||||
if run_id in self._runs:
|
||||
reports = dict(self._runs[run_id].reports)
|
||||
row = self._conn.execute(
|
||||
"SELECT reports FROM runs WHERE id = ?", (run_id,)
|
||||
).fetchone()
|
||||
if row:
|
||||
reports = json.loads(row[0])
|
||||
reports[step] = report
|
||||
self._runs[run_id] = self._runs[run_id].model_copy(
|
||||
update={"reports": reports}
|
||||
self._conn.execute(
|
||||
"UPDATE runs SET reports = ? WHERE id = ?",
|
||||
(json.dumps(reports), run_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def set_error(self, run_id: str, error: str) -> None:
|
||||
with self._lock:
|
||||
if run_id in self._runs:
|
||||
self._runs[run_id] = self._runs[run_id].model_copy(
|
||||
update={"status": RunStatus.ERROR, "error": error}
|
||||
self._conn.execute(
|
||||
"UPDATE runs SET status = ?, error = ? WHERE id = ?",
|
||||
(RunStatus.ERROR.value, error, run_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def clear_reports(self, run_id: str) -> None:
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE runs SET reports = '{}' WHERE id = ?",
|
||||
(run_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def add_token_usage(self, run_id: str, key: str, tokens: dict) -> None:
|
||||
with self._lock:
|
||||
row = self._conn.execute(
|
||||
"SELECT token_usage FROM runs WHERE id = ?", (run_id,)
|
||||
).fetchone()
|
||||
if row:
|
||||
usage = json.loads(row[0] or "{}")
|
||||
usage[key] = tokens
|
||||
self._conn.execute(
|
||||
"UPDATE runs SET token_usage = ? WHERE id = ?",
|
||||
(json.dumps(usage), run_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def clear_token_usage(self, run_id: str) -> None:
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE runs SET token_usage = '{}' WHERE id = ?",
|
||||
(run_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from api.models.run import RunConfig, RunStatus
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return RunsStore()
|
||||
def store(tmp_path):
|
||||
return RunsStore(tmp_path / "test.sqlite")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -116,3 +116,124 @@ def test_selected_analysts_passed_to_graph(service, store):
|
|||
|
||||
call_kwargs = MockGraph.call_args.kwargs
|
||||
assert call_kwargs.get("selected_analysts") == ["market", "news"]
|
||||
|
||||
|
||||
def test_completed_run_replays_without_re_running_graph(service, store):
|
||||
# Pre-populate a completed run in the store
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.update_status(run.id, RunStatus.RUNNING)
|
||||
store.add_report(run.id, "market_analyst:0", "bullish")
|
||||
store.update_status(run.id, RunStatus.COMPLETE)
|
||||
store.update_decision(run.id, "BUY")
|
||||
|
||||
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
|
||||
events = list(service.stream_events(run.id))
|
||||
|
||||
assert MockGraph.call_count == 0 # no agent execution on replay
|
||||
agent_completes = [e for e in events if e["event"] == "agent:complete"]
|
||||
assert len(agent_completes) == 1
|
||||
assert agent_completes[0]["data"]["report"] == "bullish"
|
||||
run_complete = next(e for e in events if e["event"] == "run:complete")
|
||||
assert run_complete["data"]["decision"] == "BUY"
|
||||
|
||||
|
||||
def test_running_run_returns_in_progress_error(service, store):
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.update_status(run.id, RunStatus.RUNNING)
|
||||
|
||||
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
|
||||
events = list(service.stream_events(run.id))
|
||||
|
||||
assert MockGraph.call_count == 0 # no agent execution
|
||||
assert len(events) == 1
|
||||
assert events[0]["event"] == "run:error"
|
||||
assert "already in progress" in events[0]["data"]["message"]
|
||||
|
||||
|
||||
def test_error_run_retries_and_clears_reports(service, store):
|
||||
# Simulate a run that failed partway through with a stale report
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.update_status(run.id, RunStatus.RUNNING)
|
||||
store.add_report(run.id, "market_analyst:0", "stale data")
|
||||
store.set_error(run.id, "timeout")
|
||||
|
||||
assert store.get(run.id).status == RunStatus.ERROR
|
||||
assert store.get(run.id).reports == {"market_analyst:0": "stale data"}
|
||||
|
||||
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
|
||||
MockGraph.return_value = _mock_graph([("news_analyst", "fresh")], decision="HOLD")
|
||||
events = list(service.stream_events(run.id))
|
||||
|
||||
assert MockGraph.call_count == 1 # graph was executed on retry
|
||||
final_reports = store.get(run.id).reports
|
||||
assert "news_analyst:0" in final_reports
|
||||
assert "market_analyst:0" not in final_reports # stale data cleared on retry
|
||||
|
||||
|
||||
def test_live_run_agent_complete_includes_token_fields(service, store):
|
||||
"""agent:complete events must include tokens_in and tokens_out."""
|
||||
config = RunConfig(ticker="NVDA", date="2026-03-23")
|
||||
run = store.create(config)
|
||||
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
|
||||
MockGraph.return_value = _mock_graph([("market_analyst", "bullish")])
|
||||
events = list(service.stream_events(run.id))
|
||||
|
||||
complete_events = [e for e in events if e["event"] == "agent:complete"]
|
||||
assert len(complete_events) == 1
|
||||
data = complete_events[0]["data"]
|
||||
assert "tokens_in" in data
|
||||
assert "tokens_out" in data
|
||||
assert isinstance(data["tokens_in"], int)
|
||||
assert isinstance(data["tokens_out"], int)
|
||||
|
||||
|
||||
def test_retry_clears_stale_token_usage(service, store):
|
||||
"""After an errored run is retried, stale token keys are absent."""
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.update_status(run.id, RunStatus.RUNNING)
|
||||
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 500, "tokens_out": 200})
|
||||
store.set_error(run.id, "timeout")
|
||||
|
||||
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
|
||||
MockGraph.return_value = _mock_graph([("news_analyst", "fresh")])
|
||||
list(service.stream_events(run.id))
|
||||
|
||||
final = store.get(run.id)
|
||||
assert "market_analyst:0" not in final.token_usage # stale key gone
|
||||
assert "news_analyst:0" in final.token_usage # new key present
|
||||
|
||||
|
||||
def test_replay_attaches_token_data_to_agent_complete(service, store):
|
||||
"""Replaying a completed run emits agent:complete with token data."""
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.update_status(run.id, RunStatus.RUNNING)
|
||||
store.add_report(run.id, "market_analyst:0", "bullish")
|
||||
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400})
|
||||
store.update_status(run.id, RunStatus.COMPLETE)
|
||||
store.update_decision(run.id, "BUY")
|
||||
|
||||
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
|
||||
events = list(service.stream_events(run.id))
|
||||
|
||||
assert MockGraph.call_count == 0
|
||||
complete_events = [e for e in events if e["event"] == "agent:complete"]
|
||||
assert len(complete_events) == 1
|
||||
assert complete_events[0]["data"]["tokens_in"] == 1200
|
||||
assert complete_events[0]["data"]["tokens_out"] == 400
|
||||
|
||||
|
||||
def test_replay_defaults_to_zero_tokens_when_missing(service, store):
|
||||
"""Replay of a run with no token_usage emits tokens_in=0, tokens_out=0."""
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.update_status(run.id, RunStatus.RUNNING)
|
||||
store.add_report(run.id, "market_analyst:0", "bearish")
|
||||
store.update_status(run.id, RunStatus.COMPLETE)
|
||||
store.update_decision(run.id, "SELL")
|
||||
# No add_token_usage call — simulates old run with no token data
|
||||
|
||||
with patch("api.services.run_service.TradingAgentsGraph"):
|
||||
events = list(service.stream_events(run.id))
|
||||
|
||||
complete = next(e for e in events if e["event"] == "agent:complete")
|
||||
assert complete["data"]["tokens_in"] == 0
|
||||
assert complete["data"]["tokens_out"] == 0
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import pytest
|
||||
import sqlite3 as _sqlite3
|
||||
from api.store.runs_store import RunsStore
|
||||
from api.models.run import RunConfig, RunStatus
|
||||
|
||||
|
||||
def test_create_and_get_run():
|
||||
store = RunsStore()
|
||||
def test_create_and_get_run(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
config = RunConfig(ticker="NVDA", date="2024-05-10")
|
||||
run = store.create(config)
|
||||
assert run.id is not None
|
||||
|
|
@ -13,16 +14,108 @@ def test_create_and_get_run():
|
|||
assert fetched.ticker == "NVDA"
|
||||
|
||||
|
||||
def test_list_runs():
|
||||
store = RunsStore()
|
||||
def test_list_runs(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
|
||||
store.create(RunConfig(ticker="AAPL", date="2024-05-09"))
|
||||
runs = store.list_all()
|
||||
assert len(runs) == 2
|
||||
|
||||
|
||||
def test_update_run_status():
|
||||
store = RunsStore()
|
||||
def test_update_run_status(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
|
||||
store.update_status(run.id, RunStatus.RUNNING)
|
||||
assert store.get(run.id).status == RunStatus.RUNNING
|
||||
|
||||
|
||||
def test_add_report(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
|
||||
store.add_report(run.id, "market_analyst:0", "bullish")
|
||||
assert store.get(run.id).reports == {"market_analyst:0": "bullish"}
|
||||
|
||||
|
||||
def test_set_error(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
|
||||
store.set_error(run.id, "timeout")
|
||||
fetched = store.get(run.id)
|
||||
assert fetched.status == RunStatus.ERROR
|
||||
assert fetched.error == "timeout"
|
||||
|
||||
|
||||
def test_clear_reports(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
|
||||
store.add_report(run.id, "market_analyst:0", "bullish")
|
||||
store.clear_reports(run.id)
|
||||
assert store.get(run.id).reports == {}
|
||||
|
||||
|
||||
def test_update_decision(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
|
||||
store.update_decision(run.id, "BUY")
|
||||
assert store.get(run.id).decision == "BUY"
|
||||
|
||||
|
||||
def test_migration_adds_token_usage_column_to_existing_db(tmp_path):
|
||||
"""DB created without token_usage gets the column added on RunsStore.__init__."""
|
||||
db_path = tmp_path / "old.sqlite"
|
||||
# Create a DB without the token_usage column (simulate pre-migration state)
|
||||
conn = _sqlite3.connect(str(db_path))
|
||||
conn.execute("""
|
||||
CREATE TABLE runs (
|
||||
id TEXT PRIMARY KEY, ticker TEXT NOT NULL, date TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued', decision TEXT,
|
||||
created_at TEXT NOT NULL, config TEXT,
|
||||
reports TEXT NOT NULL DEFAULT '{}', error TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Initialising the store should migrate the column
|
||||
store = RunsStore(db_path)
|
||||
cols = {
|
||||
row["name"]
|
||||
for row in store._conn.execute("PRAGMA table_info(runs)")
|
||||
}
|
||||
assert "token_usage" in cols
|
||||
|
||||
|
||||
def test_migration_is_idempotent(tmp_path):
|
||||
"""Re-initialising the store after migration does not crash."""
|
||||
db_path = tmp_path / "test.sqlite"
|
||||
RunsStore(db_path) # first init — creates table + column
|
||||
RunsStore(db_path) # second init — column already present, should not crash
|
||||
|
||||
|
||||
def test_add_token_usage(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400})
|
||||
result = store.get(run.id)
|
||||
assert "market_analyst:0" in result.token_usage
|
||||
assert result.token_usage["market_analyst:0"].tokens_in == 1200
|
||||
assert result.token_usage["market_analyst:0"].tokens_out == 400
|
||||
|
||||
|
||||
def test_clear_token_usage(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400})
|
||||
store.clear_token_usage(run.id)
|
||||
assert store.get(run.id).token_usage == {}
|
||||
|
||||
|
||||
def test_clear_token_usage_does_not_affect_reports(tmp_path):
|
||||
store = RunsStore(tmp_path / "test.sqlite")
|
||||
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
|
||||
store.add_report(run.id, "market_analyst:0", "bullish")
|
||||
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 100, "tokens_out": 50})
|
||||
store.clear_token_usage(run.id)
|
||||
fetched = store.get(run.id)
|
||||
assert fetched.reports == {"market_analyst:0": "bullish"}
|
||||
assert fetched.token_usage == {}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import LLMResult, ChatGeneration
|
||||
from api.callbacks.token_handler import TokenCallbackHandler
|
||||
|
||||
|
||||
def _make_response(input_tokens: int, output_tokens: int) -> LLMResult:
|
||||
"""Build a minimal LLMResult that on_llm_end will parse."""
|
||||
msg = AIMessage(content="ok")
|
||||
msg.usage_metadata = {"input_tokens": input_tokens, "output_tokens": output_tokens}
|
||||
gen = ChatGeneration(message=msg)
|
||||
return LLMResult(generations=[[gen]])
|
||||
|
||||
|
||||
def test_snapshot_and_reset_returns_delta():
|
||||
handler = TokenCallbackHandler()
|
||||
handler.on_llm_end(_make_response(100, 40))
|
||||
result = handler.snapshot_and_reset()
|
||||
assert result == {"in": 100, "out": 40}
|
||||
|
||||
|
||||
def test_snapshot_and_reset_zeroes_counters():
|
||||
handler = TokenCallbackHandler()
|
||||
handler.on_llm_end(_make_response(100, 40))
|
||||
handler.snapshot_and_reset()
|
||||
second = handler.snapshot_and_reset()
|
||||
assert second == {"in": 0, "out": 0}
|
||||
|
||||
|
||||
def test_multiple_llm_calls_accumulate():
|
||||
handler = TokenCallbackHandler()
|
||||
handler.on_llm_end(_make_response(100, 40))
|
||||
handler.on_llm_end(_make_response(200, 60))
|
||||
result = handler.snapshot_and_reset()
|
||||
assert result == {"in": 300, "out": 100}
|
||||
|
||||
|
||||
def test_concurrent_on_llm_end_does_not_corrupt():
|
||||
handler = TokenCallbackHandler()
|
||||
threads = [
|
||||
threading.Thread(target=handler.on_llm_end, args=(_make_response(10, 5),))
|
||||
for _ in range(20)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
result = handler.snapshot_and_reset()
|
||||
assert result == {"in": 200, "out": 100}
|
||||
|
||||
|
||||
def test_missing_usage_metadata_does_not_crash():
|
||||
handler = TokenCallbackHandler()
|
||||
msg = AIMessage(content="no metadata")
|
||||
# No usage_metadata attribute
|
||||
gen = ChatGeneration(message=msg)
|
||||
response = LLMResult(generations=[[gen]])
|
||||
handler.on_llm_end(response) # should not raise
|
||||
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
|
||||
|
||||
|
||||
def test_empty_outer_generations_does_not_crash():
|
||||
handler = TokenCallbackHandler()
|
||||
response = LLMResult(generations=[])
|
||||
handler.on_llm_end(response) # IndexError guard
|
||||
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
|
||||
|
||||
|
||||
def test_empty_inner_generations_does_not_crash():
|
||||
handler = TokenCallbackHandler()
|
||||
response = LLMResult(generations=[[]])
|
||||
handler.on_llm_end(response) # IndexError guard
|
||||
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
import { render, screen } from '@testing-library/react'
|
||||
import TokenStatsBar from '@/features/run-detail/components/TokenStatsBar'
|
||||
|
||||
test('renders nothing when both token counts are zero', () => {
|
||||
const { container } = render(
|
||||
<TokenStatsBar tokensTotal={{ in: 0, out: 0 }} status="running" />
|
||||
)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
test('renders total, input, output when tokens are non-zero', () => {
|
||||
render(
|
||||
<TokenStatsBar tokensTotal={{ in: 3100, out: 1100 }} status="running" />
|
||||
)
|
||||
// Total = 3100 + 1100 = 4200 → "4.2k"
|
||||
expect(screen.getByText('4.2k')).toBeInTheDocument()
|
||||
// Input → "3.1k"
|
||||
expect(screen.getByText('3.1k')).toBeInTheDocument()
|
||||
// Output → "1.1k"
|
||||
expect(screen.getByText('1.1k')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
test('formats values below 1000 without k suffix', () => {
|
||||
render(
|
||||
<TokenStatsBar tokensTotal={{ in: 800, out: 150 }} status="complete" />
|
||||
)
|
||||
expect(screen.getByText('950')).toBeInTheDocument() // total
|
||||
expect(screen.getByText('800')).toBeInTheDocument() // input
|
||||
expect(screen.getByText('150')).toBeInTheDocument() // output
|
||||
})
|
||||
|
|
@ -16,8 +16,21 @@ jest.mock('@/lib/sse', () => ({
|
|||
}),
|
||||
}))
|
||||
|
||||
// getRun defaults to 'queued' so existing SSE-path tests still exercise the SSE branch.
|
||||
// Tests that need a different status use mockResolvedValueOnce to override.
|
||||
jest.mock('@/lib/api-client', () => ({
|
||||
getRunStreamUrl: (id: string) => `/api/runs/${id}/stream`,
|
||||
getRun: jest.fn().mockResolvedValue({
|
||||
id: 'abc',
|
||||
ticker: 'NVDA',
|
||||
date: '2026-03-23',
|
||||
status: 'queued',
|
||||
decision: null,
|
||||
created_at: '2026-03-23T00:00:00Z',
|
||||
config: null,
|
||||
reports: {},
|
||||
error: null,
|
||||
}),
|
||||
}))
|
||||
|
||||
test('appends multiple turns for same step', async () => {
|
||||
|
|
@ -46,3 +59,118 @@ test('initial reports are empty arrays', () => {
|
|||
const { result } = renderHook(() => useRunStream('abc'))
|
||||
expect(result.current.reports['market_analyst']).toEqual([])
|
||||
})
|
||||
|
||||
test('hydrates from reports when run is complete, skipping SSE', async () => {
|
||||
const { getRun } = jest.requireMock('@/lib/api-client')
|
||||
const { createSSEConnection } = jest.requireMock('@/lib/sse')
|
||||
|
||||
// Reset call history so we can assert createSSEConnection was NOT called for this run
|
||||
jest.clearAllMocks()
|
||||
|
||||
getRun.mockResolvedValueOnce({
|
||||
id: 'xyz',
|
||||
ticker: 'AAPL',
|
||||
date: '2026-03-23',
|
||||
status: 'complete',
|
||||
decision: 'SELL',
|
||||
created_at: '2026-03-23T00:00:00Z',
|
||||
config: null,
|
||||
reports: { 'market_analyst:0': 'bearish signal' },
|
||||
error: null,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useRunStream('xyz'))
|
||||
await act(async () => { await new Promise((r) => setTimeout(r, 10)) })
|
||||
|
||||
expect(result.current.status).toBe('complete')
|
||||
expect(result.current.verdict).toBe('SELL')
|
||||
expect(result.current.reports['market_analyst']).toEqual(['bearish signal'])
|
||||
expect(createSSEConnection).not.toHaveBeenCalled() // SSE skipped for completed run
|
||||
})
|
||||
|
||||
test('AGENT_COMPLETE accumulates tokensByStep and tokensTotal', async () => {
|
||||
const { getRun } = jest.requireMock('@/lib/api-client')
|
||||
const { createSSEConnection } = jest.requireMock('@/lib/sse')
|
||||
jest.clearAllMocks()
|
||||
|
||||
// getRun returns queued so SSE path runs
|
||||
getRun.mockResolvedValueOnce({
|
||||
id: 'abc', ticker: 'NVDA', date: '2026-03-23', status: 'queued',
|
||||
decision: null, created_at: '2026-03-23T00:00:00Z',
|
||||
config: null, reports: {}, error: null, token_usage: null,
|
||||
})
|
||||
|
||||
// SSE mock emits one agent:complete with token data
|
||||
createSSEConnection.mockImplementationOnce(
|
||||
(_url: string, handlers: Record<string, (d: unknown) => void>) => {
|
||||
setTimeout(() => {
|
||||
handlers.onAgentStart?.({ step: 'market_analyst', turn: 0 })
|
||||
handlers.onAgentComplete?.({
|
||||
step: 'market_analyst', turn: 0, report: 'bullish',
|
||||
tokens_in: 1200, tokens_out: 400,
|
||||
})
|
||||
handlers.onRunComplete?.({ decision: 'BUY', run_id: 'abc' })
|
||||
}, 0)
|
||||
return jest.fn()
|
||||
}
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useRunStream('abc'))
|
||||
await act(async () => { await new Promise((r) => setTimeout(r, 10)) })
|
||||
|
||||
expect(result.current.tokensByStep['market_analyst']).toEqual({ in: 1200, out: 400 })
|
||||
expect(result.current.tokensTotal).toEqual({ in: 1200, out: 400 })
|
||||
})
|
||||
|
||||
test('missing tokens_in/out in AGENT_COMPLETE defaults to 0', async () => {
|
||||
const { getRun } = jest.requireMock('@/lib/api-client')
|
||||
const { createSSEConnection } = jest.requireMock('@/lib/sse')
|
||||
jest.clearAllMocks()
|
||||
|
||||
getRun.mockResolvedValueOnce({
|
||||
id: 'abc', ticker: 'NVDA', date: '2026-03-23', status: 'queued',
|
||||
decision: null, created_at: '2026-03-23T00:00:00Z',
|
||||
config: null, reports: {}, error: null, token_usage: null,
|
||||
})
|
||||
|
||||
createSSEConnection.mockImplementationOnce(
|
||||
(_url: string, handlers: Record<string, (d: unknown) => void>) => {
|
||||
setTimeout(() => {
|
||||
handlers.onAgentStart?.({ step: 'news_analyst', turn: 0 })
|
||||
// No tokens_in/tokens_out in payload
|
||||
handlers.onAgentComplete?.({ step: 'news_analyst', turn: 0, report: 'ok' })
|
||||
handlers.onRunComplete?.({ decision: 'HOLD', run_id: 'abc' })
|
||||
}, 0)
|
||||
return jest.fn()
|
||||
}
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useRunStream('abc'))
|
||||
await act(async () => { await new Promise((r) => setTimeout(r, 10)) })
|
||||
|
||||
expect(result.current.tokensByStep['news_analyst']).toEqual({ in: 0, out: 0 })
|
||||
expect(result.current.tokensTotal).toEqual({ in: 0, out: 0 })
|
||||
})
|
||||
|
||||
test('completed-run hydration populates tokens from getRun().token_usage without SSE', async () => {
|
||||
const { getRun } = jest.requireMock('@/lib/api-client')
|
||||
const { createSSEConnection } = jest.requireMock('@/lib/sse')
|
||||
jest.clearAllMocks()
|
||||
|
||||
getRun.mockResolvedValueOnce({
|
||||
id: 'tok', ticker: 'AAPL', date: '2026-03-23', status: 'complete',
|
||||
decision: 'BUY', created_at: '2026-03-23T00:00:00Z',
|
||||
config: null,
|
||||
reports: { 'market_analyst:0': 'bullish' },
|
||||
error: null,
|
||||
token_usage: { 'market_analyst:0': { tokens_in: 1200, tokens_out: 400 } },
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useRunStream('tok'))
|
||||
await act(async () => { await new Promise((r) => setTimeout(r, 10)) })
|
||||
|
||||
expect(createSSEConnection).not.toHaveBeenCalled()
|
||||
expect(result.current.status).toBe('complete')
|
||||
expect(result.current.tokensByStep['market_analyst']).toEqual({ in: 1200, out: 400 })
|
||||
expect(result.current.tokensTotal).toEqual({ in: 1200, out: 400 })
|
||||
})
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import VerdictBanner from '@/features/run-detail/components/VerdictBanner'
|
|||
import PhaseTabs from '@/features/run-detail/components/PhaseTabs'
|
||||
import { getRun } from '@/lib/api-client'
|
||||
import type { RunSummary } from '@/lib/types/run'
|
||||
import TokenStatsBar from '@/features/run-detail/components/TokenStatsBar'
|
||||
|
||||
const STATUS_CONFIG: Record<string, {
|
||||
bg: string; color: string; dot: string; label: string; pulse: boolean
|
||||
|
|
@ -18,7 +19,7 @@ const STATUS_CONFIG: Record<string, {
|
|||
|
||||
export default function RunDetailPage({ params }: { params: Promise<{ id: string }> }) {
|
||||
const { id } = use(params)
|
||||
const { steps, reports, verdict, status, error } = useRunStream(id)
|
||||
const { steps, reports, verdict, status, error, tokensTotal, tokensByStep } = useRunStream(id)
|
||||
const [run, setRun] = useState<RunSummary | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -88,6 +89,9 @@ export default function RunDetailPage({ params }: { params: Promise<{ id: string
|
|||
</div>
|
||||
</div>
|
||||
|
||||
{/* Token stats bar */}
|
||||
<TokenStatsBar tokensTotal={tokensTotal} status={status} />
|
||||
|
||||
{/* Pipeline */}
|
||||
<PipelineStepper steps={steps} />
|
||||
|
||||
|
|
@ -111,7 +115,7 @@ export default function RunDetailPage({ params }: { params: Promise<{ id: string
|
|||
)}
|
||||
|
||||
{/* Phase tabs + reports */}
|
||||
<PhaseTabs steps={steps} reports={reports} />
|
||||
<PhaseTabs steps={steps} reports={reports} tokensByStep={tokensByStep} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@ import { AGENT_STEPS, AGENT_STEP_LABELS, STEP_PHASE } from '@/lib/types/run'
|
|||
import type { AgentStep } from '@/lib/types/run'
|
||||
import type { StepStatus } from '@/lib/types/agents'
|
||||
|
||||
function formatTokens(n: number): string {
|
||||
return n >= 1000 ? `${(n / 1000).toFixed(1)}k` : String(n)
|
||||
}
|
||||
|
||||
type Phase = 'analysts' | 'researchers' | 'trader' | 'risk'
|
||||
|
||||
const MULTI_TURN_STEPS = new Set<AgentStep>([
|
||||
|
|
@ -43,9 +47,10 @@ type Props = {
|
|||
phase: Phase
|
||||
steps: Record<AgentStep, StepStatus>
|
||||
reports: Record<AgentStep, string[]>
|
||||
tokensByStep: Record<AgentStep, { in: number; out: number }>
|
||||
}
|
||||
|
||||
export default function AnalystReports({ phase, steps, reports }: Props) {
|
||||
export default function AnalystReports({ phase, steps, reports, tokensByStep }: Props) {
|
||||
const phaseSteps = AGENT_STEPS.filter((s) => STEP_PHASE[s] === phase)
|
||||
|
||||
return (
|
||||
|
|
@ -112,6 +117,22 @@ export default function AnalystReports({ phase, steps, reports }: Props) {
|
|||
T{i + 1}
|
||||
</span>
|
||||
)}
|
||||
{(() => {
|
||||
const tok = tokensByStep[step]
|
||||
return tok && (tok.in > 0 || tok.out > 0) ? (
|
||||
<span
|
||||
style={{
|
||||
fontFamily: 'var(--font-mono)',
|
||||
fontSize: '9px',
|
||||
color: 'var(--text-mid)',
|
||||
letterSpacing: '0.02em',
|
||||
}}
|
||||
>
|
||||
<span style={{ color: 'var(--accent)' }}>{formatTokens(tok.in)}</span>↑{' '}
|
||||
{formatTokens(tok.out)}↓
|
||||
</span>
|
||||
) : null
|
||||
})()}
|
||||
</div>
|
||||
|
||||
<div
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ const TABS: { label: string; phase: Phase; count: string }[] = [
|
|||
]
|
||||
|
||||
type Props = {
|
||||
steps: Record<AgentStep, StepStatus>
|
||||
reports: Record<AgentStep, string[]>
|
||||
steps: Record<AgentStep, StepStatus>
|
||||
reports: Record<AgentStep, string[]>
|
||||
tokensByStep: Record<AgentStep, { in: number; out: number }>
|
||||
}
|
||||
|
||||
function getPhaseCompletion(phase: Phase, steps: Record<AgentStep, StepStatus>): number {
|
||||
|
|
@ -32,7 +33,7 @@ function getPhaseStatus(phase: Phase, steps: Record<AgentStep, StepStatus>): 'do
|
|||
return 'pending'
|
||||
}
|
||||
|
||||
export default function PhaseTabs({ steps, reports }: Props) {
|
||||
export default function PhaseTabs({ steps, reports, tokensByStep }: Props) {
|
||||
const [active, setActive] = useState<Phase>('analysts')
|
||||
|
||||
return (
|
||||
|
|
@ -118,7 +119,7 @@ export default function PhaseTabs({ steps, reports }: Props) {
|
|||
</div>
|
||||
|
||||
{/* Reports */}
|
||||
<AnalystReports phase={active} steps={steps} reports={reports} />
|
||||
<AnalystReports phase={active} steps={steps} reports={reports} tokensByStep={tokensByStep} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
import type { TokenCount } from '../types'
|
||||
|
||||
function formatTokens(n: number): string {
|
||||
return n >= 1000 ? `${(n / 1000).toFixed(1)}k` : String(n)
|
||||
}
|
||||
|
||||
type Props = {
|
||||
tokensTotal: TokenCount
|
||||
status: string
|
||||
}
|
||||
|
||||
export default function TokenStatsBar({ tokensTotal }: Props) {
|
||||
if (tokensTotal.in === 0 && tokensTotal.out === 0) return null
|
||||
|
||||
const total = tokensTotal.in + tokensTotal.out
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex items-center gap-4 px-4 py-2.5 rounded-xl"
|
||||
style={{
|
||||
background: 'var(--bg-elevated)',
|
||||
border: '1px solid var(--border-raised)',
|
||||
}}
|
||||
>
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<span
|
||||
className="text-[8px] uppercase tracking-widest"
|
||||
style={{ color: 'var(--text-mid)', fontFamily: 'var(--font-mono)' }}
|
||||
>
|
||||
Total
|
||||
</span>
|
||||
<span
|
||||
className="text-[11px] font-bold"
|
||||
style={{ color: 'var(--text-high)', fontFamily: 'var(--font-mono)' }}
|
||||
>
|
||||
{formatTokens(total)}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="w-px h-6 shrink-0" style={{ background: 'var(--text-low)', opacity: 0.35 }} />
|
||||
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<span
|
||||
className="text-[8px] uppercase tracking-widest"
|
||||
style={{ color: 'var(--text-mid)', fontFamily: 'var(--font-mono)' }}
|
||||
>
|
||||
Input ↑
|
||||
</span>
|
||||
<span
|
||||
className="text-[11px] font-bold"
|
||||
style={{ color: 'var(--accent-light)', fontFamily: 'var(--font-mono)' }}
|
||||
>
|
||||
{formatTokens(tokensTotal.in)}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<span
|
||||
className="text-[8px] uppercase tracking-widest"
|
||||
style={{ color: 'var(--text-mid)', fontFamily: 'var(--font-mono)' }}
|
||||
>
|
||||
Output ↓
|
||||
</span>
|
||||
<span
|
||||
className="text-[11px] font-bold"
|
||||
style={{ color: 'var(--text-mid)', fontFamily: 'var(--font-mono)' }}
|
||||
>
|
||||
{formatTokens(tokensTotal.out)}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
@ -1,24 +1,28 @@
|
|||
'use client'
|
||||
import { useEffect, useReducer } from 'react'
|
||||
import { createSSEConnection } from '@/lib/sse'
|
||||
import { getRunStreamUrl } from '@/lib/api-client'
|
||||
import { getRun, getRunStreamUrl } from '@/lib/api-client'
|
||||
import { AGENT_STEPS } from '@/lib/types/run'
|
||||
import type { AgentStep } from '@/lib/types/run'
|
||||
import type { RunStreamState } from '../types'
|
||||
import type { RunStreamState, TokenCount } from '../types'
|
||||
|
||||
const zeroTokens = (): TokenCount => ({ in: 0, out: 0 })
|
||||
|
||||
const initialState: RunStreamState = {
|
||||
status: 'connecting',
|
||||
steps: Object.fromEntries(AGENT_STEPS.map((s) => [s, 'pending'])) as RunStreamState['steps'],
|
||||
reports: Object.fromEntries(AGENT_STEPS.map((s) => [s, []])) as RunStreamState['reports'],
|
||||
steps: Object.fromEntries(AGENT_STEPS.map((s) => [s, 'pending'])) as RunStreamState['steps'],
|
||||
reports: Object.fromEntries(AGENT_STEPS.map((s) => [s, []])) as RunStreamState['reports'],
|
||||
tokensByStep: Object.fromEntries(AGENT_STEPS.map((s) => [s, zeroTokens()])) as RunStreamState['tokensByStep'],
|
||||
tokensTotal: zeroTokens(),
|
||||
verdict: null,
|
||||
error: null,
|
||||
}
|
||||
|
||||
type Action =
|
||||
| { type: 'AGENT_START'; step: AgentStep; turn: number }
|
||||
| { type: 'AGENT_COMPLETE'; step: AgentStep; turn: number; report: string }
|
||||
| { type: 'RUN_COMPLETE'; decision: string }
|
||||
| { type: 'RUN_ERROR'; message: string }
|
||||
| { type: 'AGENT_START'; step: AgentStep; turn: number }
|
||||
| { type: 'AGENT_COMPLETE'; step: AgentStep; turn: number; report: string; tokens_in?: number; tokens_out?: number }
|
||||
| { type: 'RUN_COMPLETE'; decision: string }
|
||||
| { type: 'RUN_ERROR'; message: string }
|
||||
| { type: 'CONNECTED' }
|
||||
|
||||
function reducer(state: RunStreamState, action: Action): RunStreamState {
|
||||
|
|
@ -27,19 +31,30 @@ function reducer(state: RunStreamState, action: Action): RunStreamState {
|
|||
return { ...state, status: 'running' }
|
||||
|
||||
case 'AGENT_START':
|
||||
// Only transition to 'running' on first turn (don't regress from 'done')
|
||||
if (state.steps[action.step] !== 'pending') return state
|
||||
return { ...state, steps: { ...state.steps, [action.step]: 'running' } }
|
||||
|
||||
case 'AGENT_COMPLETE':
|
||||
case 'AGENT_COMPLETE': {
|
||||
const dIn = action.tokens_in ?? 0
|
||||
const dOut = action.tokens_out ?? 0
|
||||
const prev = state.tokensByStep[action.step] ?? zeroTokens()
|
||||
return {
|
||||
...state,
|
||||
steps: { ...state.steps, [action.step]: 'done' },
|
||||
steps: { ...state.steps, [action.step]: 'done' },
|
||||
reports: {
|
||||
...state.reports,
|
||||
[action.step]: [...(state.reports[action.step] ?? []), action.report],
|
||||
},
|
||||
tokensByStep: {
|
||||
...state.tokensByStep,
|
||||
[action.step]: { in: prev.in + dIn, out: prev.out + dOut },
|
||||
},
|
||||
tokensTotal: {
|
||||
in: state.tokensTotal.in + dIn,
|
||||
out: state.tokensTotal.out + dOut,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
case 'RUN_COMPLETE':
|
||||
return { ...state, status: 'complete', verdict: action.decision as RunStreamState['verdict'] }
|
||||
|
|
@ -56,17 +71,43 @@ export function useRunStream(runId: string): RunStreamState {
|
|||
const [state, dispatch] = useReducer(reducer, initialState)
|
||||
|
||||
useEffect(() => {
|
||||
const url = getRunStreamUrl(runId)
|
||||
const close = createSSEConnection(url, {
|
||||
onOpen: () => dispatch({ type: 'CONNECTED' }),
|
||||
onAgentStart: ({ step, turn }) =>
|
||||
dispatch({ type: 'AGENT_START', step: step as AgentStep, turn }),
|
||||
onAgentComplete: ({ step, turn, report }) =>
|
||||
dispatch({ type: 'AGENT_COMPLETE', step: step as AgentStep, turn, report }),
|
||||
onRunComplete: ({ decision }) => dispatch({ type: 'RUN_COMPLETE', decision }),
|
||||
onRunError: ({ message }) => dispatch({ type: 'RUN_ERROR', message }),
|
||||
let close: (() => void) | undefined
|
||||
let aborted = false
|
||||
|
||||
getRun(runId).then((run) => {
|
||||
if (aborted) return
|
||||
|
||||
if (run.status === 'complete' && run.reports) {
|
||||
dispatch({ type: 'CONNECTED' })
|
||||
for (const [key, report] of Object.entries(run.reports)) {
|
||||
const lastColon = key.lastIndexOf(':')
|
||||
const step = key.slice(0, lastColon) as AgentStep
|
||||
const turn = parseInt(key.slice(lastColon + 1), 10)
|
||||
const tok = run.token_usage?.[key] ?? { tokens_in: 0, tokens_out: 0 }
|
||||
dispatch({ type: 'AGENT_START', step, turn })
|
||||
dispatch({ type: 'AGENT_COMPLETE', step, turn, report,
|
||||
tokens_in: tok.tokens_in, tokens_out: tok.tokens_out })
|
||||
}
|
||||
dispatch({ type: 'RUN_COMPLETE', decision: run.decision ?? 'HOLD' })
|
||||
return
|
||||
}
|
||||
|
||||
const url = getRunStreamUrl(runId)
|
||||
close = createSSEConnection(url, {
|
||||
onOpen: () => dispatch({ type: 'CONNECTED' }),
|
||||
onAgentStart: ({ step, turn }) =>
|
||||
dispatch({ type: 'AGENT_START', step: step as AgentStep, turn }),
|
||||
onAgentComplete: ({ step, turn, report, tokens_in, tokens_out }) =>
|
||||
dispatch({ type: 'AGENT_COMPLETE', step: step as AgentStep, turn, report,
|
||||
tokens_in, tokens_out }),
|
||||
onRunComplete: ({ decision }) => dispatch({ type: 'RUN_COMPLETE', decision }),
|
||||
onRunError: ({ message }) => dispatch({ type: 'RUN_ERROR', message }),
|
||||
})
|
||||
}).catch(() => {
|
||||
if (!aborted) dispatch({ type: 'RUN_ERROR', message: 'Failed to load run' })
|
||||
})
|
||||
return close
|
||||
|
||||
return () => { aborted = true; close?.() }
|
||||
}, [runId])
|
||||
|
||||
return state
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
import type { AgentStep, RunStatus } from '@/lib/types/run'
|
||||
import type { Decision, StepStatus } from '@/lib/types/agents'
|
||||
|
||||
export type TokenCount = { in: number; out: number }
|
||||
|
||||
export type RunStreamState = {
|
||||
status: RunStatus | 'connecting'
|
||||
steps: Record<AgentStep, StepStatus>
|
||||
reports: Record<AgentStep, string[]>
|
||||
verdict: Decision | null
|
||||
error: string | null
|
||||
tokensByStep: Record<AgentStep, TokenCount>
|
||||
tokensTotal: TokenCount
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,14 +14,21 @@ async function apiFetch<T>(path: string, init?: RequestInit): Promise<T> {
|
|||
return res.json() as Promise<T>
|
||||
}
|
||||
|
||||
export type RunResult = RunSummary & {
|
||||
config: RunConfig | null
|
||||
reports: Record<string, string>
|
||||
error: string | null
|
||||
token_usage: Record<string, { tokens_in: number; tokens_out: number }> | null
|
||||
}
|
||||
|
||||
export const createRun = (config: RunConfig): Promise<RunSummary> =>
|
||||
apiFetch('/api/runs', { method: 'POST', body: JSON.stringify(config) })
|
||||
|
||||
export const listRuns = (): Promise<RunSummary[]> =>
|
||||
apiFetch('/api/runs')
|
||||
|
||||
export const getRun = (id: string): Promise<RunSummary> =>
|
||||
apiFetch(`/api/runs/${id}`)
|
||||
export const getRun = (id: string): Promise<RunResult> =>
|
||||
apiFetch<RunResult>(`/api/runs/${id}`)
|
||||
|
||||
export const getSettings = (): Promise<Settings> =>
|
||||
apiFetch('/api/settings')
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
export type SSEHandlers = {
|
||||
onAgentStart?: (data: { step: string; turn: number }) => void
|
||||
onAgentComplete?: (data: { step: string; turn: number; report: string }) => void
|
||||
onAgentComplete?: (data: {
|
||||
step: string; turn: number; report: string;
|
||||
tokens_in?: number; tokens_out?: number
|
||||
}) => void
|
||||
onRunComplete?: (data: { decision: string; run_id: string }) => void
|
||||
onRunError?: (data: { message: string }) => void
|
||||
onOpen?: () => void
|
||||
|
|
|
|||
Loading…
Reference in New Issue