Merge 9c9c336a93 into a438acdbbd
This commit is contained in:
commit
a397f1085e
105
cli/main.py
105
cli/main.py
|
|
@ -11,19 +11,23 @@ from rich.columns import Columns
|
|||
from rich.markdown import Markdown
|
||||
from rich.layout import Layout
|
||||
from rich.text import Text
|
||||
from rich.live import Live
|
||||
from rich.table import Table
|
||||
from collections import deque
|
||||
import time
|
||||
from rich.tree import Tree
|
||||
from rich import box
|
||||
from rich.align import Align
|
||||
from rich.rule import Rule
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from cli.models import AnalystType
|
||||
from cli.utils import *
|
||||
from cli.utils import (
|
||||
get_ticker,
|
||||
get_analysis_date,
|
||||
select_analysts,
|
||||
select_research_depth,
|
||||
select_shallow_thinking_agent,
|
||||
select_deep_thinking_agent,
|
||||
select_llm_provider,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -122,31 +126,29 @@ class MessageBuffer:
|
|||
report_parts = []
|
||||
|
||||
# Analyst Team Reports
|
||||
if any(
|
||||
self.report_sections[section]
|
||||
for section in [
|
||||
"market_report",
|
||||
"sentiment_report",
|
||||
"news_report",
|
||||
"fundamentals_report",
|
||||
]
|
||||
):
|
||||
sections = self.report_sections
|
||||
if any(sections[sec] for sec in [
|
||||
"market_report",
|
||||
"sentiment_report",
|
||||
"news_report",
|
||||
"fundamentals_report",
|
||||
]):
|
||||
report_parts.append("## Analyst Team Reports")
|
||||
if self.report_sections["market_report"]:
|
||||
if sections["market_report"]:
|
||||
report_parts.append(
|
||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
||||
f"### Market Analysis\n{sections['market_report']}"
|
||||
)
|
||||
if self.report_sections["sentiment_report"]:
|
||||
if sections["sentiment_report"]:
|
||||
report_parts.append(
|
||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
||||
f"### Social Sentiment\n{sections['sentiment_report']}"
|
||||
)
|
||||
if self.report_sections["news_report"]:
|
||||
if sections["news_report"]:
|
||||
report_parts.append(
|
||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
||||
f"### News Analysis\n{sections['news_report']}"
|
||||
)
|
||||
if self.report_sections["fundamentals_report"]:
|
||||
if sections["fundamentals_report"]:
|
||||
report_parts.append(
|
||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
||||
f"### Fundamentals Analysis\n{sections['fundamentals_report']}"
|
||||
)
|
||||
|
||||
# Research Team Reports
|
||||
|
|
@ -166,6 +168,65 @@ class MessageBuffer:
|
|||
|
||||
self.final_report = "\n\n".join(report_parts) if report_parts else None
|
||||
|
||||
def reset(self):
|
||||
"""Clear all stored messages and reports and reset agent status."""
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
for key in self.report_sections:
|
||||
self.report_sections[key] = None
|
||||
for key in self.agent_status:
|
||||
self.agent_status[key] = "pending"
|
||||
self.current_agent = None
|
||||
|
||||
def get_report_sections(self):
|
||||
"""Return a copy of the current report sections."""
|
||||
return dict(self.report_sections)
|
||||
|
||||
def export_final_report(self, path: Path) -> None:
|
||||
"""Write the final report to ``path``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``final_report`` has not been generated yet.
|
||||
"""
|
||||
if self.final_report is None:
|
||||
raise ValueError("final_report has not been generated")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(self.final_report)
|
||||
|
||||
def get_messages(self, limit=None):
|
||||
"""Return the most recent messages.
|
||||
|
||||
Args:
|
||||
limit: Optional maximum number of messages to return. If ``None``,
|
||||
all stored messages are returned.
|
||||
|
||||
Returns:
|
||||
List of ``(timestamp, type, content)`` tuples ordered from oldest
|
||||
to newest.
|
||||
"""
|
||||
msgs = list(self.messages)
|
||||
if limit is not None:
|
||||
return msgs[-limit:]
|
||||
return msgs
|
||||
|
||||
def get_tool_calls(self, limit=None):
|
||||
"""Return the most recent tool calls.
|
||||
|
||||
Args:
|
||||
limit: Optional maximum number of tool calls to return. If ``None``,
|
||||
all stored tool calls are returned.
|
||||
|
||||
Returns:
|
||||
List of ``(timestamp, tool_name, args)`` tuples ordered from oldest
|
||||
to newest.
|
||||
"""
|
||||
calls = list(self.tool_calls)
|
||||
if limit is not None:
|
||||
return calls[-limit:]
|
||||
return calls
|
||||
|
||||
|
||||
message_buffer = MessageBuffer()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue