Improve CLI report tracking, modularity, and test resilience
This commit is contained in:
parent
c3ba3bf428
commit
59f17e6ecd
228
cli/main.py
228
cli/main.py
|
|
@ -16,7 +16,6 @@ from rich.markdown import Markdown
|
|||
from rich.layout import Layout
|
||||
from rich.text import Text
|
||||
from rich.table import Table
|
||||
from collections import deque
|
||||
import time
|
||||
from rich.tree import Tree
|
||||
from rich import box
|
||||
|
|
@ -29,6 +28,7 @@ from cli.models import AnalystType
|
|||
from cli.utils import *
|
||||
from cli.announcements import fetch_announcements, display_announcements
|
||||
from cli.stats_handler import StatsCallbackHandler
|
||||
from cli.message_buffer import MessageBuffer
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -39,193 +39,6 @@ app = typer.Typer(
|
|||
)
|
||||
|
||||
|
||||
# Create a deque to store recent messages with a maximum length
|
||||
class MessageBuffer:
|
||||
# Fixed teams that always run (not user-selectable)
|
||||
FIXED_AGENTS = {
|
||||
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
|
||||
"Trading Team": ["Trader"],
|
||||
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
|
||||
"Portfolio Management": ["Portfolio Manager"],
|
||||
}
|
||||
|
||||
# Analyst name mapping
|
||||
ANALYST_MAPPING = {
|
||||
"market": "Market Analyst",
|
||||
"social": "Social Analyst",
|
||||
"news": "News Analyst",
|
||||
"fundamentals": "Fundamentals Analyst",
|
||||
}
|
||||
|
||||
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
|
||||
# analyst_key: which analyst selection controls this section (None = always included)
|
||||
# finalizing_agent: which agent must be "completed" for this report to count as done
|
||||
REPORT_SECTIONS = {
|
||||
"market_report": ("market", "Market Analyst"),
|
||||
"sentiment_report": ("social", "Social Analyst"),
|
||||
"news_report": ("news", "News Analyst"),
|
||||
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
|
||||
"investment_plan": (None, "Research Manager"),
|
||||
"trader_investment_plan": (None, "Trader"),
|
||||
"final_trade_decision": (None, "Portfolio Manager"),
|
||||
}
|
||||
|
||||
def __init__(self, max_length=100):
|
||||
self.messages = deque(maxlen=max_length)
|
||||
self.tool_calls = deque(maxlen=max_length)
|
||||
self.current_report = None
|
||||
self.final_report = None # Store the complete final report
|
||||
self.agent_status = {}
|
||||
self.current_agent = None
|
||||
self.report_sections = {}
|
||||
self.selected_analysts = []
|
||||
self._last_message_id = None
|
||||
|
||||
def init_for_analysis(self, selected_analysts):
|
||||
"""Initialize agent status and report sections based on selected analysts.
|
||||
|
||||
Args:
|
||||
selected_analysts: List of analyst type strings (e.g., ["market", "news"])
|
||||
"""
|
||||
self.selected_analysts = [a.lower() for a in selected_analysts]
|
||||
|
||||
# Build agent_status dynamically
|
||||
self.agent_status = {}
|
||||
|
||||
# Add selected analysts
|
||||
for analyst_key in self.selected_analysts:
|
||||
if analyst_key in self.ANALYST_MAPPING:
|
||||
self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending"
|
||||
|
||||
# Add fixed teams
|
||||
for team_agents in self.FIXED_AGENTS.values():
|
||||
for agent in team_agents:
|
||||
self.agent_status[agent] = "pending"
|
||||
|
||||
# Build report_sections dynamically
|
||||
self.report_sections = {}
|
||||
for section, (analyst_key, _) in self.REPORT_SECTIONS.items():
|
||||
if analyst_key is None or analyst_key in self.selected_analysts:
|
||||
self.report_sections[section] = None
|
||||
|
||||
# Reset other state
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
self.current_agent = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
self._last_message_id = None
|
||||
|
||||
def get_completed_reports_count(self):
|
||||
"""Count reports that are finalized (their finalizing agent is completed).
|
||||
|
||||
A report is considered complete when:
|
||||
1. The report section has content (not None), AND
|
||||
2. The agent responsible for finalizing that report has status "completed"
|
||||
|
||||
This prevents interim updates (like debate rounds) from counting as completed.
|
||||
"""
|
||||
count = 0
|
||||
for section in self.report_sections:
|
||||
if section not in self.REPORT_SECTIONS:
|
||||
continue
|
||||
_, finalizing_agent = self.REPORT_SECTIONS[section]
|
||||
# Report is complete if it has content AND its finalizing agent is done
|
||||
has_content = self.report_sections.get(section) is not None
|
||||
agent_done = self.agent_status.get(finalizing_agent) == "completed"
|
||||
if has_content and agent_done:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def add_message(self, message_type, content):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.messages.append((timestamp, message_type, content))
|
||||
|
||||
def add_tool_call(self, tool_name, args):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.tool_calls.append((timestamp, tool_name, args))
|
||||
|
||||
def update_agent_status(self, agent, status):
|
||||
if agent in self.agent_status:
|
||||
self.agent_status[agent] = status
|
||||
self.current_agent = agent
|
||||
|
||||
def update_report_section(self, section_name, content):
|
||||
if section_name in self.report_sections:
|
||||
self.report_sections[section_name] = content
|
||||
self._update_current_report()
|
||||
|
||||
def _update_current_report(self):
|
||||
# For the panel display, only show the most recently updated section
|
||||
latest_section = None
|
||||
latest_content = None
|
||||
|
||||
# Find the most recently updated section
|
||||
for section, content in self.report_sections.items():
|
||||
if content is not None:
|
||||
latest_section = section
|
||||
latest_content = content
|
||||
|
||||
if latest_section and latest_content:
|
||||
# Format the current section for display
|
||||
section_titles = {
|
||||
"market_report": "Market Analysis",
|
||||
"sentiment_report": "Social Sentiment",
|
||||
"news_report": "News Analysis",
|
||||
"fundamentals_report": "Fundamentals Analysis",
|
||||
"investment_plan": "Research Team Decision",
|
||||
"trader_investment_plan": "Trading Team Plan",
|
||||
"final_trade_decision": "Portfolio Management Decision",
|
||||
}
|
||||
self.current_report = (
|
||||
f"### {section_titles[latest_section]}\n{latest_content}"
|
||||
)
|
||||
|
||||
# Update the final complete report
|
||||
self._update_final_report()
|
||||
|
||||
def _update_final_report(self):
|
||||
report_parts = []
|
||||
|
||||
# Analyst Team Reports - use .get() to handle missing sections
|
||||
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
|
||||
if any(self.report_sections.get(section) for section in analyst_sections):
|
||||
report_parts.append("## Analyst Team Reports")
|
||||
if self.report_sections.get("market_report"):
|
||||
report_parts.append(
|
||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
||||
)
|
||||
if self.report_sections.get("sentiment_report"):
|
||||
report_parts.append(
|
||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
||||
)
|
||||
if self.report_sections.get("news_report"):
|
||||
report_parts.append(
|
||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
||||
)
|
||||
if self.report_sections.get("fundamentals_report"):
|
||||
report_parts.append(
|
||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
||||
)
|
||||
|
||||
# Research Team Reports
|
||||
if self.report_sections.get("investment_plan"):
|
||||
report_parts.append("## Research Team Decision")
|
||||
report_parts.append(f"{self.report_sections['investment_plan']}")
|
||||
|
||||
# Trading Team Reports
|
||||
if self.report_sections.get("trader_investment_plan"):
|
||||
report_parts.append("## Trading Team Plan")
|
||||
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
||||
|
||||
# Portfolio Management Decision
|
||||
if self.report_sections.get("final_trade_decision"):
|
||||
report_parts.append("## Portfolio Management Decision")
|
||||
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
||||
|
||||
self.final_report = "\n\n".join(report_parts) if report_parts else None
|
||||
|
||||
|
||||
message_buffer = MessageBuffer()
|
||||
|
||||
|
||||
|
|
@ -506,7 +319,7 @@ def get_user_selections():
|
|||
"SPY",
|
||||
)
|
||||
)
|
||||
selected_ticker = get_ticker()
|
||||
selected_ticker = get_ticker(prompt_text="")
|
||||
|
||||
# Step 2: Analysis date
|
||||
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
|
|
@ -517,7 +330,7 @@ def get_user_selections():
|
|||
default_date,
|
||||
)
|
||||
)
|
||||
analysis_date = get_analysis_date()
|
||||
analysis_date = get_analysis_date(prompt_text="")
|
||||
|
||||
# Step 3: Select analysts
|
||||
console.print(
|
||||
|
|
@ -538,10 +351,10 @@ def get_user_selections():
|
|||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
||||
# Step 5: OpenAI backend
|
||||
# Step 5: LLM provider backend
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 5: OpenAI backend", "Select which service to talk to"
|
||||
"Step 5: LLM Provider", "Select which service to talk to"
|
||||
)
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
|
@ -601,30 +414,6 @@ def get_user_selections():
|
|||
}
|
||||
|
||||
|
||||
def get_ticker():
|
||||
"""Get ticker symbol from user input."""
|
||||
return typer.prompt("", default="SPY")
|
||||
|
||||
|
||||
def get_analysis_date():
|
||||
"""Get the analysis date from user input."""
|
||||
while True:
|
||||
date_str = typer.prompt(
|
||||
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
)
|
||||
try:
|
||||
# Validate date format and ensure it's not in the future
|
||||
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
|
||||
if analysis_date.date() > datetime.datetime.now().date():
|
||||
console.print("[red]Error: Analysis date cannot be in the future[/red]")
|
||||
continue
|
||||
return date_str
|
||||
except ValueError:
|
||||
console.print(
|
||||
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
|
||||
)
|
||||
|
||||
|
||||
def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||
"""Save complete analysis report to disk with organized subfolders."""
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -970,8 +759,11 @@ def run_analysis():
|
|||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
func(*args, **kwargs)
|
||||
timestamp, tool_name, args = obj.tool_calls[-1]
|
||||
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
||||
timestamp, tool_name, tool_args = obj.tool_calls[-1]
|
||||
if isinstance(tool_args, dict):
|
||||
args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items())
|
||||
else:
|
||||
args_str = str(tool_args)
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -0,0 +1,192 @@
|
|||
from collections import deque
|
||||
import datetime
|
||||
|
||||
|
||||
class MessageBuffer:
|
||||
# Fixed teams that always run (not user-selectable)
|
||||
FIXED_AGENTS = {
|
||||
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
|
||||
"Trading Team": ["Trader"],
|
||||
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
|
||||
"Portfolio Management": ["Portfolio Manager"],
|
||||
}
|
||||
|
||||
# Analyst name mapping
|
||||
ANALYST_MAPPING = {
|
||||
"market": "Market Analyst",
|
||||
"social": "Social Analyst",
|
||||
"news": "News Analyst",
|
||||
"fundamentals": "Fundamentals Analyst",
|
||||
}
|
||||
|
||||
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
|
||||
# analyst_key: which analyst selection controls this section (None = always included)
|
||||
# finalizing_agent: which agent must be "completed" for this report to count as done
|
||||
REPORT_SECTIONS = {
|
||||
"market_report": ("market", "Market Analyst"),
|
||||
"sentiment_report": ("social", "Social Analyst"),
|
||||
"news_report": ("news", "News Analyst"),
|
||||
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
|
||||
"investment_plan": (None, "Research Manager"),
|
||||
"trader_investment_plan": (None, "Trader"),
|
||||
"final_trade_decision": (None, "Portfolio Manager"),
|
||||
}
|
||||
|
||||
def __init__(self, max_length=100):
|
||||
self.messages = deque(maxlen=max_length)
|
||||
self.tool_calls = deque(maxlen=max_length)
|
||||
self.current_report = None
|
||||
self.final_report = None # Store the complete final report
|
||||
self.agent_status = {}
|
||||
self.current_agent = None
|
||||
self.report_sections = {}
|
||||
self.selected_analysts = []
|
||||
self._last_message_id = None
|
||||
self._last_updated_section = None
|
||||
|
||||
def init_for_analysis(self, selected_analysts):
|
||||
"""Initialize agent status and report sections based on selected analysts.
|
||||
|
||||
Args:
|
||||
selected_analysts: List of analyst type strings (e.g., ["market", "news"])
|
||||
"""
|
||||
self.selected_analysts = [a.lower() for a in selected_analysts]
|
||||
|
||||
# Build agent_status dynamically
|
||||
self.agent_status = {}
|
||||
|
||||
# Add selected analysts
|
||||
for analyst_key in self.selected_analysts:
|
||||
if analyst_key in self.ANALYST_MAPPING:
|
||||
self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending"
|
||||
|
||||
# Add fixed teams
|
||||
for team_agents in self.FIXED_AGENTS.values():
|
||||
for agent in team_agents:
|
||||
self.agent_status[agent] = "pending"
|
||||
|
||||
# Build report_sections dynamically
|
||||
self.report_sections = {}
|
||||
for section, (analyst_key, _) in self.REPORT_SECTIONS.items():
|
||||
if analyst_key is None or analyst_key in self.selected_analysts:
|
||||
self.report_sections[section] = None
|
||||
|
||||
# Reset other state
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
self.current_agent = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
self._last_message_id = None
|
||||
self._last_updated_section = None
|
||||
|
||||
def get_completed_reports_count(self):
|
||||
"""Count reports that are finalized (their finalizing agent is completed).
|
||||
|
||||
A report is considered complete when:
|
||||
1. The report section has content (not None), AND
|
||||
2. The agent responsible for finalizing that report has status "completed"
|
||||
|
||||
This prevents interim updates (like debate rounds) from counting as completed.
|
||||
"""
|
||||
count = 0
|
||||
for section in self.report_sections:
|
||||
if section not in self.REPORT_SECTIONS:
|
||||
continue
|
||||
_, finalizing_agent = self.REPORT_SECTIONS[section]
|
||||
# Report is complete if it has content AND its finalizing agent is done
|
||||
has_content = self.report_sections.get(section) is not None
|
||||
agent_done = self.agent_status.get(finalizing_agent) == "completed"
|
||||
if has_content and agent_done:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def add_message(self, message_type, content):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.messages.append((timestamp, message_type, content))
|
||||
|
||||
def add_tool_call(self, tool_name, args):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.tool_calls.append((timestamp, tool_name, args))
|
||||
|
||||
def update_agent_status(self, agent, status):
|
||||
if agent in self.agent_status:
|
||||
self.agent_status[agent] = status
|
||||
self.current_agent = agent
|
||||
|
||||
def update_report_section(self, section_name, content):
|
||||
if section_name in self.report_sections:
|
||||
self.report_sections[section_name] = content
|
||||
self._last_updated_section = section_name
|
||||
self._update_current_report()
|
||||
|
||||
def _update_current_report(self):
|
||||
# For the panel display, only show the most recently updated section
|
||||
latest_section = self._last_updated_section
|
||||
latest_content = self.report_sections.get(latest_section) if latest_section else None
|
||||
|
||||
# Fallback if section tracking is unavailable
|
||||
if latest_content is None:
|
||||
for section, content in self.report_sections.items():
|
||||
if content is not None:
|
||||
latest_section = section
|
||||
latest_content = content
|
||||
|
||||
if latest_section and latest_content:
|
||||
# Format the current section for display
|
||||
section_titles = {
|
||||
"market_report": "Market Analysis",
|
||||
"sentiment_report": "Social Sentiment",
|
||||
"news_report": "News Analysis",
|
||||
"fundamentals_report": "Fundamentals Analysis",
|
||||
"investment_plan": "Research Team Decision",
|
||||
"trader_investment_plan": "Trading Team Plan",
|
||||
"final_trade_decision": "Portfolio Management Decision",
|
||||
}
|
||||
self.current_report = (
|
||||
f"### {section_titles[latest_section]}\n{latest_content}"
|
||||
)
|
||||
|
||||
# Update the final complete report
|
||||
self._update_final_report()
|
||||
|
||||
def _update_final_report(self):
|
||||
report_parts = []
|
||||
|
||||
# Analyst Team Reports - use .get() to handle missing sections
|
||||
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
|
||||
if any(self.report_sections.get(section) for section in analyst_sections):
|
||||
report_parts.append("## Analyst Team Reports")
|
||||
if self.report_sections.get("market_report"):
|
||||
report_parts.append(
|
||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
||||
)
|
||||
if self.report_sections.get("sentiment_report"):
|
||||
report_parts.append(
|
||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
||||
)
|
||||
if self.report_sections.get("news_report"):
|
||||
report_parts.append(
|
||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
||||
)
|
||||
if self.report_sections.get("fundamentals_report"):
|
||||
report_parts.append(
|
||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
||||
)
|
||||
|
||||
# Research Team Reports
|
||||
if self.report_sections.get("investment_plan"):
|
||||
report_parts.append("## Research Team Decision")
|
||||
report_parts.append(f"{self.report_sections['investment_plan']}")
|
||||
|
||||
# Trading Team Reports
|
||||
if self.report_sections.get("trader_investment_plan"):
|
||||
report_parts.append("## Trading Team Plan")
|
||||
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
||||
|
||||
# Portfolio Management Decision
|
||||
if self.report_sections.get("final_trade_decision"):
|
||||
report_parts.append("## Portfolio Management Decision")
|
||||
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
||||
|
||||
self.final_report = "\n\n".join(report_parts) if report_parts else None
|
||||
|
|
@ -1,6 +1,4 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnalystType(str, Enum):
|
||||
|
|
|
|||
38
cli/utils.py
38
cli/utils.py
|
|
@ -1,12 +1,23 @@
|
|||
import questionary
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
try:
|
||||
import questionary
|
||||
except ImportError: # pragma: no cover - optional during non-interactive testing
|
||||
questionary = None
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from cli.models import AnalystType
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def _ensure_questionary():
|
||||
if questionary is None:
|
||||
raise RuntimeError(
|
||||
"questionary is required for interactive CLI prompts. Install dependencies with `pip install .`."
|
||||
)
|
||||
|
||||
TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
|
||||
|
||||
ANALYST_ORDER = [
|
||||
|
|
@ -17,10 +28,12 @@ ANALYST_ORDER = [
|
|||
]
|
||||
|
||||
|
||||
def get_ticker() -> str:
|
||||
def get_ticker(prompt_text: str | None = None) -> str:
|
||||
"""Prompt the user to enter a ticker symbol."""
|
||||
_ensure_questionary()
|
||||
prompt = prompt_text or f"Enter the exact ticker symbol to analyze ({TICKER_INPUT_EXAMPLES}):"
|
||||
ticker = questionary.text(
|
||||
f"Enter the exact ticker symbol to analyze ({TICKER_INPUT_EXAMPLES}):",
|
||||
prompt,
|
||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid ticker symbol.",
|
||||
style=questionary.Style(
|
||||
[
|
||||
|
|
@ -42,8 +55,9 @@ def normalize_ticker_symbol(ticker: str) -> str:
|
|||
return ticker.strip().upper()
|
||||
|
||||
|
||||
def get_analysis_date() -> str:
|
||||
def get_analysis_date(prompt_text: str = "Enter the analysis date (YYYY-MM-DD):") -> str:
|
||||
"""Prompt the user to enter a date in YYYY-MM-DD format."""
|
||||
_ensure_questionary()
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -57,7 +71,7 @@ def get_analysis_date() -> str:
|
|||
return False
|
||||
|
||||
date = questionary.text(
|
||||
"Enter the analysis date (YYYY-MM-DD):",
|
||||
prompt_text,
|
||||
validate=lambda x: validate_date(x.strip())
|
||||
or "Please enter a valid date in YYYY-MM-DD format.",
|
||||
style=questionary.Style(
|
||||
|
|
@ -77,6 +91,7 @@ def get_analysis_date() -> str:
|
|||
|
||||
def select_analysts() -> List[AnalystType]:
|
||||
"""Select analysts using an interactive checkbox."""
|
||||
_ensure_questionary()
|
||||
choices = questionary.checkbox(
|
||||
"Select Your [Analysts Team]:",
|
||||
choices=[
|
||||
|
|
@ -103,6 +118,7 @@ def select_analysts() -> List[AnalystType]:
|
|||
|
||||
def select_research_depth() -> int:
|
||||
"""Select research depth using an interactive selection."""
|
||||
_ensure_questionary()
|
||||
|
||||
# Define research depth options with their corresponding values
|
||||
DEPTH_OPTIONS = [
|
||||
|
|
@ -135,6 +151,7 @@ def select_research_depth() -> int:
|
|||
|
||||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
_ensure_questionary()
|
||||
|
||||
# Define shallow thinking llm engine options with their corresponding model names
|
||||
# Ordering: medium → light → heavy (balanced first for quick tasks)
|
||||
|
|
@ -200,6 +217,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
|
||||
def select_deep_thinking_agent(provider) -> str:
|
||||
"""Select deep thinking llm engine using an interactive selection."""
|
||||
_ensure_questionary()
|
||||
|
||||
# Define deep thinking llm engine options with their corresponding model names
|
||||
# Ordering: heavy → medium → light (most capable first for deep tasks)
|
||||
|
|
@ -263,8 +281,9 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
return choice
|
||||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
"""Select the OpenAI api url using interactive selection."""
|
||||
# Define OpenAI api options with their corresponding endpoints
|
||||
"""Select the LLM provider and API endpoint using interactive selection."""
|
||||
_ensure_questionary()
|
||||
# Define provider API options with their corresponding endpoints
|
||||
BASE_URLS = [
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
|
|
@ -291,7 +310,7 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
display_name, url = choice
|
||||
|
|
@ -302,6 +321,7 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
|
||||
def ask_openai_reasoning_effort() -> str:
|
||||
"""Ask for OpenAI reasoning effort level."""
|
||||
_ensure_questionary()
|
||||
choices = [
|
||||
questionary.Choice("Medium (Default)", "medium"),
|
||||
questionary.Choice("High (More thorough)", "high"),
|
||||
|
|
@ -323,6 +343,7 @@ def ask_anthropic_effort() -> str | None:
|
|||
|
||||
Controls token usage and response thoroughness on Claude 4.5+ and 4.6 models.
|
||||
"""
|
||||
_ensure_questionary()
|
||||
return questionary.select(
|
||||
"Select Effort Level:",
|
||||
choices=[
|
||||
|
|
@ -344,6 +365,7 @@ def ask_gemini_thinking_config() -> str | None:
|
|||
Returns thinking_level: "high" or "minimal".
|
||||
Client maps to appropriate API param based on model series.
|
||||
"""
|
||||
_ensure_questionary()
|
||||
return questionary.select(
|
||||
"Select Thinking Mode:",
|
||||
choices=[
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
import unittest
|
||||
|
||||
from cli.message_buffer import MessageBuffer
|
||||
|
||||
|
||||
class MessageBufferTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.buffer = MessageBuffer()
|
||||
self.buffer.init_for_analysis(["market", "news"])
|
||||
|
||||
def test_current_report_tracks_most_recent_updated_section(self):
|
||||
self.buffer.update_report_section("market_report", "Market content")
|
||||
self.assertIn("Market Analysis", self.buffer.current_report)
|
||||
|
||||
self.buffer.update_report_section("news_report", "News content")
|
||||
self.assertIn("News Analysis", self.buffer.current_report)
|
||||
self.assertNotIn("Market Analysis", self.buffer.current_report)
|
||||
|
||||
def test_init_resets_last_updated_section(self):
|
||||
self.buffer.update_report_section("market_report", "Market content")
|
||||
self.assertEqual(self.buffer._last_updated_section, "market_report")
|
||||
|
||||
self.buffer.init_for_analysis(["fundamentals"])
|
||||
self.assertIsNone(self.buffer._last_updated_section)
|
||||
self.assertIsNone(self.buffer.current_report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import unittest
|
||||
|
||||
from cli.utils import normalize_ticker_symbol
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
|
||||
|
||||
class TickerSymbolHandlingTests(unittest.TestCase):
|
||||
|
|
@ -9,6 +8,11 @@ class TickerSymbolHandlingTests(unittest.TestCase):
|
|||
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")
|
||||
|
||||
def test_build_instrument_context_mentions_exact_symbol(self):
|
||||
try:
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
except ModuleNotFoundError as exc:
|
||||
self.skipTest(f"optional dependency missing: {exc}")
|
||||
|
||||
context = build_instrument_context("7203.T")
|
||||
self.assertIn("7203.T", context)
|
||||
self.assertIn("exchange suffix", context)
|
||||
|
|
|
|||
Loading…
Reference in New Issue