feat: create enums for status strings

Added AgentStatus enum for CLI agent tracking (pending, in_progress,
completed, error) and BacktestStatus enum for backtest results (pending,
running, completed, failed). Replaces string literals with type-safe
enum values throughout the codebase.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 04:13:59 -05:00
parent b8be981bc8
commit 1346e20b5e
8 changed files with 77 additions and 58 deletions

View File

@ -13,7 +13,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.dataflows.config import get_config
from cli.state import message_buffer
from cli.models import AnalystType
from cli.models import AnalystType, AgentStatus
from cli.display import (
create_layout,
update_display,
@ -138,32 +138,32 @@ def get_user_selections() -> dict:
def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) -> None:
if "market_report" in chunk and chunk["market_report"]:
message_buffer.update_report_section("market_report", chunk["market_report"])
message_buffer.update_agent_status("Market Analyst", "completed")
message_buffer.update_agent_status("Market Analyst", AgentStatus.COMPLETED)
if AnalystType.SOCIAL in selected_analysts:
message_buffer.update_agent_status("Social Analyst", "in_progress")
message_buffer.update_agent_status("Social Analyst", AgentStatus.IN_PROGRESS)
if "sentiment_report" in chunk and chunk["sentiment_report"]:
message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"])
message_buffer.update_agent_status("Social Analyst", "completed")
message_buffer.update_agent_status("Social Analyst", AgentStatus.COMPLETED)
if AnalystType.NEWS in selected_analysts:
message_buffer.update_agent_status("News Analyst", "in_progress")
message_buffer.update_agent_status("News Analyst", AgentStatus.IN_PROGRESS)
if "news_report" in chunk and chunk["news_report"]:
message_buffer.update_report_section("news_report", chunk["news_report"])
message_buffer.update_agent_status("News Analyst", "completed")
message_buffer.update_agent_status("News Analyst", AgentStatus.COMPLETED)
if AnalystType.FUNDAMENTALS in selected_analysts:
message_buffer.update_agent_status("Fundamentals Analyst", "in_progress")
message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.IN_PROGRESS)
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"])
message_buffer.update_agent_status("Fundamentals Analyst", "completed")
update_research_team_status("in_progress")
message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.COMPLETED)
update_research_team_status(AgentStatus.IN_PROGRESS)
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
debate_state = chunk["investment_debate_state"]
if "bull_history" in debate_state and debate_state["bull_history"]:
update_research_team_status("in_progress")
update_research_team_status(AgentStatus.IN_PROGRESS)
bull_responses = debate_state["bull_history"].split("\n")
latest_bull = bull_responses[-1] if bull_responses else ""
if latest_bull:
@ -174,7 +174,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
)
if "bear_history" in debate_state and debate_state["bear_history"]:
update_research_team_status("in_progress")
update_research_team_status(AgentStatus.IN_PROGRESS)
bear_responses = debate_state["bear_history"].split("\n")
latest_bear = bear_responses[-1] if bear_responses else ""
if latest_bear:
@ -185,7 +185,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
)
if "judge_decision" in debate_state and debate_state["judge_decision"]:
update_research_team_status("in_progress")
update_research_team_status(AgentStatus.IN_PROGRESS)
message_buffer.add_message(
"Reasoning",
f"Research Manager: {debate_state['judge_decision']}",
@ -194,18 +194,18 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
"investment_plan",
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
)
update_research_team_status("completed")
message_buffer.update_agent_status("Risky Analyst", "in_progress")
update_research_team_status(AgentStatus.COMPLETED)
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
message_buffer.update_agent_status("Risky Analyst", "in_progress")
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
risk_state = chunk["risk_debate_state"]
if "current_risky_response" in risk_state and risk_state["current_risky_response"]:
message_buffer.update_agent_status("Risky Analyst", "in_progress")
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
message_buffer.add_message(
"Reasoning",
f"Risky Analyst: {risk_state['current_risky_response']}",
@ -216,7 +216,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
)
if "current_safe_response" in risk_state and risk_state["current_safe_response"]:
message_buffer.update_agent_status("Safe Analyst", "in_progress")
message_buffer.update_agent_status("Safe Analyst", AgentStatus.IN_PROGRESS)
message_buffer.add_message(
"Reasoning",
f"Safe Analyst: {risk_state['current_safe_response']}",
@ -227,7 +227,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
)
if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]:
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
message_buffer.update_agent_status("Neutral Analyst", AgentStatus.IN_PROGRESS)
message_buffer.add_message(
"Reasoning",
f"Neutral Analyst: {risk_state['current_neutral_response']}",
@ -238,7 +238,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
)
if "judge_decision" in risk_state and risk_state["judge_decision"]:
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
message_buffer.update_agent_status("Portfolio Manager", AgentStatus.IN_PROGRESS)
message_buffer.add_message(
"Reasoning",
f"Portfolio Manager: {risk_state['judge_decision']}",
@ -247,10 +247,10 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
"final_trade_decision",
f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
)
message_buffer.update_agent_status("Risky Analyst", "completed")
message_buffer.update_agent_status("Safe Analyst", "completed")
message_buffer.update_agent_status("Neutral Analyst", "completed")
message_buffer.update_agent_status("Portfolio Manager", "completed")
message_buffer.update_agent_status("Risky Analyst", AgentStatus.COMPLETED)
message_buffer.update_agent_status("Safe Analyst", AgentStatus.COMPLETED)
message_buffer.update_agent_status("Neutral Analyst", AgentStatus.COMPLETED)
message_buffer.update_agent_status("Portfolio Manager", AgentStatus.COMPLETED)
def setup_logging_decorators(report_dir, log_file) -> tuple:
@ -383,7 +383,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
update_display(layout)
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "pending")
message_buffer.update_agent_status(agent, AgentStatus.PENDING)
for section in message_buffer.report_sections:
message_buffer.report_sections[section] = None
@ -391,7 +391,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
message_buffer.final_report = None
first_analyst = f"{selected_analysts[0].value.capitalize()} Analyst"
message_buffer.update_agent_status(first_analyst, "in_progress")
message_buffer.update_agent_status(first_analyst, AgentStatus.IN_PROGRESS)
update_display(layout)
spinner_text = f"Analyzing {ticker} on {analysis_date}..."
@ -430,7 +430,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
decision = graph.process_signal(final_state["final_trade_decision"])
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "completed")
message_buffer.update_agent_status(agent, AgentStatus.COMPLETED)
message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}")

