diff --git a/cli/analysis.py b/cli/analysis.py index 56829a3d..bd59e1fb 100644 --- a/cli/analysis.py +++ b/cli/analysis.py @@ -13,7 +13,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.dataflows.config import get_config from cli.state import message_buffer -from cli.models import AnalystType +from cli.models import AnalystType, AgentStatus from cli.display import ( create_layout, update_display, @@ -138,32 +138,32 @@ def get_user_selections() -> dict: def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) -> None: if "market_report" in chunk and chunk["market_report"]: message_buffer.update_report_section("market_report", chunk["market_report"]) - message_buffer.update_agent_status("Market Analyst", "completed") + message_buffer.update_agent_status("Market Analyst", AgentStatus.COMPLETED) if AnalystType.SOCIAL in selected_analysts: - message_buffer.update_agent_status("Social Analyst", "in_progress") + message_buffer.update_agent_status("Social Analyst", AgentStatus.IN_PROGRESS) if "sentiment_report" in chunk and chunk["sentiment_report"]: message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"]) - message_buffer.update_agent_status("Social Analyst", "completed") + message_buffer.update_agent_status("Social Analyst", AgentStatus.COMPLETED) if AnalystType.NEWS in selected_analysts: - message_buffer.update_agent_status("News Analyst", "in_progress") + message_buffer.update_agent_status("News Analyst", AgentStatus.IN_PROGRESS) if "news_report" in chunk and chunk["news_report"]: message_buffer.update_report_section("news_report", chunk["news_report"]) - message_buffer.update_agent_status("News Analyst", "completed") + message_buffer.update_agent_status("News Analyst", AgentStatus.COMPLETED) if AnalystType.FUNDAMENTALS in selected_analysts: - message_buffer.update_agent_status("Fundamentals Analyst", "in_progress") + message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.IN_PROGRESS) if "fundamentals_report" in chunk and chunk["fundamentals_report"]: message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"]) - message_buffer.update_agent_status("Fundamentals Analyst", "completed") - update_research_team_status("in_progress") + message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.COMPLETED) + update_research_team_status(AgentStatus.IN_PROGRESS) if "investment_debate_state" in chunk and chunk["investment_debate_state"]: debate_state = chunk["investment_debate_state"] if "bull_history" in debate_state and debate_state["bull_history"]: - update_research_team_status("in_progress") + update_research_team_status(AgentStatus.IN_PROGRESS) bull_responses = debate_state["bull_history"].split("\n") latest_bull = bull_responses[-1] if bull_responses else "" if latest_bull: @@ -174,7 +174,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) ) if "bear_history" in debate_state and debate_state["bear_history"]: - update_research_team_status("in_progress") + update_research_team_status(AgentStatus.IN_PROGRESS) bear_responses = debate_state["bear_history"].split("\n") latest_bear = bear_responses[-1] if bear_responses else "" if latest_bear: @@ -185,7 +185,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) ) if "judge_decision" in debate_state and debate_state["judge_decision"]: - update_research_team_status("in_progress") + update_research_team_status(AgentStatus.IN_PROGRESS) message_buffer.add_message( "Reasoning", f"Research Manager: {debate_state['judge_decision']}", @@ -194,18 +194,18 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) "investment_plan", f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}", ) - update_research_team_status("completed") - message_buffer.update_agent_status("Risky Analyst", "in_progress") + update_research_team_status(AgentStatus.COMPLETED) + message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS) if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]: message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"]) - message_buffer.update_agent_status("Risky Analyst", "in_progress") + message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS) if "risk_debate_state" in chunk and chunk["risk_debate_state"]: risk_state = chunk["risk_debate_state"] if "current_risky_response" in risk_state and risk_state["current_risky_response"]: - message_buffer.update_agent_status("Risky Analyst", "in_progress") + message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS) message_buffer.add_message( "Reasoning", f"Risky Analyst: {risk_state['current_risky_response']}", @@ -216,7 +216,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) ) if "current_safe_response" in risk_state and risk_state["current_safe_response"]: - message_buffer.update_agent_status("Safe Analyst", "in_progress") + message_buffer.update_agent_status("Safe Analyst", AgentStatus.IN_PROGRESS) message_buffer.add_message( "Reasoning", f"Safe Analyst: {risk_state['current_safe_response']}", @@ -227,7 +227,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) ) if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]: - message_buffer.update_agent_status("Neutral Analyst", "in_progress") + message_buffer.update_agent_status("Neutral Analyst", AgentStatus.IN_PROGRESS) message_buffer.add_message( "Reasoning", f"Neutral Analyst: {risk_state['current_neutral_response']}", @@ -238,7 +238,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) ) if "judge_decision" in risk_state and risk_state["judge_decision"]: - message_buffer.update_agent_status("Portfolio Manager", "in_progress") + message_buffer.update_agent_status("Portfolio Manager", AgentStatus.IN_PROGRESS) message_buffer.add_message( "Reasoning", f"Portfolio Manager: {risk_state['judge_decision']}", @@ -247,10 +247,10 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) "final_trade_decision", f"### Portfolio Manager Decision\n{risk_state['judge_decision']}", ) - message_buffer.update_agent_status("Risky Analyst", "completed") - message_buffer.update_agent_status("Safe Analyst", "completed") - message_buffer.update_agent_status("Neutral Analyst", "completed") - message_buffer.update_agent_status("Portfolio Manager", "completed") + message_buffer.update_agent_status("Risky Analyst", AgentStatus.COMPLETED) + message_buffer.update_agent_status("Safe Analyst", AgentStatus.COMPLETED) + message_buffer.update_agent_status("Neutral Analyst", AgentStatus.COMPLETED) + message_buffer.update_agent_status("Portfolio Manager", AgentStatus.COMPLETED) def setup_logging_decorators(report_dir, log_file) -> tuple: @@ -383,7 +383,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts update_display(layout) for agent in message_buffer.agent_status: - message_buffer.update_agent_status(agent, "pending") + message_buffer.update_agent_status(agent, AgentStatus.PENDING) for section in message_buffer.report_sections: message_buffer.report_sections[section] = None @@ -391,7 +391,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts message_buffer.final_report = None first_analyst = f"{selected_analysts[0].value.capitalize()} Analyst" - message_buffer.update_agent_status(first_analyst, "in_progress") + message_buffer.update_agent_status(first_analyst, AgentStatus.IN_PROGRESS) update_display(layout) spinner_text = f"Analyzing {ticker} on {analysis_date}..." @@ -430,7 +430,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts decision = graph.process_signal(final_state["final_trade_decision"]) for agent in message_buffer.agent_status: - message_buffer.update_agent_status(agent, "completed") + message_buffer.update_agent_status(agent, AgentStatus.COMPLETED) message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}") diff --git a/cli/backtest_cmd.py b/cli/backtest_cmd.py index 34124498..0aa5a124 100644 --- a/cli/backtest_cmd.py +++ b/cli/backtest_cmd.py @@ -9,7 +9,7 @@ from rich.table import Table from rich import box from tradingagents.backtesting import SimpleBacktestEngine, DataLoader -from tradingagents.models.backtest import BacktestConfig +from tradingagents.models.backtest import BacktestConfig, BacktestStatus from tradingagents.models.portfolio import PortfolioConfig from cli.display import create_question_box @@ -158,7 +158,7 @@ def run_backtest( console.print() - if result.status == "failed": + if result.status == BacktestStatus.FAILED: console.print(f"[red]Backtest failed: {result.error_message}[/red]") return diff --git a/cli/display.py b/cli/display.py index 449ee41b..879a7870 100644 --- a/cli/display.py +++ b/cli/display.py @@ -1,6 +1,7 @@ from typing import Optional, Dict, Any from rich.console import Console +from cli.models import AgentStatus from rich.panel import Panel from rich.spinner import Spinner from rich.markdown import Markdown @@ -72,34 +73,34 @@ def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None: for team, agents in teams.items(): first_agent = agents[0] status = message_buffer.agent_status[first_agent] - if status == "in_progress": + if status == AgentStatus.IN_PROGRESS: spinner = Spinner( "dots", text="[blue]in_progress[/blue]", style="bold cyan" ) status_cell = spinner else: status_color = { - "pending": "yellow", - "completed": "green", - "error": "red", + AgentStatus.PENDING: "yellow", + AgentStatus.COMPLETED: "green", + AgentStatus.ERROR: "red", }.get(status, "white") - status_cell = f"[{status_color}]{status}[/{status_color}]" + status_cell = f"[{status_color}]{status.value}[/{status_color}]" progress_table.add_row(team, first_agent, status_cell) for agent in agents[1:]: status = message_buffer.agent_status[agent] - if status == "in_progress": + if status == AgentStatus.IN_PROGRESS: spinner = Spinner( "dots", text="[blue]in_progress[/blue]", style="bold cyan" ) status_cell = spinner else: status_color = { - "pending": "yellow", - "completed": "green", - "error": "red", + AgentStatus.PENDING: "yellow", + AgentStatus.COMPLETED: "green", + AgentStatus.ERROR: "red", }.get(status, "white") - status_cell = f"[{status_color}]{status}[/{status_color}]" + status_cell = f"[{status_color}]{status.value}[/{status_color}]" progress_table.add_row("", agent, status_cell) progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim") @@ -383,7 +384,7 @@ def display_complete_report(final_state: Dict[str, Any]) -> None: ) -def update_research_team_status(status: str) -> None: +def update_research_team_status(status: AgentStatus) -> None: research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"] for agent in research_team: message_buffer.update_agent_status(agent, status) diff --git a/cli/models.py b/cli/models.py index f68c3da1..fc5b28c1 100644 --- a/cli/models.py +++ b/cli/models.py @@ -8,3 +8,10 @@ class AnalystType(str, Enum): SOCIAL = "social" NEWS = "news" FUNDAMENTALS = "fundamentals" + + +class AgentStatus(str, Enum): + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + ERROR = "error" diff --git a/cli/state.py b/cli/state.py index b9d64651..b032712c 100644 --- a/cli/state.py +++ b/cli/state.py @@ -2,6 +2,8 @@ import datetime from collections import deque from typing import Optional, Dict, Any, Deque +from cli.models import AgentStatus + class MessageBuffer: def __init__(self, max_length: int = 100) -> None: @@ -9,19 +11,19 @@ class MessageBuffer: self.tool_calls: Deque = deque(maxlen=max_length) self.current_report = None self.final_report = None - self.agent_status = { - "Market Analyst": "pending", - "Social Analyst": "pending", - "News Analyst": "pending", - "Fundamentals Analyst": "pending", - "Bull Researcher": "pending", - "Bear Researcher": "pending", - "Research Manager": "pending", - "Trader": "pending", - "Risky Analyst": "pending", - "Neutral Analyst": "pending", - "Safe Analyst": "pending", - "Portfolio Manager": "pending", + self.agent_status: Dict[str, AgentStatus] = { + "Market Analyst": AgentStatus.PENDING, + "Social Analyst": AgentStatus.PENDING, + "News Analyst": AgentStatus.PENDING, + "Fundamentals Analyst": AgentStatus.PENDING, + "Bull Researcher": AgentStatus.PENDING, + "Bear Researcher": AgentStatus.PENDING, + "Research Manager": AgentStatus.PENDING, + "Trader": AgentStatus.PENDING, + "Risky Analyst": AgentStatus.PENDING, + "Neutral Analyst": AgentStatus.PENDING, + "Safe Analyst": AgentStatus.PENDING, + "Portfolio Manager": AgentStatus.PENDING, } self.current_agent = None self.report_sections = { @@ -42,7 +44,7 @@ class MessageBuffer: timestamp = datetime.datetime.now().strftime("%H:%M:%S") self.tool_calls.append((timestamp, tool_name, args)) - def update_agent_status(self, agent: str, status: str) -> None: + def update_agent_status(self, agent: str, status: AgentStatus) -> None: if agent in self.agent_status: self.agent_status[agent] = status self.current_agent = agent @@ -123,7 +125,7 @@ class MessageBuffer: def reset(self) -> None: for agent in self.agent_status: - self.agent_status[agent] = "pending" + self.agent_status[agent] = AgentStatus.PENDING for section in self.report_sections: self.report_sections[section] = None self.current_report = None diff --git a/tests/models/test_backtest.py b/tests/models/test_backtest.py index 809fe281..0718b5dd 100644 --- a/tests/models/test_backtest.py +++ b/tests/models/test_backtest.py @@ -6,6 +6,7 @@ import pytest from tradingagents.models.backtest import ( BacktestConfig, BacktestResult, + BacktestStatus, BacktestMetrics, EquityCurvePoint, TradeLog, @@ -287,7 +288,7 @@ class TestBacktestResult: ) assert result.duration_seconds == 330.0 - assert result.status == "completed" + assert result.status == BacktestStatus.COMPLETED def test_to_dict(self): config = BacktestConfig( diff --git a/tradingagents/backtesting/engine.py b/tradingagents/backtesting/engine.py index 589290c6..0a64242f 100644 --- a/tradingagents/backtesting/engine.py +++ b/tradingagents/backtesting/engine.py @@ -7,6 +7,7 @@ from uuid import uuid4 from tradingagents.models.backtest import ( BacktestConfig, BacktestResult, + BacktestStatus, EquityCurvePoint, TradeLog, ) @@ -69,7 +70,7 @@ class BacktestEngine: daily_returns=self.daily_returns, started_at=started_at, completed_at=completed_at, - status="completed", + status=BacktestStatus.COMPLETED, ) except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e: @@ -84,7 +85,7 @@ class BacktestEngine: daily_returns=self.daily_returns, started_at=started_at, completed_at=completed_at, - status="failed", + status=BacktestStatus.FAILED, error_message=str(e), ) diff --git a/tradingagents/models/backtest.py b/tradingagents/models/backtest.py index c18aaa13..d8a967e0 100644 --- a/tradingagents/models/backtest.py +++ b/tradingagents/models/backtest.py @@ -10,6 +10,13 @@ from .portfolio import PortfolioConfig from .trading import Trade +class BacktestStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + class BacktestConfig(BaseModel): id: UUID = Field(default_factory=uuid4) name: str = Field(default="Backtest") @@ -212,7 +219,7 @@ class BacktestResult(BaseModel): daily_returns: list[Decimal] = Field(default_factory=list) started_at: datetime completed_at: datetime - status: str = Field(default="completed") + status: BacktestStatus = Field(default=BacktestStatus.COMPLETED) error_message: Optional[str] = None @computed_field