feat: create enums for status strings

Added AgentStatus enum for CLI agent tracking (pending, in_progress,
completed, error) and BacktestStatus enum for backtest results (pending,
running, completed, failed). Replaces string literals with type-safe
enum values throughout the codebase.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 04:13:59 -05:00
parent b8be981bc8
commit 1346e20b5e
8 changed files with 77 additions and 58 deletions

View File

@ -13,7 +13,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
from cli.state import message_buffer from cli.state import message_buffer
from cli.models import AnalystType from cli.models import AnalystType, AgentStatus
from cli.display import ( from cli.display import (
create_layout, create_layout,
update_display, update_display,
@ -138,32 +138,32 @@ def get_user_selections() -> dict:
def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) -> None: def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) -> None:
if "market_report" in chunk and chunk["market_report"]: if "market_report" in chunk and chunk["market_report"]:
message_buffer.update_report_section("market_report", chunk["market_report"]) message_buffer.update_report_section("market_report", chunk["market_report"])
message_buffer.update_agent_status("Market Analyst", "completed") message_buffer.update_agent_status("Market Analyst", AgentStatus.COMPLETED)
if AnalystType.SOCIAL in selected_analysts: if AnalystType.SOCIAL in selected_analysts:
message_buffer.update_agent_status("Social Analyst", "in_progress") message_buffer.update_agent_status("Social Analyst", AgentStatus.IN_PROGRESS)
if "sentiment_report" in chunk and chunk["sentiment_report"]: if "sentiment_report" in chunk and chunk["sentiment_report"]:
message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"]) message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"])
message_buffer.update_agent_status("Social Analyst", "completed") message_buffer.update_agent_status("Social Analyst", AgentStatus.COMPLETED)
if AnalystType.NEWS in selected_analysts: if AnalystType.NEWS in selected_analysts:
message_buffer.update_agent_status("News Analyst", "in_progress") message_buffer.update_agent_status("News Analyst", AgentStatus.IN_PROGRESS)
if "news_report" in chunk and chunk["news_report"]: if "news_report" in chunk and chunk["news_report"]:
message_buffer.update_report_section("news_report", chunk["news_report"]) message_buffer.update_report_section("news_report", chunk["news_report"])
message_buffer.update_agent_status("News Analyst", "completed") message_buffer.update_agent_status("News Analyst", AgentStatus.COMPLETED)
if AnalystType.FUNDAMENTALS in selected_analysts: if AnalystType.FUNDAMENTALS in selected_analysts:
message_buffer.update_agent_status("Fundamentals Analyst", "in_progress") message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.IN_PROGRESS)
if "fundamentals_report" in chunk and chunk["fundamentals_report"]: if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"]) message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"])
message_buffer.update_agent_status("Fundamentals Analyst", "completed") message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.COMPLETED)
update_research_team_status("in_progress") update_research_team_status(AgentStatus.IN_PROGRESS)
if "investment_debate_state" in chunk and chunk["investment_debate_state"]: if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
debate_state = chunk["investment_debate_state"] debate_state = chunk["investment_debate_state"]
if "bull_history" in debate_state and debate_state["bull_history"]: if "bull_history" in debate_state and debate_state["bull_history"]:
update_research_team_status("in_progress") update_research_team_status(AgentStatus.IN_PROGRESS)
bull_responses = debate_state["bull_history"].split("\n") bull_responses = debate_state["bull_history"].split("\n")
latest_bull = bull_responses[-1] if bull_responses else "" latest_bull = bull_responses[-1] if bull_responses else ""
if latest_bull: if latest_bull:
@ -174,7 +174,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
) )
if "bear_history" in debate_state and debate_state["bear_history"]: if "bear_history" in debate_state and debate_state["bear_history"]:
update_research_team_status("in_progress") update_research_team_status(AgentStatus.IN_PROGRESS)
bear_responses = debate_state["bear_history"].split("\n") bear_responses = debate_state["bear_history"].split("\n")
latest_bear = bear_responses[-1] if bear_responses else "" latest_bear = bear_responses[-1] if bear_responses else ""
if latest_bear: if latest_bear:
@ -185,7 +185,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
) )
if "judge_decision" in debate_state and debate_state["judge_decision"]: if "judge_decision" in debate_state and debate_state["judge_decision"]:
update_research_team_status("in_progress") update_research_team_status(AgentStatus.IN_PROGRESS)
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
f"Research Manager: {debate_state['judge_decision']}", f"Research Manager: {debate_state['judge_decision']}",
@ -194,18 +194,18 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
"investment_plan", "investment_plan",
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}", f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
) )
update_research_team_status("completed") update_research_team_status(AgentStatus.COMPLETED)
message_buffer.update_agent_status("Risky Analyst", "in_progress") message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]: if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"]) message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
message_buffer.update_agent_status("Risky Analyst", "in_progress") message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
if "risk_debate_state" in chunk and chunk["risk_debate_state"]: if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
risk_state = chunk["risk_debate_state"] risk_state = chunk["risk_debate_state"]
if "current_risky_response" in risk_state and risk_state["current_risky_response"]: if "current_risky_response" in risk_state and risk_state["current_risky_response"]:
message_buffer.update_agent_status("Risky Analyst", "in_progress") message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
f"Risky Analyst: {risk_state['current_risky_response']}", f"Risky Analyst: {risk_state['current_risky_response']}",
@ -216,7 +216,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
) )
if "current_safe_response" in risk_state and risk_state["current_safe_response"]: if "current_safe_response" in risk_state and risk_state["current_safe_response"]:
message_buffer.update_agent_status("Safe Analyst", "in_progress") message_buffer.update_agent_status("Safe Analyst", AgentStatus.IN_PROGRESS)
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
f"Safe Analyst: {risk_state['current_safe_response']}", f"Safe Analyst: {risk_state['current_safe_response']}",
@ -227,7 +227,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
) )
if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]: if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]:
message_buffer.update_agent_status("Neutral Analyst", "in_progress") message_buffer.update_agent_status("Neutral Analyst", AgentStatus.IN_PROGRESS)
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
f"Neutral Analyst: {risk_state['current_neutral_response']}", f"Neutral Analyst: {risk_state['current_neutral_response']}",
@ -238,7 +238,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
) )
if "judge_decision" in risk_state and risk_state["judge_decision"]: if "judge_decision" in risk_state and risk_state["judge_decision"]:
message_buffer.update_agent_status("Portfolio Manager", "in_progress") message_buffer.update_agent_status("Portfolio Manager", AgentStatus.IN_PROGRESS)
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
f"Portfolio Manager: {risk_state['judge_decision']}", f"Portfolio Manager: {risk_state['judge_decision']}",
@ -247,10 +247,10 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
"final_trade_decision", "final_trade_decision",
f"### Portfolio Manager Decision\n{risk_state['judge_decision']}", f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
) )
message_buffer.update_agent_status("Risky Analyst", "completed") message_buffer.update_agent_status("Risky Analyst", AgentStatus.COMPLETED)
message_buffer.update_agent_status("Safe Analyst", "completed") message_buffer.update_agent_status("Safe Analyst", AgentStatus.COMPLETED)
message_buffer.update_agent_status("Neutral Analyst", "completed") message_buffer.update_agent_status("Neutral Analyst", AgentStatus.COMPLETED)
message_buffer.update_agent_status("Portfolio Manager", "completed") message_buffer.update_agent_status("Portfolio Manager", AgentStatus.COMPLETED)
def setup_logging_decorators(report_dir, log_file) -> tuple: def setup_logging_decorators(report_dir, log_file) -> tuple:
@ -383,7 +383,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
update_display(layout) update_display(layout)
for agent in message_buffer.agent_status: for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "pending") message_buffer.update_agent_status(agent, AgentStatus.PENDING)
for section in message_buffer.report_sections: for section in message_buffer.report_sections:
message_buffer.report_sections[section] = None message_buffer.report_sections[section] = None
@ -391,7 +391,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
message_buffer.final_report = None message_buffer.final_report = None
first_analyst = f"{selected_analysts[0].value.capitalize()} Analyst" first_analyst = f"{selected_analysts[0].value.capitalize()} Analyst"
message_buffer.update_agent_status(first_analyst, "in_progress") message_buffer.update_agent_status(first_analyst, AgentStatus.IN_PROGRESS)
update_display(layout) update_display(layout)
spinner_text = f"Analyzing {ticker} on {analysis_date}..." spinner_text = f"Analyzing {ticker} on {analysis_date}..."
@ -430,7 +430,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
decision = graph.process_signal(final_state["final_trade_decision"]) decision = graph.process_signal(final_state["final_trade_decision"])
for agent in message_buffer.agent_status: for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "completed") message_buffer.update_agent_status(agent, AgentStatus.COMPLETED)
message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}") message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}")