View File

@ -9,7 +9,7 @@ from rich.table import Table
from rich import box
from tradingagents.backtesting import SimpleBacktestEngine, DataLoader
from tradingagents.models.backtest import BacktestConfig
from tradingagents.models.backtest import BacktestConfig, BacktestStatus
from tradingagents.models.portfolio import PortfolioConfig
from cli.display import create_question_box
@ -158,7 +158,7 @@ def run_backtest(
console.print()
if result.status == "failed":
if result.status == BacktestStatus.FAILED:
console.print(f"[red]Backtest failed: {result.error_message}[/red]")
return

View File

@ -1,6 +1,7 @@
from typing import Optional, Dict, Any
from rich.console import Console
from cli.models import AgentStatus
from rich.panel import Panel
from rich.spinner import Spinner
from rich.markdown import Markdown
@ -72,34 +73,34 @@ def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None:
for team, agents in teams.items():
first_agent = agents[0]
status = message_buffer.agent_status[first_agent]
if status == "in_progress":
if status == AgentStatus.IN_PROGRESS:
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
)
status_cell = spinner
else:
status_color = {
"pending": "yellow",
"completed": "green",
"error": "red",
AgentStatus.PENDING: "yellow",
AgentStatus.COMPLETED: "green",
AgentStatus.ERROR: "red",
}.get(status, "white")
status_cell = f"[{status_color}]{status}[/{status_color}]"
status_cell = f"[{status_color}]{status.value}[/{status_color}]"
progress_table.add_row(team, first_agent, status_cell)
for agent in agents[1:]:
status = message_buffer.agent_status[agent]
if status == "in_progress":
if status == AgentStatus.IN_PROGRESS:
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
)
status_cell = spinner
else:
status_color = {
"pending": "yellow",
"completed": "green",
"error": "red",
AgentStatus.PENDING: "yellow",
AgentStatus.COMPLETED: "green",
AgentStatus.ERROR: "red",
}.get(status, "white")
status_cell = f"[{status_color}]{status}[/{status_color}]"
status_cell = f"[{status_color}]{status.value}[/{status_color}]"
progress_table.add_row("", agent, status_cell)
progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim")
@ -383,7 +384,7 @@ def display_complete_report(final_state: Dict[str, Any]) -> None:
)
def update_research_team_status(status: str) -> None:
def update_research_team_status(status: AgentStatus) -> None:
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
for agent in research_team:
message_buffer.update_agent_status(agent, status)

