feat: add footer statistics tracking with LangChain callbacks

- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens
- Integrate callbacks into TradingAgentsGraph and all LLM clients
- Dynamic agent/report counts based on selected analysts
- Fix report completion counting (tied to agent completion)
This commit is contained in:
Yijia Xiao 2026-02-02 22:00:37 +00:00
parent b06936f420
commit 54cdb146d0
No known key found for this signature in database
10 changed files with 277 additions and 112 deletions

View File

@ -15,7 +15,6 @@ from rich.columns import Columns
from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
from rich.live import Live
from rich.table import Table
from collections import deque
import time
@ -29,6 +28,7 @@ from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
from cli.announcements import fetch_announcements, display_announcements
from cli.stats_handler import StatsCallbackHandler
console = Console()
@ -41,40 +41,99 @@ app = typer.Typer(
# Create a deque to store recent messages with a maximum length
class MessageBuffer:
# Fixed teams that always run (not user-selectable)
FIXED_AGENTS = {
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
"Portfolio Management": ["Portfolio Manager"],
}
# Analyst name mapping
ANALYST_MAPPING = {
"market": "Market Analyst",
"social": "Social Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
}
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
# analyst_key: which analyst selection controls this section (None = always included)
# finalizing_agent: which agent must be "completed" for this report to count as done
REPORT_SECTIONS = {
"market_report": ("market", "Market Analyst"),
"sentiment_report": ("social", "Social Analyst"),
"news_report": ("news", "News Analyst"),
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
"investment_plan": (None, "Research Manager"),
"trader_investment_plan": (None, "Trader"),
"final_trade_decision": (None, "Portfolio Manager"),
}
def __init__(self, max_length=100):
self.messages = deque(maxlen=max_length)
self.tool_calls = deque(maxlen=max_length)
self.current_report = None
self.final_report = None # Store the complete final report
self.agent_status = {
# Analyst Team
"Market Analyst": "pending",
"Social Analyst": "pending",
"News Analyst": "pending",
"Fundamentals Analyst": "pending",
# Research Team
"Bull Researcher": "pending",
"Bear Researcher": "pending",
"Research Manager": "pending",
# Trading Team
"Trader": "pending",
# Risk Management Team
"Aggressive Analyst": "pending",
"Neutral Analyst": "pending",
"Conservative Analyst": "pending",
# Portfolio Management Team
"Portfolio Manager": "pending",
}
self.agent_status = {}
self.current_agent = None
self.report_sections = {
"market_report": None,
"sentiment_report": None,
"news_report": None,
"fundamentals_report": None,
"investment_plan": None,
"trader_investment_plan": None,
"final_trade_decision": None,
}
self.report_sections = {}
self.selected_analysts = []
def init_for_analysis(self, selected_analysts):
"""Initialize agent status and report sections based on selected analysts.
Args:
selected_analysts: List of analyst type strings (e.g., ["market", "news"])
"""
self.selected_analysts = [a.lower() for a in selected_analysts]
# Build agent_status dynamically
self.agent_status = {}
# Add selected analysts
for analyst_key in self.selected_analysts:
if analyst_key in self.ANALYST_MAPPING:
self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending"
# Add fixed teams
for team_agents in self.FIXED_AGENTS.values():
for agent in team_agents:
self.agent_status[agent] = "pending"
# Build report_sections dynamically
self.report_sections = {}
for section, (analyst_key, _) in self.REPORT_SECTIONS.items():
if analyst_key is None or analyst_key in self.selected_analysts:
self.report_sections[section] = None
# Reset other state
self.current_report = None
self.final_report = None
self.current_agent = None
self.messages.clear()
self.tool_calls.clear()
def get_completed_reports_count(self):
"""Count reports that are finalized (their finalizing agent is completed).
A report is considered complete when:
1. The report section has content (not None), AND
2. The agent responsible for finalizing that report has status "completed"
This prevents interim updates (like debate rounds) from counting as completed.
"""
count = 0
for section in self.report_sections:
if section not in self.REPORT_SECTIONS:
continue
_, finalizing_agent = self.REPORT_SECTIONS[section]
# Report is complete if it has content AND its finalizing agent is done
has_content = self.report_sections.get(section) is not None
agent_done = self.agent_status.get(finalizing_agent) == "completed"
if has_content and agent_done:
count += 1
return count
def add_message(self, message_type, content):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
@ -126,46 +185,39 @@ class MessageBuffer:
def _update_final_report(self):
report_parts = []
# Analyst Team Reports
if any(
self.report_sections[section]
for section in [
"market_report",
"sentiment_report",
"news_report",
"fundamentals_report",
]
):
# Analyst Team Reports - use .get() to handle missing sections
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
if any(self.report_sections.get(section) for section in analyst_sections):
report_parts.append("## Analyst Team Reports")
if self.report_sections["market_report"]:
if self.report_sections.get("market_report"):
report_parts.append(
f"### Market Analysis\n{self.report_sections['market_report']}"
)
if self.report_sections["sentiment_report"]:
if self.report_sections.get("sentiment_report"):
report_parts.append(
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
)
if self.report_sections["news_report"]:
if self.report_sections.get("news_report"):
report_parts.append(
f"### News Analysis\n{self.report_sections['news_report']}"
)
if self.report_sections["fundamentals_report"]:
if self.report_sections.get("fundamentals_report"):
report_parts.append(
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
)
# Research Team Reports
if self.report_sections["investment_plan"]:
if self.report_sections.get("investment_plan"):
report_parts.append("## Research Team Decision")
report_parts.append(f"{self.report_sections['investment_plan']}")
# Trading Team Reports
if self.report_sections["trader_investment_plan"]:
if self.report_sections.get("trader_investment_plan"):
report_parts.append("## Trading Team Plan")
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
# Portfolio Management Decision
if self.report_sections["final_trade_decision"]:
if self.report_sections.get("final_trade_decision"):
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{self.report_sections['final_trade_decision']}")
@ -191,7 +243,14 @@ def create_layout():
return layout
def update_display(layout, spinner_text=None):
def format_tokens(n):
"""Format token count for display."""
if n >= 1000:
return f"{n/1000:.1f}k"
return str(n)
def update_display(layout, spinner_text=None, stats_handler=None, start_time=None):
# Header with welcome message
layout["header"].update(
Panel(
@ -218,8 +277,8 @@ def update_display(layout, spinner_text=None):
progress_table.add_column("Agent", style="green", justify="center", width=20)
progress_table.add_column("Status", style="yellow", justify="center", width=20)
# Group agents by team
teams = {
# Group agents by team - filter to only include agents in agent_status
all_teams = {
"Analyst Team": [
"Market Analyst",
"Social Analyst",
@ -232,10 +291,17 @@ def update_display(layout, spinner_text=None):
"Portfolio Management": ["Portfolio Manager"],
}
# Filter teams to only include agents that are in agent_status
teams = {}
for team, agents in all_teams.items():
active_agents = [a for a in agents if a in message_buffer.agent_status]
if active_agents:
teams[team] = active_agents
for team, agents in teams.items():
# Add first agent with team name
first_agent = agents[0]
status = message_buffer.agent_status[first_agent]
status = message_buffer.agent_status.get(first_agent, "pending")
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
@ -252,7 +318,7 @@ def update_display(layout, spinner_text=None):
# Add remaining agents in team
for agent in agents[1:]:
status = message_buffer.agent_status[agent]
status = message_buffer.agent_status.get(agent, "pending")
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
@ -379,19 +445,43 @@ def update_display(layout, spinner_text=None):
)
# Footer with statistics
tool_calls_count = len(message_buffer.tool_calls)
llm_calls_count = sum(
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
)
reports_count = sum(
1 for content in message_buffer.report_sections.values() if content is not None
# Agent progress - derived from agent_status dict
agents_completed = sum(
1 for status in message_buffer.agent_status.values() if status == "completed"
)
agents_total = len(message_buffer.agent_status)
# Report progress - based on agent completion (not just content existence)
reports_completed = message_buffer.get_completed_reports_count()
reports_total = len(message_buffer.report_sections)
# Build stats parts
stats_parts = [f"Agents: {agents_completed}/{agents_total}"]
# LLM and tool stats from callback handler
if stats_handler:
stats = stats_handler.get_stats()
stats_parts.append(f"LLM: {stats['llm_calls']}")
stats_parts.append(f"Tools: {stats['tool_calls']}")
# Token display with graceful fallback
if stats["tokens_in"] > 0 or stats["tokens_out"] > 0:
tokens_str = f"Tokens: {format_tokens(stats['tokens_in'])}\u2191 {format_tokens(stats['tokens_out'])}\u2193"
else:
tokens_str = "Tokens: --"
stats_parts.append(tokens_str)
stats_parts.append(f"Reports: {reports_completed}/{reports_total}")
# Elapsed time
if start_time:
elapsed = time.time() - start_time
elapsed_str = f"\u23f1 {int(elapsed // 60):02d}:{int(elapsed % 60):02d}"
stats_parts.append(elapsed_str)
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
stats_table.add_column("Stats", justify="center")
stats_table.add_row(
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
)
stats_table.add_row(" | ".join(stats_parts))
layout["footer"].update(Panel(stats_table, border_style="grey50"))
@ -803,11 +893,24 @@ def run_analysis():
config["google_thinking_level"] = selections.get("google_thinking_level")
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
# Initialize the graph
# Create stats callback handler for tracking LLM/tool calls
stats_handler = StatsCallbackHandler()
# Initialize the graph with callbacks bound to LLMs
graph = TradingAgentsGraph(
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
[analyst.value for analyst in selections["analysts"]],
config=config,
debug=True,
callbacks=[stats_handler],
)
# Initialize message buffer with selected analysts
selected_analyst_keys = [analyst.value for analyst in selections["analysts"]]
message_buffer.init_for_analysis(selected_analyst_keys)
# Track start time for elapsed display
start_time = time.time()
# Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True)
@ -860,7 +963,7 @@ def run_analysis():
with Live(layout, refresh_per_second=4) as live:
# Initial display
update_display(layout)
update_display(layout, stats_handler=stats_handler, start_time=start_time)
# Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
@ -871,34 +974,26 @@ def run_analysis():
"System",
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
)
update_display(layout)
# Reset agent statuses
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "pending")
# Reset report sections
for section in message_buffer.report_sections:
message_buffer.report_sections[section] = None
message_buffer.current_report = None
message_buffer.final_report = None
update_display(layout, stats_handler=stats_handler, start_time=start_time)
# Update agent status to in_progress for the first analyst
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
message_buffer.update_agent_status(first_analyst, "in_progress")
update_display(layout)
update_display(layout, stats_handler=stats_handler, start_time=start_time)
# Create spinner text
spinner_text = (
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
)
update_display(layout, spinner_text)
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
# Initialize state and get graph args
# Initialize state and get graph args with callbacks
init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"]
)
args = graph.propagator.get_graph_args()
# Pass callbacks to graph config for tool execution tracking
# (LLM tracking is handled separately via LLM constructor)
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
# Stream the analysis
trace = []
@ -1112,7 +1207,7 @@ def run_analysis():
)
# Update the display
update_display(layout)
update_display(layout, stats_handler=stats_handler, start_time=start_time)
trace.append(chunk)
@ -1136,7 +1231,7 @@ def run_analysis():
# Display the complete final report
display_complete_report(final_state)
update_display(layout)
update_display(layout, stats_handler=stats_handler, start_time=start_time)
@app.command()

76
cli/stats_handler.py Normal file
View File

@ -0,0 +1,76 @@
import threading
from typing import Any, Dict, List, Union
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.messages import AIMessage
class StatsCallbackHandler(BaseCallbackHandler):
"""Callback handler that tracks LLM calls, tool calls, and token usage."""
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
self.llm_calls = 0
self.tool_calls = 0
self.tokens_in = 0
self.tokens_out = 0
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
**kwargs: Any,
) -> None:
"""Increment LLM call counter when an LLM starts."""
with self._lock:
self.llm_calls += 1
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[Any]],
**kwargs: Any,
) -> None:
"""Increment LLM call counter when a chat model starts."""
with self._lock:
self.llm_calls += 1
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Extract token usage from LLM response."""
try:
generation = response.generations[0][0]
except (IndexError, TypeError):
return
usage_metadata = None
if hasattr(generation, "message"):
message = generation.message
if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"):
usage_metadata = message.usage_metadata
if usage_metadata:
with self._lock:
self.tokens_in += usage_metadata.get("input_tokens", 0)
self.tokens_out += usage_metadata.get("output_tokens", 0)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Increment tool call counter when a tool starts."""
with self._lock:
self.tool_calls += 1
def get_stats(self) -> Dict[str, Any]:
"""Return current statistics."""
with self._lock:
return {
"llm_calls": self.llm_calls,
"tool_calls": self.tool_calls,
"tokens_in": self.tokens_in,
"tokens_out": self.tokens_out,
}

