diff --git a/cli/main.py b/cli/main.py index b21e30f1..a00812f8 100644 --- a/cli/main.py +++ b/cli/main.py @@ -13,16 +13,21 @@ from rich.layout import Layout from rich.text import Text from rich.table import Table from collections import deque -import time -from rich.tree import Tree from rich import box from rich.align import Align -from rich.rule import Rule from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG from cli.models import AnalystType -from cli.utils import * +from cli.utils import ( + get_ticker, + get_analysis_date, + select_analysts, + select_research_depth, + select_shallow_thinking_agent, + select_deep_thinking_agent, + select_llm_provider, +) console = Console() @@ -121,31 +126,29 @@ class MessageBuffer: report_parts = [] # Analyst Team Reports - if any( - self.report_sections[section] - for section in [ - "market_report", - "sentiment_report", - "news_report", - "fundamentals_report", - ] - ): + sections = self.report_sections + if any(sections[sec] for sec in [ + "market_report", + "sentiment_report", + "news_report", + "fundamentals_report", + ]): report_parts.append("## Analyst Team Reports") - if self.report_sections["market_report"]: + if sections["market_report"]: report_parts.append( - f"### Market Analysis\n{self.report_sections['market_report']}" + f"### Market Analysis\n{sections['market_report']}" ) - if self.report_sections["sentiment_report"]: + if sections["sentiment_report"]: report_parts.append( - f"### Social Sentiment\n{self.report_sections['sentiment_report']}" + f"### Social Sentiment\n{sections['sentiment_report']}" ) - if self.report_sections["news_report"]: + if sections["news_report"]: report_parts.append( - f"### News Analysis\n{self.report_sections['news_report']}" + f"### News Analysis\n{sections['news_report']}" ) - if self.report_sections["fundamentals_report"]: + if sections["fundamentals_report"]: report_parts.append( - f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" + f"### Fundamentals Analysis\n{sections['fundamentals_report']}" ) # Research Team Reports @@ -181,6 +184,49 @@ class MessageBuffer: """Return a copy of the current report sections.""" return dict(self.report_sections) + def export_final_report(self, path: Path) -> None: + """Write the final report to ``path``. + + Raises: + ValueError: If ``final_report`` has not been generated yet. + """ + if self.final_report is None: + raise ValueError("final_report has not been generated") + with open(path, "w", encoding="utf-8") as f: + f.write(self.final_report) + + def get_messages(self, limit=None): + """Return the most recent messages. + + Args: + limit: Optional maximum number of messages to return. If ``None``, + all stored messages are returned. + + Returns: + List of ``(timestamp, type, content)`` tuples ordered from oldest + to newest. + """ + msgs = list(self.messages) + if limit is not None: + return msgs[-limit:] + return msgs + + def get_tool_calls(self, limit=None): + """Return the most recent tool calls. + + Args: + limit: Optional maximum number of tool calls to return. If ``None``, + all stored tool calls are returned. + + Returns: + List of ``(timestamp, tool_name, args)`` tuples ordered from oldest + to newest. + """ + calls = list(self.tool_calls) + if limit is not None: + return calls[-limit:] + return calls + message_buffer = MessageBuffer()