Improve CLI report tracking, modularity, and test resilience

This commit is contained in:
basepoint 2026-03-22 22:59:47 +00:00
parent c3ba3bf428
commit 59f17e6ecd
6 changed files with 266 additions and 229 deletions

View File

@ -16,7 +16,6 @@ from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
from rich.table import Table
from collections import deque
import time
from rich.tree import Tree
from rich import box
@ -29,6 +28,7 @@ from cli.models import AnalystType
from cli.utils import *
from cli.announcements import fetch_announcements, display_announcements
from cli.stats_handler import StatsCallbackHandler
from cli.message_buffer import MessageBuffer
console = Console()
@ -39,193 +39,6 @@ app = typer.Typer(
)
# Create a deque to store recent messages with a maximum length
class MessageBuffer:
# Fixed teams that always run (not user-selectable)
FIXED_AGENTS = {
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
"Portfolio Management": ["Portfolio Manager"],
}
# Analyst name mapping
ANALYST_MAPPING = {
"market": "Market Analyst",
"social": "Social Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
}
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
# analyst_key: which analyst selection controls this section (None = always included)
# finalizing_agent: which agent must be "completed" for this report to count as done
REPORT_SECTIONS = {
"market_report": ("market", "Market Analyst"),
"sentiment_report": ("social", "Social Analyst"),
"news_report": ("news", "News Analyst"),
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
"investment_plan": (None, "Research Manager"),
"trader_investment_plan": (None, "Trader"),
"final_trade_decision": (None, "Portfolio Manager"),
}
def __init__(self, max_length=100):
self.messages = deque(maxlen=max_length)
self.tool_calls = deque(maxlen=max_length)
self.current_report = None
self.final_report = None # Store the complete final report
self.agent_status = {}
self.current_agent = None
self.report_sections = {}
self.selected_analysts = []
self._last_message_id = None
def init_for_analysis(self, selected_analysts):
"""Initialize agent status and report sections based on selected analysts.
Args:
selected_analysts: List of analyst type strings (e.g., ["market", "news"])
"""
self.selected_analysts = [a.lower() for a in selected_analysts]
# Build agent_status dynamically
self.agent_status = {}
# Add selected analysts
for analyst_key in self.selected_analysts:
if analyst_key in self.ANALYST_MAPPING:
self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending"
# Add fixed teams
for team_agents in self.FIXED_AGENTS.values():
for agent in team_agents:
self.agent_status[agent] = "pending"
# Build report_sections dynamically
self.report_sections = {}
for section, (analyst_key, _) in self.REPORT_SECTIONS.items():
if analyst_key is None or analyst_key in self.selected_analysts:
self.report_sections[section] = None
# Reset other state
self.current_report = None
self.final_report = None
self.current_agent = None
self.messages.clear()
self.tool_calls.clear()
self._last_message_id = None
def get_completed_reports_count(self):
"""Count reports that are finalized (their finalizing agent is completed).
A report is considered complete when:
1. The report section has content (not None), AND
2. The agent responsible for finalizing that report has status "completed"
This prevents interim updates (like debate rounds) from counting as completed.
"""
count = 0
for section in self.report_sections:
if section not in self.REPORT_SECTIONS:
continue
_, finalizing_agent = self.REPORT_SECTIONS[section]
# Report is complete if it has content AND its finalizing agent is done
has_content = self.report_sections.get(section) is not None
agent_done = self.agent_status.get(finalizing_agent) == "completed"
if has_content and agent_done:
count += 1
return count
def add_message(self, message_type, content):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.messages.append((timestamp, message_type, content))
def add_tool_call(self, tool_name, args):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.tool_calls.append((timestamp, tool_name, args))
def update_agent_status(self, agent, status):
if agent in self.agent_status:
self.agent_status[agent] = status
self.current_agent = agent
def update_report_section(self, section_name, content):
if section_name in self.report_sections:
self.report_sections[section_name] = content
self._update_current_report()
def _update_current_report(self):
# For the panel display, only show the most recently updated section
latest_section = None
latest_content = None
# Find the most recently updated section
for section, content in self.report_sections.items():
if content is not None:
latest_section = section
latest_content = content
if latest_section and latest_content:
# Format the current section for display
section_titles = {
"market_report": "Market Analysis",
"sentiment_report": "Social Sentiment",
"news_report": "News Analysis",
"fundamentals_report": "Fundamentals Analysis",
"investment_plan": "Research Team Decision",
"trader_investment_plan": "Trading Team Plan",
"final_trade_decision": "Portfolio Management Decision",
}
self.current_report = (
f"### {section_titles[latest_section]}\n{latest_content}"
)
# Update the final complete report
self._update_final_report()
def _update_final_report(self):
report_parts = []
# Analyst Team Reports - use .get() to handle missing sections
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
if any(self.report_sections.get(section) for section in analyst_sections):
report_parts.append("## Analyst Team Reports")
if self.report_sections.get("market_report"):
report_parts.append(
f"### Market Analysis\n{self.report_sections['market_report']}"
)
if self.report_sections.get("sentiment_report"):
report_parts.append(
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
)
if self.report_sections.get("news_report"):
report_parts.append(
f"### News Analysis\n{self.report_sections['news_report']}"
)
if self.report_sections.get("fundamentals_report"):
report_parts.append(
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
)
# Research Team Reports
if self.report_sections.get("investment_plan"):
report_parts.append("## Research Team Decision")
report_parts.append(f"{self.report_sections['investment_plan']}")
# Trading Team Reports
if self.report_sections.get("trader_investment_plan"):
report_parts.append("## Trading Team Plan")
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
# Portfolio Management Decision
if self.report_sections.get("final_trade_decision"):
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{self.report_sections['final_trade_decision']}")
self.final_report = "\n\n".join(report_parts) if report_parts else None
message_buffer = MessageBuffer()
@ -506,7 +319,7 @@ def get_user_selections():
"SPY",
)
)
selected_ticker = get_ticker()
selected_ticker = get_ticker(prompt_text="")
# Step 2: Analysis date
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
@ -517,7 +330,7 @@ def get_user_selections():
default_date,
)
)
analysis_date = get_analysis_date()
analysis_date = get_analysis_date(prompt_text="")
# Step 3: Select analysts
console.print(
@ -538,10 +351,10 @@ def get_user_selections():
)
selected_research_depth = select_research_depth()
# Step 5: OpenAI backend
# Step 5: LLM provider backend
console.print(
create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to"
"Step 5: LLM Provider", "Select which service to talk to"
)
)
selected_llm_provider, backend_url = select_llm_provider()
@ -601,30 +414,6 @@ def get_user_selections():
}
def get_ticker():
"""Get ticker symbol from user input."""
return typer.prompt("", default="SPY")
def get_analysis_date():
"""Get the analysis date from user input."""
while True:
date_str = typer.prompt(
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
)
try:
# Validate date format and ensure it's not in the future
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
if analysis_date.date() > datetime.datetime.now().date():
console.print("[red]Error: Analysis date cannot be in the future[/red]")
continue
return date_str
except ValueError:
console.print(
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
)
def save_report_to_disk(final_state, ticker: str, save_path: Path):
"""Save complete analysis report to disk with organized subfolders."""
save_path.mkdir(parents=True, exist_ok=True)
@ -970,8 +759,11 @@ def run_analysis():
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
timestamp, tool_name, args = obj.tool_calls[-1]
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
timestamp, tool_name, tool_args = obj.tool_calls[-1]
if isinstance(tool_args, dict):
args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items())
else:
args_str = str(tool_args)
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
return wrapper

