From 59f17e6ecd43b15de56a58b7ddc7b786a7711ce3 Mon Sep 17 00:00:00 2001 From: basepoint Date: Sun, 22 Mar 2026 22:59:47 +0000 Subject: [PATCH] Improve CLI report tracking, modularity, and test resilience --- cli/main.py | 228 ++------------------------- cli/message_buffer.py | 192 ++++++++++++++++++++++ cli/models.py | 2 - cli/utils.py | 38 ++++- tests/test_message_buffer.py | 29 ++++ tests/test_ticker_symbol_handling.py | 6 +- 6 files changed, 266 insertions(+), 229 deletions(-) create mode 100644 cli/message_buffer.py create mode 100644 tests/test_message_buffer.py diff --git a/cli/main.py b/cli/main.py index f26ae4c5..439916df 100644 --- a/cli/main.py +++ b/cli/main.py @@ -16,7 +16,6 @@ from rich.markdown import Markdown from rich.layout import Layout from rich.text import Text from rich.table import Table -from collections import deque import time from rich.tree import Tree from rich import box @@ -29,6 +28,7 @@ from cli.models import AnalystType from cli.utils import * from cli.announcements import fetch_announcements, display_announcements from cli.stats_handler import StatsCallbackHandler +from cli.message_buffer import MessageBuffer console = Console() @@ -39,193 +39,6 @@ app = typer.Typer( ) -# Create a deque to store recent messages with a maximum length -class MessageBuffer: - # Fixed teams that always run (not user-selectable) - FIXED_AGENTS = { - "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], - "Trading Team": ["Trader"], - "Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"], - "Portfolio Management": ["Portfolio Manager"], - } - - # Analyst name mapping - ANALYST_MAPPING = { - "market": "Market Analyst", - "social": "Social Analyst", - "news": "News Analyst", - "fundamentals": "Fundamentals Analyst", - } - - # Report section mapping: section -> (analyst_key for filtering, finalizing_agent) - # analyst_key: which analyst selection controls this section (None = always included) - # finalizing_agent: which agent must be "completed" for this report to count as done - REPORT_SECTIONS = { - "market_report": ("market", "Market Analyst"), - "sentiment_report": ("social", "Social Analyst"), - "news_report": ("news", "News Analyst"), - "fundamentals_report": ("fundamentals", "Fundamentals Analyst"), - "investment_plan": (None, "Research Manager"), - "trader_investment_plan": (None, "Trader"), - "final_trade_decision": (None, "Portfolio Manager"), - } - - def __init__(self, max_length=100): - self.messages = deque(maxlen=max_length) - self.tool_calls = deque(maxlen=max_length) - self.current_report = None - self.final_report = None # Store the complete final report - self.agent_status = {} - self.current_agent = None - self.report_sections = {} - self.selected_analysts = [] - self._last_message_id = None - - def init_for_analysis(self, selected_analysts): - """Initialize agent status and report sections based on selected analysts. - - Args: - selected_analysts: List of analyst type strings (e.g., ["market", "news"]) - """ - self.selected_analysts = [a.lower() for a in selected_analysts] - - # Build agent_status dynamically - self.agent_status = {} - - # Add selected analysts - for analyst_key in self.selected_analysts: - if analyst_key in self.ANALYST_MAPPING: - self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending" - - # Add fixed teams - for team_agents in self.FIXED_AGENTS.values(): - for agent in team_agents: - self.agent_status[agent] = "pending" - - # Build report_sections dynamically - self.report_sections = {} - for section, (analyst_key, _) in self.REPORT_SECTIONS.items(): - if analyst_key is None or analyst_key in self.selected_analysts: - self.report_sections[section] = None - - # Reset other state - self.current_report = None - self.final_report = None - self.current_agent = None - self.messages.clear() - self.tool_calls.clear() - self._last_message_id = None - - def get_completed_reports_count(self): - """Count reports that are finalized (their finalizing agent is completed). - - A report is considered complete when: - 1. The report section has content (not None), AND - 2. The agent responsible for finalizing that report has status "completed" - - This prevents interim updates (like debate rounds) from counting as completed. - """ - count = 0 - for section in self.report_sections: - if section not in self.REPORT_SECTIONS: - continue - _, finalizing_agent = self.REPORT_SECTIONS[section] - # Report is complete if it has content AND its finalizing agent is done - has_content = self.report_sections.get(section) is not None - agent_done = self.agent_status.get(finalizing_agent) == "completed" - if has_content and agent_done: - count += 1 - return count - - def add_message(self, message_type, content): - timestamp = datetime.datetime.now().strftime("%H:%M:%S") - self.messages.append((timestamp, message_type, content)) - - def add_tool_call(self, tool_name, args): - timestamp = datetime.datetime.now().strftime("%H:%M:%S") - self.tool_calls.append((timestamp, tool_name, args)) - - def update_agent_status(self, agent, status): - if agent in self.agent_status: - self.agent_status[agent] = status - self.current_agent = agent - - def update_report_section(self, section_name, content): - if section_name in self.report_sections: - self.report_sections[section_name] = content - self._update_current_report() - - def _update_current_report(self): - # For the panel display, only show the most recently updated section - latest_section = None - latest_content = None - - # Find the most recently updated section - for section, content in self.report_sections.items(): - if content is not None: - latest_section = section - latest_content = content - - if latest_section and latest_content: - # Format the current section for display - section_titles = { - "market_report": "Market Analysis", - "sentiment_report": "Social Sentiment", - "news_report": "News Analysis", - "fundamentals_report": "Fundamentals Analysis", - "investment_plan": "Research Team Decision", - "trader_investment_plan": "Trading Team Plan", - "final_trade_decision": "Portfolio Management Decision", - } - self.current_report = ( - f"### {section_titles[latest_section]}\n{latest_content}" - ) - - # Update the final complete report - self._update_final_report() - - def _update_final_report(self): - report_parts = [] - - # Analyst Team Reports - use .get() to handle missing sections - analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"] - if any(self.report_sections.get(section) for section in analyst_sections): - report_parts.append("## Analyst Team Reports") - if self.report_sections.get("market_report"): - report_parts.append( - f"### Market Analysis\n{self.report_sections['market_report']}" - ) - if self.report_sections.get("sentiment_report"): - report_parts.append( - f"### Social Sentiment\n{self.report_sections['sentiment_report']}" - ) - if self.report_sections.get("news_report"): - report_parts.append( - f"### News Analysis\n{self.report_sections['news_report']}" - ) - if self.report_sections.get("fundamentals_report"): - report_parts.append( - f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" - ) - - # Research Team Reports - if self.report_sections.get("investment_plan"): - report_parts.append("## Research Team Decision") - report_parts.append(f"{self.report_sections['investment_plan']}") - - # Trading Team Reports - if self.report_sections.get("trader_investment_plan"): - report_parts.append("## Trading Team Plan") - report_parts.append(f"{self.report_sections['trader_investment_plan']}") - - # Portfolio Management Decision - if self.report_sections.get("final_trade_decision"): - report_parts.append("## Portfolio Management Decision") - report_parts.append(f"{self.report_sections['final_trade_decision']}") - - self.final_report = "\n\n".join(report_parts) if report_parts else None - - message_buffer = MessageBuffer() @@ -506,7 +319,7 @@ def get_user_selections(): "SPY", ) ) - selected_ticker = get_ticker() + selected_ticker = get_ticker(prompt_text="") # Step 2: Analysis date default_date = datetime.datetime.now().strftime("%Y-%m-%d") @@ -517,7 +330,7 @@ def get_user_selections(): default_date, ) ) - analysis_date = get_analysis_date() + analysis_date = get_analysis_date(prompt_text="") # Step 3: Select analysts console.print( @@ -538,10 +351,10 @@ def get_user_selections(): ) selected_research_depth = select_research_depth() - # Step 5: OpenAI backend + # Step 5: LLM provider backend console.print( create_question_box( - "Step 5: OpenAI backend", "Select which service to talk to" + "Step 5: LLM Provider", "Select which service to talk to" ) ) selected_llm_provider, backend_url = select_llm_provider() @@ -601,30 +414,6 @@ def get_user_selections(): } -def get_ticker(): - """Get ticker symbol from user input.""" - return typer.prompt("", default="SPY") - - -def get_analysis_date(): - """Get the analysis date from user input.""" - while True: - date_str = typer.prompt( - "", default=datetime.datetime.now().strftime("%Y-%m-%d") - ) - try: - # Validate date format and ensure it's not in the future - analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d") - if analysis_date.date() > datetime.datetime.now().date(): - console.print("[red]Error: Analysis date cannot be in the future[/red]") - continue - return date_str - except ValueError: - console.print( - "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]" - ) - - def save_report_to_disk(final_state, ticker: str, save_path: Path): """Save complete analysis report to disk with organized subfolders.""" save_path.mkdir(parents=True, exist_ok=True) @@ -970,8 +759,11 @@ def run_analysis(): @wraps(func) def wrapper(*args, **kwargs): func(*args, **kwargs) - timestamp, tool_name, args = obj.tool_calls[-1] - args_str = ", ".join(f"{k}={v}" for k, v in args.items()) + timestamp, tool_name, tool_args = obj.tool_calls[-1] + if isinstance(tool_args, dict): + args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items()) + else: + args_str = str(tool_args) with open(log_file, "a", encoding="utf-8") as f: f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") return wrapper diff --git a/cli/message_buffer.py b/cli/message_buffer.py new file mode 100644 index 00000000..6aa2ea3f --- /dev/null +++ b/cli/message_buffer.py @@ -0,0 +1,192 @@ +from collections import deque +import datetime + + +class MessageBuffer: + # Fixed teams that always run (not user-selectable) + FIXED_AGENTS = { + "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], + "Trading Team": ["Trader"], + "Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"], + "Portfolio Management": ["Portfolio Manager"], + } + + # Analyst name mapping + ANALYST_MAPPING = { + "market": "Market Analyst", + "social": "Social Analyst", + "news": "News Analyst", + "fundamentals": "Fundamentals Analyst", + } + + # Report section mapping: section -> (analyst_key for filtering, finalizing_agent) + # analyst_key: which analyst selection controls this section (None = always included) + # finalizing_agent: which agent must be "completed" for this report to count as done + REPORT_SECTIONS = { + "market_report": ("market", "Market Analyst"), + "sentiment_report": ("social", "Social Analyst"), + "news_report": ("news", "News Analyst"), + "fundamentals_report": ("fundamentals", "Fundamentals Analyst"), + "investment_plan": (None, "Research Manager"), + "trader_investment_plan": (None, "Trader"), + "final_trade_decision": (None, "Portfolio Manager"), + } + + def __init__(self, max_length=100): + self.messages = deque(maxlen=max_length) + self.tool_calls = deque(maxlen=max_length) + self.current_report = None + self.final_report = None # Store the complete final report + self.agent_status = {} + self.current_agent = None + self.report_sections = {} + self.selected_analysts = [] + self._last_message_id = None + self._last_updated_section = None + + def init_for_analysis(self, selected_analysts): + """Initialize agent status and report sections based on selected analysts. + + Args: + selected_analysts: List of analyst type strings (e.g., ["market", "news"]) + """ + self.selected_analysts = [a.lower() for a in selected_analysts] + + # Build agent_status dynamically + self.agent_status = {} + + # Add selected analysts + for analyst_key in self.selected_analysts: + if analyst_key in self.ANALYST_MAPPING: + self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending" + + # Add fixed teams + for team_agents in self.FIXED_AGENTS.values(): + for agent in team_agents: + self.agent_status[agent] = "pending" + + # Build report_sections dynamically + self.report_sections = {} + for section, (analyst_key, _) in self.REPORT_SECTIONS.items(): + if analyst_key is None or analyst_key in self.selected_analysts: + self.report_sections[section] = None + + # Reset other state + self.current_report = None + self.final_report = None + self.current_agent = None + self.messages.clear() + self.tool_calls.clear() + self._last_message_id = None + self._last_updated_section = None + + def get_completed_reports_count(self): + """Count reports that are finalized (their finalizing agent is completed). + + A report is considered complete when: + 1. The report section has content (not None), AND + 2. The agent responsible for finalizing that report has status "completed" + + This prevents interim updates (like debate rounds) from counting as completed. + """ + count = 0 + for section in self.report_sections: + if section not in self.REPORT_SECTIONS: + continue + _, finalizing_agent = self.REPORT_SECTIONS[section] + # Report is complete if it has content AND its finalizing agent is done + has_content = self.report_sections.get(section) is not None + agent_done = self.agent_status.get(finalizing_agent) == "completed" + if has_content and agent_done: + count += 1 + return count + + def add_message(self, message_type, content): + timestamp = datetime.datetime.now().strftime("%H:%M:%S") + self.messages.append((timestamp, message_type, content)) + + def add_tool_call(self, tool_name, args): + timestamp = datetime.datetime.now().strftime("%H:%M:%S") + self.tool_calls.append((timestamp, tool_name, args)) + + def update_agent_status(self, agent, status): + if agent in self.agent_status: + self.agent_status[agent] = status + self.current_agent = agent + + def update_report_section(self, section_name, content): + if section_name in self.report_sections: + self.report_sections[section_name] = content + self._last_updated_section = section_name + self._update_current_report() + + def _update_current_report(self): + # For the panel display, only show the most recently updated section + latest_section = self._last_updated_section + latest_content = self.report_sections.get(latest_section) if latest_section else None + + # Fallback if section tracking is unavailable + if latest_content is None: + for section, content in self.report_sections.items(): + if content is not None: + latest_section = section + latest_content = content + + if latest_section and latest_content: + # Format the current section for display + section_titles = { + "market_report": "Market Analysis", + "sentiment_report": "Social Sentiment", + "news_report": "News Analysis", + "fundamentals_report": "Fundamentals Analysis", + "investment_plan": "Research Team Decision", + "trader_investment_plan": "Trading Team Plan", + "final_trade_decision": "Portfolio Management Decision", + } + self.current_report = ( + f"### {section_titles[latest_section]}\n{latest_content}" + ) + + # Update the final complete report + self._update_final_report() + + def _update_final_report(self): + report_parts = [] + + # Analyst Team Reports - use .get() to handle missing sections + analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"] + if any(self.report_sections.get(section) for section in analyst_sections): + report_parts.append("## Analyst Team Reports") + if self.report_sections.get("market_report"): + report_parts.append( + f"### Market Analysis\n{self.report_sections['market_report']}" + ) + if self.report_sections.get("sentiment_report"): + report_parts.append( + f"### Social Sentiment\n{self.report_sections['sentiment_report']}" + ) + if self.report_sections.get("news_report"): + report_parts.append( + f"### News Analysis\n{self.report_sections['news_report']}" + ) + if self.report_sections.get("fundamentals_report"): + report_parts.append( + f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" + ) + + # Research Team Reports + if self.report_sections.get("investment_plan"): + report_parts.append("## Research Team Decision") + report_parts.append(f"{self.report_sections['investment_plan']}") + + # Trading Team Reports + if self.report_sections.get("trader_investment_plan"): + report_parts.append("## Trading Team Plan") + report_parts.append(f"{self.report_sections['trader_investment_plan']}") + + # Portfolio Management Decision + if self.report_sections.get("final_trade_decision"): + report_parts.append("## Portfolio Management Decision") + report_parts.append(f"{self.report_sections['final_trade_decision']}") + + self.final_report = "\n\n".join(report_parts) if report_parts else None diff --git a/cli/models.py b/cli/models.py index f68c3da1..83922d7a 100644 --- a/cli/models.py +++ b/cli/models.py @@ -1,6 +1,4 @@ from enum import Enum -from typing import List, Optional, Dict -from pydantic import BaseModel class AnalystType(str, Enum): diff --git a/cli/utils.py b/cli/utils.py index 18abc3a7..d7b73376 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,12 +1,23 @@ -import questionary from typing import List, Optional, Tuple, Dict +try: + import questionary +except ImportError: # pragma: no cover - optional during non-interactive testing + questionary = None + from rich.console import Console from cli.models import AnalystType console = Console() + +def _ensure_questionary(): + if questionary is None: + raise RuntimeError( + "questionary is required for interactive CLI prompts. Install dependencies with `pip install .`." + ) + TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK" ANALYST_ORDER = [ @@ -17,10 +28,12 @@ ANALYST_ORDER = [ ] -def get_ticker() -> str: +def get_ticker(prompt_text: str | None = None) -> str: """Prompt the user to enter a ticker symbol.""" + _ensure_questionary() + prompt = prompt_text or f"Enter the exact ticker symbol to analyze ({TICKER_INPUT_EXAMPLES}):" ticker = questionary.text( - f"Enter the exact ticker symbol to analyze ({TICKER_INPUT_EXAMPLES}):", + prompt, validate=lambda x: len(x.strip()) > 0 or "Please enter a valid ticker symbol.", style=questionary.Style( [ @@ -42,8 +55,9 @@ def normalize_ticker_symbol(ticker: str) -> str: return ticker.strip().upper() -def get_analysis_date() -> str: +def get_analysis_date(prompt_text: str = "Enter the analysis date (YYYY-MM-DD):") -> str: """Prompt the user to enter a date in YYYY-MM-DD format.""" + _ensure_questionary() import re from datetime import datetime @@ -57,7 +71,7 @@ def get_analysis_date() -> str: return False date = questionary.text( - "Enter the analysis date (YYYY-MM-DD):", + prompt_text, validate=lambda x: validate_date(x.strip()) or "Please enter a valid date in YYYY-MM-DD format.", style=questionary.Style( @@ -77,6 +91,7 @@ def get_analysis_date() -> str: def select_analysts() -> List[AnalystType]: """Select analysts using an interactive checkbox.""" + _ensure_questionary() choices = questionary.checkbox( "Select Your [Analysts Team]:", choices=[ @@ -103,6 +118,7 @@ def select_analysts() -> List[AnalystType]: def select_research_depth() -> int: """Select research depth using an interactive selection.""" + _ensure_questionary() # Define research depth options with their corresponding values DEPTH_OPTIONS = [ @@ -135,6 +151,7 @@ def select_research_depth() -> int: def select_shallow_thinking_agent(provider) -> str: """Select shallow thinking llm engine using an interactive selection.""" + _ensure_questionary() # Define shallow thinking llm engine options with their corresponding model names # Ordering: medium → light → heavy (balanced first for quick tasks) @@ -200,6 +217,7 @@ def select_shallow_thinking_agent(provider) -> str: def select_deep_thinking_agent(provider) -> str: """Select deep thinking llm engine using an interactive selection.""" + _ensure_questionary() # Define deep thinking llm engine options with their corresponding model names # Ordering: heavy → medium → light (most capable first for deep tasks) @@ -263,8 +281,9 @@ def select_deep_thinking_agent(provider) -> str: return choice def select_llm_provider() -> tuple[str, str]: - """Select the OpenAI api url using interactive selection.""" - # Define OpenAI api options with their corresponding endpoints + """Select the LLM provider and API endpoint using interactive selection.""" + _ensure_questionary() + # Define provider API options with their corresponding endpoints BASE_URLS = [ ("OpenAI", "https://api.openai.com/v1"), ("Google", "https://generativelanguage.googleapis.com/v1"), @@ -291,7 +310,7 @@ def select_llm_provider() -> tuple[str, str]: ).ask() if choice is None: - console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") + console.print("\n[red]No LLM provider selected. Exiting...[/red]") exit(1) display_name, url = choice @@ -302,6 +321,7 @@ def select_llm_provider() -> tuple[str, str]: def ask_openai_reasoning_effort() -> str: """Ask for OpenAI reasoning effort level.""" + _ensure_questionary() choices = [ questionary.Choice("Medium (Default)", "medium"), questionary.Choice("High (More thorough)", "high"), @@ -323,6 +343,7 @@ def ask_anthropic_effort() -> str | None: Controls token usage and response thoroughness on Claude 4.5+ and 4.6 models. """ + _ensure_questionary() return questionary.select( "Select Effort Level:", choices=[ @@ -344,6 +365,7 @@ def ask_gemini_thinking_config() -> str | None: Returns thinking_level: "high" or "minimal". Client maps to appropriate API param based on model series. """ + _ensure_questionary() return questionary.select( "Select Thinking Mode:", choices=[ diff --git a/tests/test_message_buffer.py b/tests/test_message_buffer.py new file mode 100644 index 00000000..8bddaa0d --- /dev/null +++ b/tests/test_message_buffer.py @@ -0,0 +1,29 @@ +import unittest + +from cli.message_buffer import MessageBuffer + + +class MessageBufferTests(unittest.TestCase): + def setUp(self): + self.buffer = MessageBuffer() + self.buffer.init_for_analysis(["market", "news"]) + + def test_current_report_tracks_most_recent_updated_section(self): + self.buffer.update_report_section("market_report", "Market content") + self.assertIn("Market Analysis", self.buffer.current_report) + + self.buffer.update_report_section("news_report", "News content") + self.assertIn("News Analysis", self.buffer.current_report) + self.assertNotIn("Market Analysis", self.buffer.current_report) + + def test_init_resets_last_updated_section(self): + self.buffer.update_report_section("market_report", "Market content") + self.assertEqual(self.buffer._last_updated_section, "market_report") + + self.buffer.init_for_analysis(["fundamentals"]) + self.assertIsNone(self.buffer._last_updated_section) + self.assertIsNone(self.buffer.current_report) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ticker_symbol_handling.py b/tests/test_ticker_symbol_handling.py index 858d26cd..64d672de 100644 --- a/tests/test_ticker_symbol_handling.py +++ b/tests/test_ticker_symbol_handling.py @@ -1,7 +1,6 @@ import unittest from cli.utils import normalize_ticker_symbol -from tradingagents.agents.utils.agent_utils import build_instrument_context class TickerSymbolHandlingTests(unittest.TestCase): @@ -9,6 +8,11 @@ class TickerSymbolHandlingTests(unittest.TestCase): self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO") def test_build_instrument_context_mentions_exact_symbol(self): + try: + from tradingagents.agents.utils.agent_utils import build_instrument_context + except ModuleNotFoundError as exc: + self.skipTest(f"optional dependency missing: {exc}") + context = build_instrument_context("7203.T") self.assertIn("7203.T", context) self.assertIn("exchange suffix", context)