View File

@ -1,6 +1,6 @@
# TradingAgents/graph/propagation.py
from typing import Dict, Any
from typing import Dict, Any, List, Optional
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
@ -41,9 +41,17 @@ class Propagator:
"news_report": "",
}
def get_graph_args(self) -> Dict[str, Any]:
"""Get arguments for the graph invocation."""
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
"""Get arguments for the graph invocation.
Args:
callbacks: Optional list of callback handlers for tool execution tracking.
Note: LLM callbacks are handled separately via LLM constructor.
"""
config = {"recursion_limit": self.max_recur_limit}
if callbacks:
config["callbacks"] = callbacks
return {
"stream_mode": "values",
"config": {"recursion_limit": self.max_recur_limit},
"config": config,
}

View File

@ -48,6 +48,7 @@ class TradingAgentsGraph:
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
callbacks: Optional[List] = None,
):
"""Initialize the trading agents graph and components.
@ -55,9 +56,11 @@ class TradingAgentsGraph:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
self.callbacks = callbacks or []
# Update the interface's config
set_config(self.config)
@ -71,6 +74,10 @@ class TradingAgentsGraph:
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
# Add callbacks to kwargs if provided (passed to LLM constructor)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
@ -83,6 +90,7 @@ class TradingAgentsGraph:
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()

View File

@ -16,7 +16,7 @@ class AnthropicClient(BaseLLMClient):
"""Return configured ChatAnthropic instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]