View File

@ -9,7 +9,7 @@ from rich.table import Table
from rich import box from rich import box
from tradingagents.backtesting import SimpleBacktestEngine, DataLoader from tradingagents.backtesting import SimpleBacktestEngine, DataLoader
from tradingagents.models.backtest import BacktestConfig from tradingagents.models.backtest import BacktestConfig, BacktestStatus
from tradingagents.models.portfolio import PortfolioConfig from tradingagents.models.portfolio import PortfolioConfig
from cli.display import create_question_box from cli.display import create_question_box
@ -158,7 +158,7 @@ def run_backtest(
console.print() console.print()
if result.status == "failed": if result.status == BacktestStatus.FAILED:
console.print(f"[red]Backtest failed: {result.error_message}[/red]") console.print(f"[red]Backtest failed: {result.error_message}[/red]")
return return

View File

@ -1,6 +1,7 @@
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from rich.console import Console from rich.console import Console
from cli.models import AgentStatus
from rich.panel import Panel from rich.panel import Panel
from rich.spinner import Spinner from rich.spinner import Spinner
from rich.markdown import Markdown from rich.markdown import Markdown
@ -72,34 +73,34 @@ def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None:
for team, agents in teams.items(): for team, agents in teams.items():
first_agent = agents[0] first_agent = agents[0]
status = message_buffer.agent_status[first_agent] status = message_buffer.agent_status[first_agent]
if status == "in_progress": if status == AgentStatus.IN_PROGRESS:
spinner = Spinner( spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan" "dots", text="[blue]in_progress[/blue]", style="bold cyan"
) )
status_cell = spinner status_cell = spinner
else: else:
status_color = { status_color = {
"pending": "yellow", AgentStatus.PENDING: "yellow",
"completed": "green", AgentStatus.COMPLETED: "green",
"error": "red", AgentStatus.ERROR: "red",
}.get(status, "white") }.get(status, "white")
status_cell = f"[{status_color}]{status}[/{status_color}]" status_cell = f"[{status_color}]{status.value}[/{status_color}]"
progress_table.add_row(team, first_agent, status_cell) progress_table.add_row(team, first_agent, status_cell)
for agent in agents[1:]: for agent in agents[1:]:
status = message_buffer.agent_status[agent] status = message_buffer.agent_status[agent]
if status == "in_progress": if status == AgentStatus.IN_PROGRESS:
spinner = Spinner( spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan" "dots", text="[blue]in_progress[/blue]", style="bold cyan"
) )
status_cell = spinner status_cell = spinner
else: else:
status_color = { status_color = {
"pending": "yellow", AgentStatus.PENDING: "yellow",
"completed": "green", AgentStatus.COMPLETED: "green",
"error": "red", AgentStatus.ERROR: "red",
}.get(status, "white") }.get(status, "white")
status_cell = f"[{status_color}]{status}[/{status_color}]" status_cell = f"[{status_color}]{status.value}[/{status_color}]"
progress_table.add_row("", agent, status_cell) progress_table.add_row("", agent, status_cell)
progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim") progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim")
@ -383,7 +384,7 @@ def display_complete_report(final_state: Dict[str, Any]) -> None:
) )
def update_research_team_status(status: str) -> None: def update_research_team_status(status: AgentStatus) -> None:
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"] research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
for agent in research_team: for agent in research_team:
message_buffer.update_agent_status(agent, status) message_buffer.update_agent_status(agent, status)

