feat: add footer statistics tracking with LangChain callbacks
- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens - Integrate callbacks into TradingAgentsGraph and all LLM clients - Dynamic agent/report counts based on selected analysts - Fix report completion counting (tied to agent completion)
This commit is contained in:
parent
b06936f420
commit
54cdb146d0
255
cli/main.py
255
cli/main.py
|
|
@ -15,7 +15,6 @@ from rich.columns import Columns
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.layout import Layout
|
from rich.layout import Layout
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
from rich.live import Live
|
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import time
|
import time
|
||||||
|
|
@ -29,6 +28,7 @@ from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
from cli.utils import *
|
from cli.utils import *
|
||||||
from cli.announcements import fetch_announcements, display_announcements
|
from cli.announcements import fetch_announcements, display_announcements
|
||||||
|
from cli.stats_handler import StatsCallbackHandler
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -41,40 +41,99 @@ app = typer.Typer(
|
||||||
|
|
||||||
# Create a deque to store recent messages with a maximum length
|
# Create a deque to store recent messages with a maximum length
|
||||||
class MessageBuffer:
|
class MessageBuffer:
|
||||||
|
# Fixed teams that always run (not user-selectable)
|
||||||
|
FIXED_AGENTS = {
|
||||||
|
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
|
||||||
|
"Trading Team": ["Trader"],
|
||||||
|
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
|
||||||
|
"Portfolio Management": ["Portfolio Manager"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Analyst name mapping
|
||||||
|
ANALYST_MAPPING = {
|
||||||
|
"market": "Market Analyst",
|
||||||
|
"social": "Social Analyst",
|
||||||
|
"news": "News Analyst",
|
||||||
|
"fundamentals": "Fundamentals Analyst",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
|
||||||
|
# analyst_key: which analyst selection controls this section (None = always included)
|
||||||
|
# finalizing_agent: which agent must be "completed" for this report to count as done
|
||||||
|
REPORT_SECTIONS = {
|
||||||
|
"market_report": ("market", "Market Analyst"),
|
||||||
|
"sentiment_report": ("social", "Social Analyst"),
|
||||||
|
"news_report": ("news", "News Analyst"),
|
||||||
|
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
|
||||||
|
"investment_plan": (None, "Research Manager"),
|
||||||
|
"trader_investment_plan": (None, "Trader"),
|
||||||
|
"final_trade_decision": (None, "Portfolio Manager"),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, max_length=100):
|
def __init__(self, max_length=100):
|
||||||
self.messages = deque(maxlen=max_length)
|
self.messages = deque(maxlen=max_length)
|
||||||
self.tool_calls = deque(maxlen=max_length)
|
self.tool_calls = deque(maxlen=max_length)
|
||||||
self.current_report = None
|
self.current_report = None
|
||||||
self.final_report = None # Store the complete final report
|
self.final_report = None # Store the complete final report
|
||||||
self.agent_status = {
|
self.agent_status = {}
|
||||||
# Analyst Team
|
|
||||||
"Market Analyst": "pending",
|
|
||||||
"Social Analyst": "pending",
|
|
||||||
"News Analyst": "pending",
|
|
||||||
"Fundamentals Analyst": "pending",
|
|
||||||
# Research Team
|
|
||||||
"Bull Researcher": "pending",
|
|
||||||
"Bear Researcher": "pending",
|
|
||||||
"Research Manager": "pending",
|
|
||||||
# Trading Team
|
|
||||||
"Trader": "pending",
|
|
||||||
# Risk Management Team
|
|
||||||
"Aggressive Analyst": "pending",
|
|
||||||
"Neutral Analyst": "pending",
|
|
||||||
"Conservative Analyst": "pending",
|
|
||||||
# Portfolio Management Team
|
|
||||||
"Portfolio Manager": "pending",
|
|
||||||
}
|
|
||||||
self.current_agent = None
|
self.current_agent = None
|
||||||
self.report_sections = {
|
self.report_sections = {}
|
||||||
"market_report": None,
|
self.selected_analysts = []
|
||||||
"sentiment_report": None,
|
|
||||||
"news_report": None,
|
def init_for_analysis(self, selected_analysts):
|
||||||
"fundamentals_report": None,
|
"""Initialize agent status and report sections based on selected analysts.
|
||||||
"investment_plan": None,
|
|
||||||
"trader_investment_plan": None,
|
Args:
|
||||||
"final_trade_decision": None,
|
selected_analysts: List of analyst type strings (e.g., ["market", "news"])
|
||||||
}
|
"""
|
||||||
|
self.selected_analysts = [a.lower() for a in selected_analysts]
|
||||||
|
|
||||||
|
# Build agent_status dynamically
|
||||||
|
self.agent_status = {}
|
||||||
|
|
||||||
|
# Add selected analysts
|
||||||
|
for analyst_key in self.selected_analysts:
|
||||||
|
if analyst_key in self.ANALYST_MAPPING:
|
||||||
|
self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending"
|
||||||
|
|
||||||
|
# Add fixed teams
|
||||||
|
for team_agents in self.FIXED_AGENTS.values():
|
||||||
|
for agent in team_agents:
|
||||||
|
self.agent_status[agent] = "pending"
|
||||||
|
|
||||||
|
# Build report_sections dynamically
|
||||||
|
self.report_sections = {}
|
||||||
|
for section, (analyst_key, _) in self.REPORT_SECTIONS.items():
|
||||||
|
if analyst_key is None or analyst_key in self.selected_analysts:
|
||||||
|
self.report_sections[section] = None
|
||||||
|
|
||||||
|
# Reset other state
|
||||||
|
self.current_report = None
|
||||||
|
self.final_report = None
|
||||||
|
self.current_agent = None
|
||||||
|
self.messages.clear()
|
||||||
|
self.tool_calls.clear()
|
||||||
|
|
||||||
|
def get_completed_reports_count(self):
|
||||||
|
"""Count reports that are finalized (their finalizing agent is completed).
|
||||||
|
|
||||||
|
A report is considered complete when:
|
||||||
|
1. The report section has content (not None), AND
|
||||||
|
2. The agent responsible for finalizing that report has status "completed"
|
||||||
|
|
||||||
|
This prevents interim updates (like debate rounds) from counting as completed.
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
for section in self.report_sections:
|
||||||
|
if section not in self.REPORT_SECTIONS:
|
||||||
|
continue
|
||||||
|
_, finalizing_agent = self.REPORT_SECTIONS[section]
|
||||||
|
# Report is complete if it has content AND its finalizing agent is done
|
||||||
|
has_content = self.report_sections.get(section) is not None
|
||||||
|
agent_done = self.agent_status.get(finalizing_agent) == "completed"
|
||||||
|
if has_content and agent_done:
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
def add_message(self, message_type, content):
|
def add_message(self, message_type, content):
|
||||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||||
|
|
@ -126,46 +185,39 @@ class MessageBuffer:
|
||||||
def _update_final_report(self):
|
def _update_final_report(self):
|
||||||
report_parts = []
|
report_parts = []
|
||||||
|
|
||||||
# Analyst Team Reports
|
# Analyst Team Reports - use .get() to handle missing sections
|
||||||
if any(
|
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
|
||||||
self.report_sections[section]
|
if any(self.report_sections.get(section) for section in analyst_sections):
|
||||||
for section in [
|
|
||||||
"market_report",
|
|
||||||
"sentiment_report",
|
|
||||||
"news_report",
|
|
||||||
"fundamentals_report",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
report_parts.append("## Analyst Team Reports")
|
report_parts.append("## Analyst Team Reports")
|
||||||
if self.report_sections["market_report"]:
|
if self.report_sections.get("market_report"):
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
f"### Market Analysis\n{self.report_sections['market_report']}"
|
||||||
)
|
)
|
||||||
if self.report_sections["sentiment_report"]:
|
if self.report_sections.get("sentiment_report"):
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
||||||
)
|
)
|
||||||
if self.report_sections["news_report"]:
|
if self.report_sections.get("news_report"):
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
f"### News Analysis\n{self.report_sections['news_report']}"
|
||||||
)
|
)
|
||||||
if self.report_sections["fundamentals_report"]:
|
if self.report_sections.get("fundamentals_report"):
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Research Team Reports
|
# Research Team Reports
|
||||||
if self.report_sections["investment_plan"]:
|
if self.report_sections.get("investment_plan"):
|
||||||
report_parts.append("## Research Team Decision")
|
report_parts.append("## Research Team Decision")
|
||||||
report_parts.append(f"{self.report_sections['investment_plan']}")
|
report_parts.append(f"{self.report_sections['investment_plan']}")
|
||||||
|
|
||||||
# Trading Team Reports
|
# Trading Team Reports
|
||||||
if self.report_sections["trader_investment_plan"]:
|
if self.report_sections.get("trader_investment_plan"):
|
||||||
report_parts.append("## Trading Team Plan")
|
report_parts.append("## Trading Team Plan")
|
||||||
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
||||||
|
|
||||||
# Portfolio Management Decision
|
# Portfolio Management Decision
|
||||||
if self.report_sections["final_trade_decision"]:
|
if self.report_sections.get("final_trade_decision"):
|
||||||
report_parts.append("## Portfolio Management Decision")
|
report_parts.append("## Portfolio Management Decision")
|
||||||
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
||||||
|
|
||||||
|
|
@ -191,7 +243,14 @@ def create_layout():
|
||||||
return layout
|
return layout
|
||||||
|
|
||||||
|
|
||||||
def update_display(layout, spinner_text=None):
|
def format_tokens(n):
|
||||||
|
"""Format token count for display."""
|
||||||
|
if n >= 1000:
|
||||||
|
return f"{n/1000:.1f}k"
|
||||||
|
return str(n)
|
||||||
|
|
||||||
|
|
||||||
|
def update_display(layout, spinner_text=None, stats_handler=None, start_time=None):
|
||||||
# Header with welcome message
|
# Header with welcome message
|
||||||
layout["header"].update(
|
layout["header"].update(
|
||||||
Panel(
|
Panel(
|
||||||
|
|
@ -218,8 +277,8 @@ def update_display(layout, spinner_text=None):
|
||||||
progress_table.add_column("Agent", style="green", justify="center", width=20)
|
progress_table.add_column("Agent", style="green", justify="center", width=20)
|
||||||
progress_table.add_column("Status", style="yellow", justify="center", width=20)
|
progress_table.add_column("Status", style="yellow", justify="center", width=20)
|
||||||
|
|
||||||
# Group agents by team
|
# Group agents by team - filter to only include agents in agent_status
|
||||||
teams = {
|
all_teams = {
|
||||||
"Analyst Team": [
|
"Analyst Team": [
|
||||||
"Market Analyst",
|
"Market Analyst",
|
||||||
"Social Analyst",
|
"Social Analyst",
|
||||||
|
|
@ -232,10 +291,17 @@ def update_display(layout, spinner_text=None):
|
||||||
"Portfolio Management": ["Portfolio Manager"],
|
"Portfolio Management": ["Portfolio Manager"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Filter teams to only include agents that are in agent_status
|
||||||
|
teams = {}
|
||||||
|
for team, agents in all_teams.items():
|
||||||
|
active_agents = [a for a in agents if a in message_buffer.agent_status]
|
||||||
|
if active_agents:
|
||||||
|
teams[team] = active_agents
|
||||||
|
|
||||||
for team, agents in teams.items():
|
for team, agents in teams.items():
|
||||||
# Add first agent with team name
|
# Add first agent with team name
|
||||||
first_agent = agents[0]
|
first_agent = agents[0]
|
||||||
status = message_buffer.agent_status[first_agent]
|
status = message_buffer.agent_status.get(first_agent, "pending")
|
||||||
if status == "in_progress":
|
if status == "in_progress":
|
||||||
spinner = Spinner(
|
spinner = Spinner(
|
||||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||||
|
|
@ -252,7 +318,7 @@ def update_display(layout, spinner_text=None):
|
||||||
|
|
||||||
# Add remaining agents in team
|
# Add remaining agents in team
|
||||||
for agent in agents[1:]:
|
for agent in agents[1:]:
|
||||||
status = message_buffer.agent_status[agent]
|
status = message_buffer.agent_status.get(agent, "pending")
|
||||||
if status == "in_progress":
|
if status == "in_progress":
|
||||||
spinner = Spinner(
|
spinner = Spinner(
|
||||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||||
|
|
@ -379,19 +445,43 @@ def update_display(layout, spinner_text=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Footer with statistics
|
# Footer with statistics
|
||||||
tool_calls_count = len(message_buffer.tool_calls)
|
# Agent progress - derived from agent_status dict
|
||||||
llm_calls_count = sum(
|
agents_completed = sum(
|
||||||
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
|
1 for status in message_buffer.agent_status.values() if status == "completed"
|
||||||
)
|
|
||||||
reports_count = sum(
|
|
||||||
1 for content in message_buffer.report_sections.values() if content is not None
|
|
||||||
)
|
)
|
||||||
|
agents_total = len(message_buffer.agent_status)
|
||||||
|
|
||||||
|
# Report progress - based on agent completion (not just content existence)
|
||||||
|
reports_completed = message_buffer.get_completed_reports_count()
|
||||||
|
reports_total = len(message_buffer.report_sections)
|
||||||
|
|
||||||
|
# Build stats parts
|
||||||
|
stats_parts = [f"Agents: {agents_completed}/{agents_total}"]
|
||||||
|
|
||||||
|
# LLM and tool stats from callback handler
|
||||||
|
if stats_handler:
|
||||||
|
stats = stats_handler.get_stats()
|
||||||
|
stats_parts.append(f"LLM: {stats['llm_calls']}")
|
||||||
|
stats_parts.append(f"Tools: {stats['tool_calls']}")
|
||||||
|
|
||||||
|
# Token display with graceful fallback
|
||||||
|
if stats["tokens_in"] > 0 or stats["tokens_out"] > 0:
|
||||||
|
tokens_str = f"Tokens: {format_tokens(stats['tokens_in'])}\u2191 {format_tokens(stats['tokens_out'])}\u2193"
|
||||||
|
else:
|
||||||
|
tokens_str = "Tokens: --"
|
||||||
|
stats_parts.append(tokens_str)
|
||||||
|
|
||||||
|
stats_parts.append(f"Reports: {reports_completed}/{reports_total}")
|
||||||
|
|
||||||
|
# Elapsed time
|
||||||
|
if start_time:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
elapsed_str = f"\u23f1 {int(elapsed // 60):02d}:{int(elapsed % 60):02d}"
|
||||||
|
stats_parts.append(elapsed_str)
|
||||||
|
|
||||||
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
|
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
|
||||||
stats_table.add_column("Stats", justify="center")
|
stats_table.add_column("Stats", justify="center")
|
||||||
stats_table.add_row(
|
stats_table.add_row(" | ".join(stats_parts))
|
||||||
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
||||||
|
|
||||||
|
|
@ -803,11 +893,24 @@ def run_analysis():
|
||||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||||
|
|
||||||
# Initialize the graph
|
# Create stats callback handler for tracking LLM/tool calls
|
||||||
|
stats_handler = StatsCallbackHandler()
|
||||||
|
|
||||||
|
# Initialize the graph with callbacks bound to LLMs
|
||||||
graph = TradingAgentsGraph(
|
graph = TradingAgentsGraph(
|
||||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
[analyst.value for analyst in selections["analysts"]],
|
||||||
|
config=config,
|
||||||
|
debug=True,
|
||||||
|
callbacks=[stats_handler],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize message buffer with selected analysts
|
||||||
|
selected_analyst_keys = [analyst.value for analyst in selections["analysts"]]
|
||||||
|
message_buffer.init_for_analysis(selected_analyst_keys)
|
||||||
|
|
||||||
|
# Track start time for elapsed display
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
# Create result directory
|
# Create result directory
|
||||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -860,7 +963,7 @@ def run_analysis():
|
||||||
|
|
||||||
with Live(layout, refresh_per_second=4) as live:
|
with Live(layout, refresh_per_second=4) as live:
|
||||||
# Initial display
|
# Initial display
|
||||||
update_display(layout)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Add initial messages
|
# Add initial messages
|
||||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||||
|
|
@ -871,34 +974,26 @@ def run_analysis():
|
||||||
"System",
|
"System",
|
||||||
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
||||||
)
|
)
|
||||||
update_display(layout)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Reset agent statuses
|
|
||||||
for agent in message_buffer.agent_status:
|
|
||||||
message_buffer.update_agent_status(agent, "pending")
|
|
||||||
|
|
||||||
# Reset report sections
|
|
||||||
for section in message_buffer.report_sections:
|
|
||||||
message_buffer.report_sections[section] = None
|
|
||||||
message_buffer.current_report = None
|
|
||||||
message_buffer.final_report = None
|
|
||||||
|
|
||||||
# Update agent status to in_progress for the first analyst
|
# Update agent status to in_progress for the first analyst
|
||||||
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
|
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
|
||||||
message_buffer.update_agent_status(first_analyst, "in_progress")
|
message_buffer.update_agent_status(first_analyst, "in_progress")
|
||||||
update_display(layout)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Create spinner text
|
# Create spinner text
|
||||||
spinner_text = (
|
spinner_text = (
|
||||||
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
|
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
|
||||||
)
|
)
|
||||||
update_display(layout, spinner_text)
|
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Initialize state and get graph args
|
# Initialize state and get graph args with callbacks
|
||||||
init_agent_state = graph.propagator.create_initial_state(
|
init_agent_state = graph.propagator.create_initial_state(
|
||||||
selections["ticker"], selections["analysis_date"]
|
selections["ticker"], selections["analysis_date"]
|
||||||
)
|
)
|
||||||
args = graph.propagator.get_graph_args()
|
# Pass callbacks to graph config for tool execution tracking
|
||||||
|
# (LLM tracking is handled separately via LLM constructor)
|
||||||
|
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
|
||||||
|
|
||||||
# Stream the analysis
|
# Stream the analysis
|
||||||
trace = []
|
trace = []
|
||||||
|
|
@ -1112,7 +1207,7 @@ def run_analysis():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the display
|
# Update the display
|
||||||
update_display(layout)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
trace.append(chunk)
|
trace.append(chunk)
|
||||||
|
|
||||||
|
|
@ -1136,7 +1231,7 @@ def run_analysis():
|
||||||
# Display the complete final report
|
# Display the complete final report
|
||||||
display_complete_report(final_state)
|
display_complete_report(final_state)
|
||||||
|
|
||||||
update_display(layout)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
class StatsCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback handler that tracks LLM calls, tool calls, and token usage."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self.llm_calls = 0
|
||||||
|
self.tool_calls = 0
|
||||||
|
self.tokens_in = 0
|
||||||
|
self.tokens_out = 0
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
prompts: List[str],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Increment LLM call counter when an LLM starts."""
|
||||||
|
with self._lock:
|
||||||
|
self.llm_calls += 1
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[Any]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Increment LLM call counter when a chat model starts."""
|
||||||
|
with self._lock:
|
||||||
|
self.llm_calls += 1
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
"""Extract token usage from LLM response."""
|
||||||
|
try:
|
||||||
|
generation = response.generations[0][0]
|
||||||
|
except (IndexError, TypeError):
|
||||||
|
return
|
||||||
|
|
||||||
|
usage_metadata = None
|
||||||
|
if hasattr(generation, "message"):
|
||||||
|
message = generation.message
|
||||||
|
if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"):
|
||||||
|
usage_metadata = message.usage_metadata
|
||||||
|
|
||||||
|
if usage_metadata:
|
||||||
|
with self._lock:
|
||||||
|
self.tokens_in += usage_metadata.get("input_tokens", 0)
|
||||||
|
self.tokens_out += usage_metadata.get("output_tokens", 0)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
input_str: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Increment tool call counter when a tool starts."""
|
||||||
|
with self._lock:
|
||||||
|
self.tool_calls += 1
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Return current statistics."""
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
"llm_calls": self.llm_calls,
|
||||||
|
"tool_calls": self.tool_calls,
|
||||||
|
"tokens_in": self.tokens_in,
|
||||||
|
"tokens_out": self.tokens_out,
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# TradingAgents/graph/propagation.py
|
# TradingAgents/graph/propagation.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, List, Optional
|
||||||
from tradingagents.agents.utils.agent_states import (
|
from tradingagents.agents.utils.agent_states import (
|
||||||
AgentState,
|
AgentState,
|
||||||
InvestDebateState,
|
InvestDebateState,
|
||||||
|
|
@ -41,9 +41,17 @@ class Propagator:
|
||||||
"news_report": "",
|
"news_report": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_graph_args(self) -> Dict[str, Any]:
|
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
|
||||||
"""Get arguments for the graph invocation."""
|
"""Get arguments for the graph invocation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callbacks: Optional list of callback handlers for tool execution tracking.
|
||||||
|
Note: LLM callbacks are handled separately via LLM constructor.
|
||||||
|
"""
|
||||||
|
config = {"recursion_limit": self.max_recur_limit}
|
||||||
|
if callbacks:
|
||||||
|
config["callbacks"] = callbacks
|
||||||
return {
|
return {
|
||||||
"stream_mode": "values",
|
"stream_mode": "values",
|
||||||
"config": {"recursion_limit": self.max_recur_limit},
|
"config": config,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,7 @@ class TradingAgentsGraph:
|
||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
debug=False,
|
debug=False,
|
||||||
config: Dict[str, Any] = None,
|
config: Dict[str, Any] = None,
|
||||||
|
callbacks: Optional[List] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the trading agents graph and components.
|
"""Initialize the trading agents graph and components.
|
||||||
|
|
||||||
|
|
@ -55,9 +56,11 @@ class TradingAgentsGraph:
|
||||||
selected_analysts: List of analyst types to include
|
selected_analysts: List of analyst types to include
|
||||||
debug: Whether to run in debug mode
|
debug: Whether to run in debug mode
|
||||||
config: Configuration dictionary. If None, uses default config
|
config: Configuration dictionary. If None, uses default config
|
||||||
|
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||||||
"""
|
"""
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or DEFAULT_CONFIG
|
||||||
|
self.callbacks = callbacks or []
|
||||||
|
|
||||||
# Update the interface's config
|
# Update the interface's config
|
||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
@ -71,6 +74,10 @@ class TradingAgentsGraph:
|
||||||
# Initialize LLMs with provider-specific thinking configuration
|
# Initialize LLMs with provider-specific thinking configuration
|
||||||
llm_kwargs = self._get_provider_kwargs()
|
llm_kwargs = self._get_provider_kwargs()
|
||||||
|
|
||||||
|
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||||||
|
if self.callbacks:
|
||||||
|
llm_kwargs["callbacks"] = self.callbacks
|
||||||
|
|
||||||
deep_client = create_llm_client(
|
deep_client = create_llm_client(
|
||||||
provider=self.config["llm_provider"],
|
provider=self.config["llm_provider"],
|
||||||
model=self.config["deep_think_llm"],
|
model=self.config["deep_think_llm"],
|
||||||
|
|
@ -83,6 +90,7 @@ class TradingAgentsGraph:
|
||||||
base_url=self.config.get("backend_url"),
|
base_url=self.config.get("backend_url"),
|
||||||
**llm_kwargs,
|
**llm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deep_thinking_llm = deep_client.get_llm()
|
self.deep_thinking_llm = deep_client.get_llm()
|
||||||
self.quick_thinking_llm = quick_client.get_llm()
|
self.quick_thinking_llm = quick_client.get_llm()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ class AnthropicClient(BaseLLMClient):
|
||||||
"""Return configured ChatAnthropic instance."""
|
"""Return configured ChatAnthropic instance."""
|
||||||
llm_kwargs = {"model": self.model}
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
|
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"):
|
||||||
if key in self.kwargs:
|
if key in self.kwargs:
|
||||||
llm_kwargs[key] = self.kwargs[key]
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
from .vllm_client import VLLMClient
|
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
|
|
@ -16,7 +15,7 @@ def create_llm_client(
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
|
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||||
model: Model name/identifier
|
model: Model name/identifier
|
||||||
base_url: Optional base URL for API endpoint
|
base_url: Optional base URL for API endpoint
|
||||||
**kwargs: Additional provider-specific arguments
|
**kwargs: Additional provider-specific arguments
|
||||||
|
|
@ -41,7 +40,4 @@ def create_llm_client(
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "vllm":
|
|
||||||
return VLLMClient(model, base_url, **kwargs)
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class GoogleClient(BaseLLMClient):
|
||||||
"""Return configured ChatGoogleGenerativeAI instance."""
|
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||||
llm_kwargs = {"model": self.model}
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
for key in ("timeout", "max_retries", "google_api_key"):
|
for key in ("timeout", "max_retries", "google_api_key", "callbacks"):
|
||||||
if key in self.kwargs:
|
if key in self.kwargs:
|
||||||
llm_kwargs[key] = self.kwargs[key]
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ class OpenAIClient(BaseLLMClient):
|
||||||
elif self.base_url:
|
elif self.base_url:
|
||||||
llm_kwargs["base_url"] = self.base_url
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
|
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
|
||||||
if key in self.kwargs:
|
if key in self.kwargs:
|
||||||
llm_kwargs[key] = self.kwargs[key]
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,11 +69,11 @@ VALID_MODELS = {
|
||||||
def validate_model(provider: str, model: str) -> bool:
|
def validate_model(provider: str, model: str) -> bool:
|
||||||
"""Check if model name is valid for the given provider.
|
"""Check if model name is valid for the given provider.
|
||||||
|
|
||||||
For ollama, openrouter, vllm - any model is accepted.
|
For ollama, openrouter - any model is accepted.
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("ollama", "openrouter", "vllm"):
|
if provider_lower in ("ollama", "openrouter"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if provider_lower not in VALID_MODELS:
|
if provider_lower not in VALID_MODELS:
|
||||||
|
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from .base_client import BaseLLMClient
|
|
||||||
|
|
||||||
|
|
||||||
class VLLMClient(BaseLLMClient):
|
|
||||||
"""Client for vLLM (placeholder for future implementation)."""
|
|
||||||
|
|
||||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
|
||||||
super().__init__(model, base_url, **kwargs)
|
|
||||||
|
|
||||||
def get_llm(self) -> Any:
|
|
||||||
"""Return configured vLLM instance."""
|
|
||||||
raise NotImplementedError("vLLM client not yet implemented")
|
|
||||||
|
|
||||||
def validate_model(self) -> bool:
|
|
||||||
"""Validate model for vLLM."""
|
|
||||||
return True
|
|
||||||
Loading…
Reference in New Issue