192
cli/message_buffer.py Normal file
View File

@ -0,0 +1,192 @@
from collections import deque
import datetime
class MessageBuffer:
# Fixed teams that always run (not user-selectable)
FIXED_AGENTS = {
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
"Portfolio Management": ["Portfolio Manager"],
}
# Analyst name mapping
ANALYST_MAPPING = {
"market": "Market Analyst",
"social": "Social Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
}
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
# analyst_key: which analyst selection controls this section (None = always included)
# finalizing_agent: which agent must be "completed" for this report to count as done
REPORT_SECTIONS = {
"market_report": ("market", "Market Analyst"),
"sentiment_report": ("social", "Social Analyst"),
"news_report": ("news", "News Analyst"),
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
"investment_plan": (None, "Research Manager"),
"trader_investment_plan": (None, "Trader"),
"final_trade_decision": (None, "Portfolio Manager"),
}
def __init__(self, max_length=100):
self.messages = deque(maxlen=max_length)
self.tool_calls = deque(maxlen=max_length)
self.current_report = None
self.final_report = None # Store the complete final report
self.agent_status = {}
self.current_agent = None
self.report_sections = {}
self.selected_analysts = []
self._last_message_id = None
self._last_updated_section = None
def init_for_analysis(self, selected_analysts):
"""Initialize agent status and report sections based on selected analysts.
Args:
selected_analysts: List of analyst type strings (e.g., ["market", "news"])
"""
self.selected_analysts = [a.lower() for a in selected_analysts]
# Build agent_status dynamically
self.agent_status = {}
# Add selected analysts
for analyst_key in self.selected_analysts:
if analyst_key in self.ANALYST_MAPPING:
self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending"
# Add fixed teams
for team_agents in self.FIXED_AGENTS.values():
for agent in team_agents:
self.agent_status[agent] = "pending"
# Build report_sections dynamically
self.report_sections = {}
for section, (analyst_key, _) in self.REPORT_SECTIONS.items():
if analyst_key is None or analyst_key in self.selected_analysts:
self.report_sections[section] = None
# Reset other state
self.current_report = None
self.final_report = None
self.current_agent = None
self.messages.clear()
self.tool_calls.clear()
self._last_message_id = None
self._last_updated_section = None
def get_completed_reports_count(self):
"""Count reports that are finalized (their finalizing agent is completed).
A report is considered complete when:
1. The report section has content (not None), AND
2. The agent responsible for finalizing that report has status "completed"
This prevents interim updates (like debate rounds) from counting as completed.
"""
count = 0
for section in self.report_sections:
if section not in self.REPORT_SECTIONS:
continue
_, finalizing_agent = self.REPORT_SECTIONS[section]
# Report is complete if it has content AND its finalizing agent is done
has_content = self.report_sections.get(section) is not None
agent_done = self.agent_status.get(finalizing_agent) == "completed"
if has_content and agent_done:
count += 1
return count
def add_message(self, message_type, content):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.messages.append((timestamp, message_type, content))
def add_tool_call(self, tool_name, args):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.tool_calls.append((timestamp, tool_name, args))
def update_agent_status(self, agent, status):
if agent in self.agent_status:
self.agent_status[agent] = status
self.current_agent = agent
def update_report_section(self, section_name, content):
if section_name in self.report_sections:
self.report_sections[section_name] = content
self._last_updated_section = section_name
self._update_current_report()
def _update_current_report(self):
# For the panel display, only show the most recently updated section
latest_section = self._last_updated_section
latest_content = self.report_sections.get(latest_section) if latest_section else None
# Fallback if section tracking is unavailable
if latest_content is None:
for section, content in self.report_sections.items():
if content is not None:
latest_section = section
latest_content = content
if latest_section and latest_content:
# Format the current section for display
section_titles = {
"market_report": "Market Analysis",
"sentiment_report": "Social Sentiment",
"news_report": "News Analysis",
"fundamentals_report": "Fundamentals Analysis",
"investment_plan": "Research Team Decision",
"trader_investment_plan": "Trading Team Plan",
"final_trade_decision": "Portfolio Management Decision",
}
self.current_report = (
f"### {section_titles[latest_section]}\n{latest_content}"
)
# Update the final complete report
self._update_final_report()
def _update_final_report(self):
report_parts = []
# Analyst Team Reports - use .get() to handle missing sections
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
if any(self.report_sections.get(section) for section in analyst_sections):
report_parts.append("## Analyst Team Reports")
if self.report_sections.get("market_report"):
report_parts.append(
f"### Market Analysis\n{self.report_sections['market_report']}"
)
if self.report_sections.get("sentiment_report"):
report_parts.append(
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
)
if self.report_sections.get("news_report"):
report_parts.append(
f"### News Analysis\n{self.report_sections['news_report']}"
)
if self.report_sections.get("fundamentals_report"):
report_parts.append(
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
)
# Research Team Reports
if self.report_sections.get("investment_plan"):
report_parts.append("## Research Team Decision")
report_parts.append(f"{self.report_sections['investment_plan']}")
# Trading Team Reports
if self.report_sections.get("trader_investment_plan"):
report_parts.append("## Trading Team Plan")
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
# Portfolio Management Decision
if self.report_sections.get("final_trade_decision"):
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{self.report_sections['final_trade_decision']}")
self.final_report = "\n\n".join(report_parts) if report_parts else None