View File

@ -8,3 +8,10 @@ class AnalystType(str, Enum):
SOCIAL = "social" SOCIAL = "social"
NEWS = "news" NEWS = "news"
FUNDAMENTALS = "fundamentals" FUNDAMENTALS = "fundamentals"
class AgentStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"

View File

@ -2,6 +2,8 @@ import datetime
from collections import deque from collections import deque
from typing import Optional, Dict, Any, Deque from typing import Optional, Dict, Any, Deque
from cli.models import AgentStatus
class MessageBuffer: class MessageBuffer:
def __init__(self, max_length: int = 100) -> None: def __init__(self, max_length: int = 100) -> None:
@ -9,19 +11,19 @@ class MessageBuffer:
self.tool_calls: Deque = deque(maxlen=max_length) self.tool_calls: Deque = deque(maxlen=max_length)
self.current_report = None self.current_report = None
self.final_report = None self.final_report = None
self.agent_status = { self.agent_status: Dict[str, AgentStatus] = {
"Market Analyst": "pending", "Market Analyst": AgentStatus.PENDING,
"Social Analyst": "pending", "Social Analyst": AgentStatus.PENDING,
"News Analyst": "pending", "News Analyst": AgentStatus.PENDING,
"Fundamentals Analyst": "pending", "Fundamentals Analyst": AgentStatus.PENDING,
"Bull Researcher": "pending", "Bull Researcher": AgentStatus.PENDING,
"Bear Researcher": "pending", "Bear Researcher": AgentStatus.PENDING,
"Research Manager": "pending", "Research Manager": AgentStatus.PENDING,
"Trader": "pending", "Trader": AgentStatus.PENDING,
"Risky Analyst": "pending", "Risky Analyst": AgentStatus.PENDING,
"Neutral Analyst": "pending", "Neutral Analyst": AgentStatus.PENDING,
"Safe Analyst": "pending", "Safe Analyst": AgentStatus.PENDING,
"Portfolio Manager": "pending", "Portfolio Manager": AgentStatus.PENDING,
} }
self.current_agent = None self.current_agent = None
self.report_sections = { self.report_sections = {
@ -42,7 +44,7 @@ class MessageBuffer:
timestamp = datetime.datetime.now().strftime("%H:%M:%S") timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.tool_calls.append((timestamp, tool_name, args)) self.tool_calls.append((timestamp, tool_name, args))
def update_agent_status(self, agent: str, status: str) -> None: def update_agent_status(self, agent: str, status: AgentStatus) -> None:
if agent in self.agent_status: if agent in self.agent_status:
self.agent_status[agent] = status self.agent_status[agent] = status
self.current_agent = agent self.current_agent = agent
@ -123,7 +125,7 @@ class MessageBuffer:
def reset(self) -> None: def reset(self) -> None:
for agent in self.agent_status: for agent in self.agent_status:
self.agent_status[agent] = "pending" self.agent_status[agent] = AgentStatus.PENDING
for section in self.report_sections: for section in self.report_sections:
self.report_sections[section] = None self.report_sections[section] = None
self.current_report = None self.current_report = None

View File

@ -6,6 +6,7 @@ import pytest
from tradingagents.models.backtest import ( from tradingagents.models.backtest import (
BacktestConfig, BacktestConfig,
BacktestResult, BacktestResult,
BacktestStatus,
BacktestMetrics, BacktestMetrics,
EquityCurvePoint, EquityCurvePoint,
TradeLog, TradeLog,
@ -287,7 +288,7 @@ class TestBacktestResult:
) )
assert result.duration_seconds == 330.0 assert result.duration_seconds == 330.0
assert result.status == "completed" assert result.status == BacktestStatus.COMPLETED
def test_to_dict(self): def test_to_dict(self):
config = BacktestConfig( config = BacktestConfig(

View File

@ -7,6 +7,7 @@ from uuid import uuid4
from tradingagents.models.backtest import ( from tradingagents.models.backtest import (
BacktestConfig, BacktestConfig,
BacktestResult, BacktestResult,
BacktestStatus,
EquityCurvePoint, EquityCurvePoint,
TradeLog, TradeLog,
) )
@ -69,7 +70,7 @@ class BacktestEngine:
daily_returns=self.daily_returns, daily_returns=self.daily_returns,
started_at=started_at, started_at=started_at,
completed_at=completed_at, completed_at=completed_at,
status="completed", status=BacktestStatus.COMPLETED,
) )
except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e: except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e:
@ -84,7 +85,7 @@ class BacktestEngine:
daily_returns=self.daily_returns, daily_returns=self.daily_returns,
started_at=started_at, started_at=started_at,
completed_at=completed_at, completed_at=completed_at,
status="failed", status=BacktestStatus.FAILED,
error_message=str(e), error_message=str(e),
) )

View File

@ -10,6 +10,13 @@ from .portfolio import PortfolioConfig
from .trading import Trade from .trading import Trade
class BacktestStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class BacktestConfig(BaseModel): class BacktestConfig(BaseModel):
id: UUID = Field(default_factory=uuid4) id: UUID = Field(default_factory=uuid4)
name: str = Field(default="Backtest") name: str = Field(default="Backtest")
@ -212,7 +219,7 @@ class BacktestResult(BaseModel):
daily_returns: list[Decimal] = Field(default_factory=list) daily_returns: list[Decimal] = Field(default_factory=list)
started_at: datetime started_at: datetime
completed_at: datetime completed_at: datetime
status: str = Field(default="completed") status: BacktestStatus = Field(default=BacktestStatus.COMPLETED)
error_message: Optional[str] = None error_message: Optional[str] = None
@computed_field @computed_field