TradingAgents/tradingagents/observability.py

326 lines
13 KiB
Python

"""Structured observability logging for TradingAgents.
Emits JSON-lines logs capturing:
- LLM calls (model, agent/node, token counts, latency)
- Tool calls (tool name, args summary, success/failure, latency)
- Data vendor calls (method, vendor, success/failure, fallback chain)
- Report saves
Usage:
from tradingagents.observability import RunLogger, get_run_logger
logger = RunLogger() # one per run
# pass logger.callback to LangChain as a callback handler
# pass logger to route_to_vendor / run_tool_loop via context
All events are written as JSON lines to a file and also to Python's
``logging`` module at DEBUG level for console visibility.
"""
from __future__ import annotations
import json
import logging
import threading
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.outputs import LLMResult
_py_logger = logging.getLogger("tradingagents.observability")
# ──────────────────────────────────────────────────────────────────────────────
# Event dataclass — each logged event becomes one JSON line
# ──────────────────────────────────────────────────────────────────────────────
@dataclass
class _Event:
kind: str # "llm", "tool", "vendor", "report"
ts: float # time.time()
data: dict # kind-specific payload
def to_dict(self) -> dict:
return {"kind": self.kind, "ts": self.ts, **self.data}
# ──────────────────────────────────────────────────────────────────────────────
# RunLogger — accumulates events and writes a JSON-lines log file
# ──────────────────────────────────────────────────────────────────────────────
class RunLogger:
"""Accumulates structured events for a single run (analyze / scan / pipeline).
Attributes:
callback: A LangChain ``BaseCallbackHandler`` that can be passed to
LLM constructors or graph invocations.
events: Thread-safe list of all recorded events.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self.events: list[_Event] = []
self.callback = _LLMCallbackHandler(self)
self._start = time.time()
# -- public helpers to record events from non-callback code ----------------
def log_vendor_call(
self,
method: str,
vendor: str,
success: bool,
duration_ms: float,
error: str | None = None,
args_summary: str = "",
) -> None:
"""Record a data-vendor call (called from ``route_to_vendor``)."""
evt = _Event(
kind="vendor",
ts=time.time(),
data={
"method": method,
"vendor": vendor,
"success": success,
"duration_ms": round(duration_ms, 1),
"error": error,
"args": args_summary,
},
)
self._append(evt)
def log_tool_call(
self,
tool_name: str,
args_summary: str,
success: bool,
duration_ms: float,
error: str | None = None,
) -> None:
"""Record a tool invocation (called from ``run_tool_loop``)."""
evt = _Event(
kind="tool",
ts=time.time(),
data={
"tool": tool_name,
"args": args_summary,
"success": success,
"duration_ms": round(duration_ms, 1),
"error": error,
},
)
self._append(evt)
def log_report_save(self, path: str) -> None:
"""Record that a report file was written."""
evt = _Event(kind="report", ts=time.time(), data={"path": path})
self._append(evt)
# -- summary ---------------------------------------------------------------
def summary(self) -> dict:
"""Return an aggregated summary suitable for ``run_summary.json``."""
with self._lock:
events = list(self.events)
llm_events = [e for e in events if e.kind == "llm"]
tool_events = [e for e in events if e.kind == "tool"]
vendor_events = [e for e in events if e.kind == "vendor"]
total_tokens_in = sum(e.data.get("tokens_in", 0) for e in llm_events)
total_tokens_out = sum(e.data.get("tokens_out", 0) for e in llm_events)
vendor_ok = sum(1 for e in vendor_events if e.data["success"])
vendor_fail = sum(1 for e in vendor_events if not e.data["success"])
# Group LLM calls by model
model_counts: dict[str, int] = {}
for e in llm_events:
m = e.data.get("model", "unknown")
model_counts[m] = model_counts.get(m, 0) + 1
# Group vendor calls by vendor
vendor_counts: dict[str, dict] = {}
for e in vendor_events:
v = e.data["vendor"]
if v not in vendor_counts:
vendor_counts[v] = {"ok": 0, "fail": 0}
if e.data["success"]:
vendor_counts[v]["ok"] += 1
else:
vendor_counts[v]["fail"] += 1
return {
"elapsed_s": round(time.time() - self._start, 1),
"llm_calls": len(llm_events),
"tokens_in": total_tokens_in,
"tokens_out": total_tokens_out,
"tokens_total": total_tokens_in + total_tokens_out,
"models_used": model_counts,
"tool_calls": len(tool_events),
"tool_success": sum(1 for e in tool_events if e.data["success"]),
"tool_fail": sum(1 for e in tool_events if not e.data["success"]),
"vendor_calls": len(vendor_events),
"vendor_success": vendor_ok,
"vendor_fail": vendor_fail,
"vendors_used": vendor_counts,
}
def write_log(self, path: Path) -> None:
"""Write all events as JSON lines + a summary block to *path*."""
path.parent.mkdir(parents=True, exist_ok=True)
with self._lock:
events = list(self.events)
lines = [json.dumps(e.to_dict()) for e in events]
lines.append(json.dumps({"kind": "summary", **self.summary()}))
path.write_text("\n".join(lines) + "\n")
_py_logger.info("Run log written to %s", path)
# -- internals -------------------------------------------------------------
def _append(self, evt: _Event) -> None:
with self._lock:
self.events.append(evt)
_py_logger.debug("%s | %s", evt.kind, json.dumps(evt.data))
# ──────────────────────────────────────────────────────────────────────────────
# LangChain callback handler — captures LLM call details
# ──────────────────────────────────────────────────────────────────────────────
class _LLMCallbackHandler(BaseCallbackHandler):
"""LangChain callback that feeds LLM events into a ``RunLogger``."""
def __init__(self, run_logger: RunLogger) -> None:
super().__init__()
self._rl = run_logger
self._lock = threading.Lock()
# Track in-flight calls: run_id -> metadata
self._inflight: dict[str, dict] = {}
# -- chat model start (preferred path for ChatOpenAI / ChatAnthropic) ------
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[Any]],
*,
run_id: Any = None,
**kwargs: Any,
) -> None:
model = _extract_model(serialized, kwargs)
agent = kwargs.get("name") or serialized.get("name") or _extract_graph_node(kwargs)
key = str(run_id) if run_id else str(id(messages))
with self._lock:
self._inflight[key] = {
"model": model,
"agent": agent or "",
"t0": time.time(),
}
# -- legacy LLM start (completion-style) -----------------------------------
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: Any = None,
**kwargs: Any,
) -> None:
model = _extract_model(serialized, kwargs)
agent = kwargs.get("name") or serialized.get("name") or _extract_graph_node(kwargs)
key = str(run_id) if run_id else str(id(prompts))
with self._lock:
self._inflight[key] = {
"model": model,
"agent": agent or "",
"t0": time.time(),
}
# -- LLM end ---------------------------------------------------------------
def on_llm_end(self, response: LLMResult, *, run_id: Any = None, **kwargs: Any) -> None:
key = str(run_id) if run_id else None
with self._lock:
meta = self._inflight.pop(key, None) if key else None
tokens_in = 0
tokens_out = 0
model_from_response = ""
try:
generation = response.generations[0][0]
if hasattr(generation, "message"):
msg = generation.message
if isinstance(msg, AIMessage) and hasattr(msg, "usage_metadata") and msg.usage_metadata:
tokens_in = msg.usage_metadata.get("input_tokens", 0)
tokens_out = msg.usage_metadata.get("output_tokens", 0)
if hasattr(msg, "response_metadata"):
model_from_response = msg.response_metadata.get("model_name", "") or msg.response_metadata.get("model", "")
except (IndexError, TypeError, AttributeError):
pass
model = model_from_response or (meta["model"] if meta else "unknown")
agent = meta["agent"] if meta else ""
duration_ms = (time.time() - meta["t0"]) * 1000 if meta else 0
evt = _Event(
kind="llm",
ts=time.time(),
data={
"model": model,
"agent": agent,
"tokens_in": tokens_in,
"tokens_out": tokens_out,
"duration_ms": round(duration_ms, 1),
},
)
self._rl._append(evt)
# ──────────────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────────────
def _extract_model(serialized: dict, kwargs: dict) -> str:
"""Best-effort model name from LangChain callback metadata."""
# kwargs.invocation_params often has the model name
inv = kwargs.get("invocation_params") or {}
model = inv.get("model_name") or inv.get("model") or ""
if model:
return model
# serialized might have it nested
kw = serialized.get("kwargs", {})
return kw.get("model_name") or kw.get("model") or serialized.get("id", [""])[-1]
def _extract_graph_node(kwargs: dict) -> str:
"""Try to get the current graph node name from LangGraph metadata."""
tags = kwargs.get("tags") or []
for tag in tags:
if isinstance(tag, str) and tag.startswith("graph:node:"):
return tag.split(":", 2)[-1]
metadata = kwargs.get("metadata") or {}
return metadata.get("langgraph_node", "")
# ──────────────────────────────────────────────────────────────────────────────
# Thread-local context for passing RunLogger to vendor/tool layers
# ──────────────────────────────────────────────────────────────────────────────
_current_run_logger: threading.local = threading.local()
def set_run_logger(rl: RunLogger | None) -> None:
"""Set the active RunLogger for the current thread."""
_current_run_logger.instance = rl
def get_run_logger() -> RunLogger | None:
"""Get the active RunLogger (or None if not set)."""
return getattr(_current_run_logger, "instance", None)