View File

@ -4,7 +4,6 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .vllm_client import VLLMClient
def create_llm_client(
@ -16,7 +15,7 @@ def create_llm_client(
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
@ -41,7 +40,4 @@ def create_llm_client(
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "vllm":
return VLLMClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -38,7 +38,7 @@ class GoogleClient(BaseLLMClient):
"""Return configured ChatGoogleGenerativeAI instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "google_api_key"):
for key in ("timeout", "max_retries", "google_api_key", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]

View File

@ -60,7 +60,7 @@ class OpenAIClient(BaseLLMClient):
elif self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]

View File

@ -69,11 +69,11 @@ VALID_MODELS = {
def validate_model(provider: str, model: str) -> bool:
"""Check if model name is valid for the given provider.
For ollama, openrouter, vllm - any model is accepted.
For ollama, openrouter - any model is accepted.
"""
provider_lower = provider.lower()
if provider_lower in ("ollama", "openrouter", "vllm"):
if provider_lower in ("ollama", "openrouter"):
return True
if provider_lower not in VALID_MODELS:

View File

@ -1,18 +0,0 @@
from typing import Any, Optional
from .base_client import BaseLLMClient
class VLLMClient(BaseLLMClient):
"""Client for vLLM (placeholder for future implementation)."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured vLLM instance."""
raise NotImplementedError("vLLM client not yet implemented")
def validate_model(self) -> bool:
"""Validate model for vLLM."""
return True