View File

@ -8,3 +8,10 @@ class AnalystType(str, Enum):
SOCIAL = "social"
NEWS = "news"
FUNDAMENTALS = "fundamentals"
class AgentStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"

View File

@ -2,6 +2,8 @@ import datetime
from collections import deque
from typing import Optional, Dict, Any, Deque
from cli.models import AgentStatus
class MessageBuffer:
def __init__(self, max_length: int = 100) -> None:
@ -9,19 +11,19 @@ class MessageBuffer:
self.tool_calls: Deque = deque(maxlen=max_length)
self.current_report = None
self.final_report = None
self.agent_status = {
"Market Analyst": "pending",
"Social Analyst": "pending",
"News Analyst": "pending",
"Fundamentals Analyst": "pending",
"Bull Researcher": "pending",
"Bear Researcher": "pending",
"Research Manager": "pending",
"Trader": "pending",
"Risky Analyst": "pending",
"Neutral Analyst": "pending",
"Safe Analyst": "pending",
"Portfolio Manager": "pending",
self.agent_status: Dict[str, AgentStatus] = {
"Market Analyst": AgentStatus.PENDING,
"Social Analyst": AgentStatus.PENDING,
"News Analyst": AgentStatus.PENDING,
"Fundamentals Analyst": AgentStatus.PENDING,
"Bull Researcher": AgentStatus.PENDING,
"Bear Researcher": AgentStatus.PENDING,
"Research Manager": AgentStatus.PENDING,
"Trader": AgentStatus.PENDING,
"Risky Analyst": AgentStatus.PENDING,
"Neutral Analyst": AgentStatus.PENDING,
"Safe Analyst": AgentStatus.PENDING,
"Portfolio Manager": AgentStatus.PENDING,
}
self.current_agent = None
self.report_sections = {
@ -42,7 +44,7 @@ class MessageBuffer:
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.tool_calls.append((timestamp, tool_name, args))
def update_agent_status(self, agent: str, status: str) -> None:
def update_agent_status(self, agent: str, status: AgentStatus) -> None:
if agent in self.agent_status:
self.agent_status[agent] = status
self.current_agent = agent
@ -123,7 +125,7 @@ class MessageBuffer:
def reset(self) -> None:
for agent in self.agent_status:
self.agent_status[agent] = "pending"
self.agent_status[agent] = AgentStatus.PENDING
for section in self.report_sections:
self.report_sections[section] = None
self.current_report = None

View File

@ -6,6 +6,7 @@ import pytest
from tradingagents.models.backtest import (
BacktestConfig,
BacktestResult,
BacktestStatus,
BacktestMetrics,
EquityCurvePoint,
TradeLog,
@ -287,7 +288,7 @@ class TestBacktestResult:
)
assert result.duration_seconds == 330.0
assert result.status == "completed"
assert result.status == BacktestStatus.COMPLETED
def test_to_dict(self):
config = BacktestConfig(

View File

@ -7,6 +7,7 @@ from uuid import uuid4
from tradingagents.models.backtest import (
BacktestConfig,
BacktestResult,
BacktestStatus,
EquityCurvePoint,
TradeLog,
)
@ -69,7 +70,7 @@ class BacktestEngine:
daily_returns=self.daily_returns,
started_at=started_at,
completed_at=completed_at,
status="completed",
status=BacktestStatus.COMPLETED,
)
except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e:
@ -84,7 +85,7 @@ class BacktestEngine:
daily_returns=self.daily_returns,
started_at=started_at,
completed_at=completed_at,
status="failed",
status=BacktestStatus.FAILED,
error_message=str(e),
)

View File

@ -10,6 +10,13 @@ from .portfolio import PortfolioConfig
from .trading import Trade
class BacktestStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class BacktestConfig(BaseModel):
id: UUID = Field(default_factory=uuid4)
name: str = Field(default="Backtest")
@ -212,7 +219,7 @@ class BacktestResult(BaseModel):
daily_returns: list[Decimal] = Field(default_factory=list)
started_at: datetime
completed_at: datetime
status: str = Field(default="completed")
status: BacktestStatus = Field(default=BacktestStatus.COMPLETED)
error_message: Optional[str] = None
@computed_field