From 54cdb146d0d6aaf32adadca8414bed36020c57e7 Mon Sep 17 00:00:00 2001 From: Yijia Xiao Date: Mon, 2 Feb 2026 22:00:37 +0000 Subject: [PATCH] feat: add footer statistics tracking with LangChain callbacks - Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens - Integrate callbacks into TradingAgentsGraph and all LLM clients - Dynamic agent/report counts based on selected analysts - Fix report completion counting (tied to agent completion) --- cli/main.py | 255 ++++++++++++------ cli/stats_handler.py | 76 ++++++ tradingagents/graph/propagation.py | 16 +- tradingagents/graph/trading_graph.py | 8 + tradingagents/llm_clients/anthropic_client.py | 2 +- tradingagents/llm_clients/factory.py | 6 +- tradingagents/llm_clients/google_client.py | 2 +- tradingagents/llm_clients/openai_client.py | 2 +- tradingagents/llm_clients/validators.py | 4 +- tradingagents/llm_clients/vllm_client.py | 18 -- 10 files changed, 277 insertions(+), 112 deletions(-) create mode 100644 cli/stats_handler.py delete mode 100644 tradingagents/llm_clients/vllm_client.py diff --git a/cli/main.py b/cli/main.py index f555a81f..614b43f2 100644 --- a/cli/main.py +++ b/cli/main.py @@ -15,7 +15,6 @@ from rich.columns import Columns from rich.markdown import Markdown from rich.layout import Layout from rich.text import Text -from rich.live import Live from rich.table import Table from collections import deque import time @@ -29,6 +28,7 @@ from tradingagents.default_config import DEFAULT_CONFIG from cli.models import AnalystType from cli.utils import * from cli.announcements import fetch_announcements, display_announcements +from cli.stats_handler import StatsCallbackHandler console = Console() @@ -41,40 +41,99 @@ app = typer.Typer( # Create a deque to store recent messages with a maximum length class MessageBuffer: + # Fixed teams that always run (not user-selectable) + FIXED_AGENTS = { + "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], + "Trading Team": ["Trader"], + "Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"], + "Portfolio Management": ["Portfolio Manager"], + } + + # Analyst name mapping + ANALYST_MAPPING = { + "market": "Market Analyst", + "social": "Social Analyst", + "news": "News Analyst", + "fundamentals": "Fundamentals Analyst", + } + + # Report section mapping: section -> (analyst_key for filtering, finalizing_agent) + # analyst_key: which analyst selection controls this section (None = always included) + # finalizing_agent: which agent must be "completed" for this report to count as done + REPORT_SECTIONS = { + "market_report": ("market", "Market Analyst"), + "sentiment_report": ("social", "Social Analyst"), + "news_report": ("news", "News Analyst"), + "fundamentals_report": ("fundamentals", "Fundamentals Analyst"), + "investment_plan": (None, "Research Manager"), + "trader_investment_plan": (None, "Trader"), + "final_trade_decision": (None, "Portfolio Manager"), + } + def __init__(self, max_length=100): self.messages = deque(maxlen=max_length) self.tool_calls = deque(maxlen=max_length) self.current_report = None self.final_report = None # Store the complete final report - self.agent_status = { - # Analyst Team - "Market Analyst": "pending", - "Social Analyst": "pending", - "News Analyst": "pending", - "Fundamentals Analyst": "pending", - # Research Team - "Bull Researcher": "pending", - "Bear Researcher": "pending", - "Research Manager": "pending", - # Trading Team - "Trader": "pending", - # Risk Management Team - "Aggressive Analyst": "pending", - "Neutral Analyst": "pending", - "Conservative Analyst": "pending", - # Portfolio Management Team - "Portfolio Manager": "pending", - } + self.agent_status = {} self.current_agent = None - self.report_sections = { - "market_report": None, - "sentiment_report": None, - "news_report": None, - "fundamentals_report": None, - "investment_plan": None, - "trader_investment_plan": None, - "final_trade_decision": None, - } + self.report_sections = {} + self.selected_analysts = [] + + def init_for_analysis(self, selected_analysts): + """Initialize agent status and report sections based on selected analysts. + + Args: + selected_analysts: List of analyst type strings (e.g., ["market", "news"]) + """ + self.selected_analysts = [a.lower() for a in selected_analysts] + + # Build agent_status dynamically + self.agent_status = {} + + # Add selected analysts + for analyst_key in self.selected_analysts: + if analyst_key in self.ANALYST_MAPPING: + self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending" + + # Add fixed teams + for team_agents in self.FIXED_AGENTS.values(): + for agent in team_agents: + self.agent_status[agent] = "pending" + + # Build report_sections dynamically + self.report_sections = {} + for section, (analyst_key, _) in self.REPORT_SECTIONS.items(): + if analyst_key is None or analyst_key in self.selected_analysts: + self.report_sections[section] = None + + # Reset other state + self.current_report = None + self.final_report = None + self.current_agent = None + self.messages.clear() + self.tool_calls.clear() + + def get_completed_reports_count(self): + """Count reports that are finalized (their finalizing agent is completed). + + A report is considered complete when: + 1. The report section has content (not None), AND + 2. The agent responsible for finalizing that report has status "completed" + + This prevents interim updates (like debate rounds) from counting as completed. + """ + count = 0 + for section in self.report_sections: + if section not in self.REPORT_SECTIONS: + continue + _, finalizing_agent = self.REPORT_SECTIONS[section] + # Report is complete if it has content AND its finalizing agent is done + has_content = self.report_sections.get(section) is not None + agent_done = self.agent_status.get(finalizing_agent) == "completed" + if has_content and agent_done: + count += 1 + return count def add_message(self, message_type, content): timestamp = datetime.datetime.now().strftime("%H:%M:%S") @@ -126,46 +185,39 @@ class MessageBuffer: def _update_final_report(self): report_parts = [] - # Analyst Team Reports - if any( - self.report_sections[section] - for section in [ - "market_report", - "sentiment_report", - "news_report", - "fundamentals_report", - ] - ): + # Analyst Team Reports - use .get() to handle missing sections + analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"] + if any(self.report_sections.get(section) for section in analyst_sections): report_parts.append("## Analyst Team Reports") - if self.report_sections["market_report"]: + if self.report_sections.get("market_report"): report_parts.append( f"### Market Analysis\n{self.report_sections['market_report']}" ) - if self.report_sections["sentiment_report"]: + if self.report_sections.get("sentiment_report"): report_parts.append( f"### Social Sentiment\n{self.report_sections['sentiment_report']}" ) - if self.report_sections["news_report"]: + if self.report_sections.get("news_report"): report_parts.append( f"### News Analysis\n{self.report_sections['news_report']}" ) - if self.report_sections["fundamentals_report"]: + if self.report_sections.get("fundamentals_report"): report_parts.append( f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" ) # Research Team Reports - if self.report_sections["investment_plan"]: + if self.report_sections.get("investment_plan"): report_parts.append("## Research Team Decision") report_parts.append(f"{self.report_sections['investment_plan']}") # Trading Team Reports - if self.report_sections["trader_investment_plan"]: + if self.report_sections.get("trader_investment_plan"): report_parts.append("## Trading Team Plan") report_parts.append(f"{self.report_sections['trader_investment_plan']}") # Portfolio Management Decision - if self.report_sections["final_trade_decision"]: + if self.report_sections.get("final_trade_decision"): report_parts.append("## Portfolio Management Decision") report_parts.append(f"{self.report_sections['final_trade_decision']}") @@ -191,7 +243,14 @@ def create_layout(): return layout -def update_display(layout, spinner_text=None): +def format_tokens(n): + """Format token count for display.""" + if n >= 1000: + return f"{n/1000:.1f}k" + return str(n) + + +def update_display(layout, spinner_text=None, stats_handler=None, start_time=None): # Header with welcome message layout["header"].update( Panel( @@ -218,8 +277,8 @@ def update_display(layout, spinner_text=None): progress_table.add_column("Agent", style="green", justify="center", width=20) progress_table.add_column("Status", style="yellow", justify="center", width=20) - # Group agents by team - teams = { + # Group agents by team - filter to only include agents in agent_status + all_teams = { "Analyst Team": [ "Market Analyst", "Social Analyst", @@ -232,10 +291,17 @@ def update_display(layout, spinner_text=None): "Portfolio Management": ["Portfolio Manager"], } + # Filter teams to only include agents that are in agent_status + teams = {} + for team, agents in all_teams.items(): + active_agents = [a for a in agents if a in message_buffer.agent_status] + if active_agents: + teams[team] = active_agents + for team, agents in teams.items(): # Add first agent with team name first_agent = agents[0] - status = message_buffer.agent_status[first_agent] + status = message_buffer.agent_status.get(first_agent, "pending") if status == "in_progress": spinner = Spinner( "dots", text="[blue]in_progress[/blue]", style="bold cyan" @@ -252,7 +318,7 @@ def update_display(layout, spinner_text=None): # Add remaining agents in team for agent in agents[1:]: - status = message_buffer.agent_status[agent] + status = message_buffer.agent_status.get(agent, "pending") if status == "in_progress": spinner = Spinner( "dots", text="[blue]in_progress[/blue]", style="bold cyan" @@ -379,19 +445,43 @@ def update_display(layout, spinner_text=None): ) # Footer with statistics - tool_calls_count = len(message_buffer.tool_calls) - llm_calls_count = sum( - 1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning" - ) - reports_count = sum( - 1 for content in message_buffer.report_sections.values() if content is not None + # Agent progress - derived from agent_status dict + agents_completed = sum( + 1 for status in message_buffer.agent_status.values() if status == "completed" ) + agents_total = len(message_buffer.agent_status) + + # Report progress - based on agent completion (not just content existence) + reports_completed = message_buffer.get_completed_reports_count() + reports_total = len(message_buffer.report_sections) + + # Build stats parts + stats_parts = [f"Agents: {agents_completed}/{agents_total}"] + + # LLM and tool stats from callback handler + if stats_handler: + stats = stats_handler.get_stats() + stats_parts.append(f"LLM: {stats['llm_calls']}") + stats_parts.append(f"Tools: {stats['tool_calls']}") + + # Token display with graceful fallback + if stats["tokens_in"] > 0 or stats["tokens_out"] > 0: + tokens_str = f"Tokens: {format_tokens(stats['tokens_in'])}\u2191 {format_tokens(stats['tokens_out'])}\u2193" + else: + tokens_str = "Tokens: --" + stats_parts.append(tokens_str) + + stats_parts.append(f"Reports: {reports_completed}/{reports_total}") + + # Elapsed time + if start_time: + elapsed = time.time() - start_time + elapsed_str = f"\u23f1 {int(elapsed // 60):02d}:{int(elapsed % 60):02d}" + stats_parts.append(elapsed_str) stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True) stats_table.add_column("Stats", justify="center") - stats_table.add_row( - f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}" - ) + stats_table.add_row(" | ".join(stats_parts)) layout["footer"].update(Panel(stats_table, border_style="grey50")) @@ -803,11 +893,24 @@ def run_analysis(): config["google_thinking_level"] = selections.get("google_thinking_level") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") - # Initialize the graph + # Create stats callback handler for tracking LLM/tool calls + stats_handler = StatsCallbackHandler() + + # Initialize the graph with callbacks bound to LLMs graph = TradingAgentsGraph( - [analyst.value for analyst in selections["analysts"]], config=config, debug=True + [analyst.value for analyst in selections["analysts"]], + config=config, + debug=True, + callbacks=[stats_handler], ) + # Initialize message buffer with selected analysts + selected_analyst_keys = [analyst.value for analyst in selections["analysts"]] + message_buffer.init_for_analysis(selected_analyst_keys) + + # Track start time for elapsed display + start_time = time.time() + # Create result directory results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] results_dir.mkdir(parents=True, exist_ok=True) @@ -860,7 +963,7 @@ def run_analysis(): with Live(layout, refresh_per_second=4) as live: # Initial display - update_display(layout) + update_display(layout, stats_handler=stats_handler, start_time=start_time) # Add initial messages message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") @@ -871,34 +974,26 @@ def run_analysis(): "System", f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}", ) - update_display(layout) - - # Reset agent statuses - for agent in message_buffer.agent_status: - message_buffer.update_agent_status(agent, "pending") - - # Reset report sections - for section in message_buffer.report_sections: - message_buffer.report_sections[section] = None - message_buffer.current_report = None - message_buffer.final_report = None + update_display(layout, stats_handler=stats_handler, start_time=start_time) # Update agent status to in_progress for the first analyst first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst" message_buffer.update_agent_status(first_analyst, "in_progress") - update_display(layout) + update_display(layout, stats_handler=stats_handler, start_time=start_time) # Create spinner text spinner_text = ( f"Analyzing {selections['ticker']} on {selections['analysis_date']}..." ) - update_display(layout, spinner_text) + update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time) - # Initialize state and get graph args + # Initialize state and get graph args with callbacks init_agent_state = graph.propagator.create_initial_state( selections["ticker"], selections["analysis_date"] ) - args = graph.propagator.get_graph_args() + # Pass callbacks to graph config for tool execution tracking + # (LLM tracking is handled separately via LLM constructor) + args = graph.propagator.get_graph_args(callbacks=[stats_handler]) # Stream the analysis trace = [] @@ -1112,7 +1207,7 @@ def run_analysis(): ) # Update the display - update_display(layout) + update_display(layout, stats_handler=stats_handler, start_time=start_time) trace.append(chunk) @@ -1136,7 +1231,7 @@ def run_analysis(): # Display the complete final report display_complete_report(final_state) - update_display(layout) + update_display(layout, stats_handler=stats_handler, start_time=start_time) @app.command() diff --git a/cli/stats_handler.py b/cli/stats_handler.py new file mode 100644 index 00000000..10734cc3 --- /dev/null +++ b/cli/stats_handler.py @@ -0,0 +1,76 @@ +import threading +from typing import Any, Dict, List, Union + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult +from langchain_core.messages import AIMessage + + +class StatsCallbackHandler(BaseCallbackHandler): + """Callback handler that tracks LLM calls, tool calls, and token usage.""" + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + self.llm_calls = 0 + self.tool_calls = 0 + self.tokens_in = 0 + self.tokens_out = 0 + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + **kwargs: Any, + ) -> None: + """Increment LLM call counter when an LLM starts.""" + with self._lock: + self.llm_calls += 1 + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[Any]], + **kwargs: Any, + ) -> None: + """Increment LLM call counter when a chat model starts.""" + with self._lock: + self.llm_calls += 1 + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Extract token usage from LLM response.""" + try: + generation = response.generations[0][0] + except (IndexError, TypeError): + return + + usage_metadata = None + if hasattr(generation, "message"): + message = generation.message + if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"): + usage_metadata = message.usage_metadata + + if usage_metadata: + with self._lock: + self.tokens_in += usage_metadata.get("input_tokens", 0) + self.tokens_out += usage_metadata.get("output_tokens", 0) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Increment tool call counter when a tool starts.""" + with self._lock: + self.tool_calls += 1 + + def get_stats(self) -> Dict[str, Any]: + """Return current statistics.""" + with self._lock: + return { + "llm_calls": self.llm_calls, + "tool_calls": self.tool_calls, + "tokens_in": self.tokens_in, + "tokens_out": self.tokens_out, + } diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index dcc1a5aa..7aba5258 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -1,6 +1,6 @@ # TradingAgents/graph/propagation.py -from typing import Dict, Any +from typing import Dict, Any, List, Optional from tradingagents.agents.utils.agent_states import ( AgentState, InvestDebateState, @@ -41,9 +41,17 @@ class Propagator: "news_report": "", } - def get_graph_args(self) -> Dict[str, Any]: - """Get arguments for the graph invocation.""" + def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]: + """Get arguments for the graph invocation. + + Args: + callbacks: Optional list of callback handlers for tool execution tracking. + Note: LLM callbacks are handled separately via LLM constructor. + """ + config = {"recursion_limit": self.max_recur_limit} + if callbacks: + config["callbacks"] = callbacks return { "stream_mode": "values", - "config": {"recursion_limit": self.max_recur_limit}, + "config": config, } diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index d8dff204..44ecca0c 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -48,6 +48,7 @@ class TradingAgentsGraph: selected_analysts=["market", "social", "news", "fundamentals"], debug=False, config: Dict[str, Any] = None, + callbacks: Optional[List] = None, ): """Initialize the trading agents graph and components. @@ -55,9 +56,11 @@ class TradingAgentsGraph: selected_analysts: List of analyst types to include debug: Whether to run in debug mode config: Configuration dictionary. If None, uses default config + callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats) """ self.debug = debug self.config = config or DEFAULT_CONFIG + self.callbacks = callbacks or [] # Update the interface's config set_config(self.config) @@ -71,6 +74,10 @@ class TradingAgentsGraph: # Initialize LLMs with provider-specific thinking configuration llm_kwargs = self._get_provider_kwargs() + # Add callbacks to kwargs if provided (passed to LLM constructor) + if self.callbacks: + llm_kwargs["callbacks"] = self.callbacks + deep_client = create_llm_client( provider=self.config["llm_provider"], model=self.config["deep_think_llm"], @@ -83,6 +90,7 @@ class TradingAgentsGraph: base_url=self.config.get("backend_url"), **llm_kwargs, ) + self.deep_thinking_llm = deep_client.get_llm() self.quick_thinking_llm = quick_client.get_llm() diff --git a/tradingagents/llm_clients/anthropic_client.py b/tradingagents/llm_clients/anthropic_client.py index 5fdd9ac2..e2f1abba 100644 --- a/tradingagents/llm_clients/anthropic_client.py +++ b/tradingagents/llm_clients/anthropic_client.py @@ -16,7 +16,7 @@ class AnthropicClient(BaseLLMClient): """Return configured ChatAnthropic instance.""" llm_kwargs = {"model": self.model} - for key in ("timeout", "max_retries", "api_key", "max_tokens"): + for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"): if key in self.kwargs: llm_kwargs[key] = self.kwargs[key] diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index e10e83da..028c88a2 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,7 +4,6 @@ from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient -from .vllm_client import VLLMClient def create_llm_client( @@ -16,7 +15,7 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm) + provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) model: Model name/identifier base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments @@ -41,7 +40,4 @@ def create_llm_client( if provider_lower == "google": return GoogleClient(model, base_url, **kwargs) - if provider_lower == "vllm": - return VLLMClient(model, base_url, **kwargs) - raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tradingagents/llm_clients/google_client.py b/tradingagents/llm_clients/google_client.py index 99f2285c..a1bd386b 100644 --- a/tradingagents/llm_clients/google_client.py +++ b/tradingagents/llm_clients/google_client.py @@ -38,7 +38,7 @@ class GoogleClient(BaseLLMClient): """Return configured ChatGoogleGenerativeAI instance.""" llm_kwargs = {"model": self.model} - for key in ("timeout", "max_retries", "google_api_key"): + for key in ("timeout", "max_retries", "google_api_key", "callbacks"): if key in self.kwargs: llm_kwargs[key] = self.kwargs[key] diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index d957d5c3..1a87f8b5 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -60,7 +60,7 @@ class OpenAIClient(BaseLLMClient): elif self.base_url: llm_kwargs["base_url"] = self.base_url - for key in ("timeout", "max_retries", "reasoning_effort", "api_key"): + for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"): if key in self.kwargs: llm_kwargs[key] = self.kwargs[key] diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index b1d769b0..3c0f2290 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -69,11 +69,11 @@ VALID_MODELS = { def validate_model(provider: str, model: str) -> bool: """Check if model name is valid for the given provider. - For ollama, openrouter, vllm - any model is accepted. + For ollama, openrouter - any model is accepted. """ provider_lower = provider.lower() - if provider_lower in ("ollama", "openrouter", "vllm"): + if provider_lower in ("ollama", "openrouter"): return True if provider_lower not in VALID_MODELS: diff --git a/tradingagents/llm_clients/vllm_client.py b/tradingagents/llm_clients/vllm_client.py deleted file mode 100644 index a1ebfebf..00000000 --- a/tradingagents/llm_clients/vllm_client.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any, Optional - -from .base_client import BaseLLMClient - - -class VLLMClient(BaseLLMClient): - """Client for vLLM (placeholder for future implementation).""" - - def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): - super().__init__(model, base_url, **kwargs) - - def get_llm(self) -> Any: - """Return configured vLLM instance.""" - raise NotImplementedError("vLLM client not yet implemented") - - def validate_model(self) -> bool: - """Validate model for vLLM.""" - return True