View File

@ -1,6 +1,4 @@
from enum import Enum
from typing import List, Optional, Dict
from pydantic import BaseModel
class AnalystType(str, Enum):

View File

@ -1,12 +1,23 @@
import questionary
from typing import List, Optional, Tuple, Dict
try:
import questionary
except ImportError: # pragma: no cover - optional during non-interactive testing
questionary = None
from rich.console import Console
from cli.models import AnalystType
console = Console()
def _ensure_questionary():
if questionary is None:
raise RuntimeError(
"questionary is required for interactive CLI prompts. Install dependencies with `pip install .`."
)
TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
ANALYST_ORDER = [
@ -17,10 +28,12 @@ ANALYST_ORDER = [
]
def get_ticker() -> str:
def get_ticker(prompt_text: str | None = None) -> str:
"""Prompt the user to enter a ticker symbol."""
_ensure_questionary()
prompt = prompt_text or f"Enter the exact ticker symbol to analyze ({TICKER_INPUT_EXAMPLES}):"
ticker = questionary.text(
f"Enter the exact ticker symbol to analyze ({TICKER_INPUT_EXAMPLES}):",
prompt,
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid ticker symbol.",
style=questionary.Style(
[
@ -42,8 +55,9 @@ def normalize_ticker_symbol(ticker: str) -> str:
return ticker.strip().upper()
def get_analysis_date() -> str:
def get_analysis_date(prompt_text: str = "Enter the analysis date (YYYY-MM-DD):") -> str:
"""Prompt the user to enter a date in YYYY-MM-DD format."""
_ensure_questionary()
import re
from datetime import datetime
@ -57,7 +71,7 @@ def get_analysis_date() -> str:
return False
date = questionary.text(
"Enter the analysis date (YYYY-MM-DD):",
prompt_text,
validate=lambda x: validate_date(x.strip())
or "Please enter a valid date in YYYY-MM-DD format.",
style=questionary.Style(
@ -77,6 +91,7 @@ def get_analysis_date() -> str:
def select_analysts() -> List[AnalystType]:
"""Select analysts using an interactive checkbox."""
_ensure_questionary()
choices = questionary.checkbox(
"Select Your [Analysts Team]:",
choices=[
@ -103,6 +118,7 @@ def select_analysts() -> List[AnalystType]:
def select_research_depth() -> int:
"""Select research depth using an interactive selection."""
_ensure_questionary()
# Define research depth options with their corresponding values
DEPTH_OPTIONS = [
@ -135,6 +151,7 @@ def select_research_depth() -> int:
def select_shallow_thinking_agent(provider) -> str:
"""Select shallow thinking llm engine using an interactive selection."""
_ensure_questionary()
# Define shallow thinking llm engine options with their corresponding model names
# Ordering: medium → light → heavy (balanced first for quick tasks)
@ -200,6 +217,7 @@ def select_shallow_thinking_agent(provider) -> str:
def select_deep_thinking_agent(provider) -> str:
"""Select deep thinking llm engine using an interactive selection."""
_ensure_questionary()
# Define deep thinking llm engine options with their corresponding model names
# Ordering: heavy → medium → light (most capable first for deep tasks)
@ -263,8 +281,9 @@ def select_deep_thinking_agent(provider) -> str:
return choice
def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection."""
# Define OpenAI api options with their corresponding endpoints
"""Select the LLM provider and API endpoint using interactive selection."""
_ensure_questionary()
# Define provider API options with their corresponding endpoints
BASE_URLS = [
("OpenAI", "https://api.openai.com/v1"),
("Google", "https://generativelanguage.googleapis.com/v1"),
@ -291,7 +310,7 @@ def select_llm_provider() -> tuple[str, str]:
).ask()
if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
exit(1)
display_name, url = choice
@ -302,6 +321,7 @@ def select_llm_provider() -> tuple[str, str]:
def ask_openai_reasoning_effort() -> str:
"""Ask for OpenAI reasoning effort level."""
_ensure_questionary()
choices = [
questionary.Choice("Medium (Default)", "medium"),
questionary.Choice("High (More thorough)", "high"),
@ -323,6 +343,7 @@ def ask_anthropic_effort() -> str | None:
Controls token usage and response thoroughness on Claude 4.5+ and 4.6 models.
"""
_ensure_questionary()
return questionary.select(
"Select Effort Level:",
choices=[
@ -344,6 +365,7 @@ def ask_gemini_thinking_config() -> str | None:
Returns thinking_level: "high" or "minimal".
Client maps to appropriate API param based on model series.
"""
_ensure_questionary()
return questionary.select(
"Select Thinking Mode:",
choices=[

View File

@ -0,0 +1,29 @@
import unittest
from cli.message_buffer import MessageBuffer
class MessageBufferTests(unittest.TestCase):
def setUp(self):
self.buffer = MessageBuffer()
self.buffer.init_for_analysis(["market", "news"])
def test_current_report_tracks_most_recent_updated_section(self):
self.buffer.update_report_section("market_report", "Market content")
self.assertIn("Market Analysis", self.buffer.current_report)
self.buffer.update_report_section("news_report", "News content")
self.assertIn("News Analysis", self.buffer.current_report)
self.assertNotIn("Market Analysis", self.buffer.current_report)
def test_init_resets_last_updated_section(self):
self.buffer.update_report_section("market_report", "Market content")
self.assertEqual(self.buffer._last_updated_section, "market_report")
self.buffer.init_for_analysis(["fundamentals"])
self.assertIsNone(self.buffer._last_updated_section)
self.assertIsNone(self.buffer.current_report)
if __name__ == "__main__":
unittest.main()

View File

@ -1,7 +1,6 @@
import unittest
from cli.utils import normalize_ticker_symbol
from tradingagents.agents.utils.agent_utils import build_instrument_context
class TickerSymbolHandlingTests(unittest.TestCase):
@ -9,6 +8,11 @@ class TickerSymbolHandlingTests(unittest.TestCase):
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")
def test_build_instrument_context_mentions_exact_symbol(self):
try:
from tradingagents.agents.utils.agent_utils import build_instrument_context
except ModuleNotFoundError as exc:
self.skipTest(f"optional dependency missing: {exc}")
context = build_instrument_context("7203.T")
self.assertIn("7203.T", context)
self.assertIn("exchange suffix", context)