feat: create enums for status strings
Added AgentStatus enum for CLI agent tracking (pending, in_progress, completed, error) and BacktestStatus enum for backtest results (pending, running, completed, failed). Replaces string literals with type-safe enum values throughout the codebase. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
b8be981bc8
commit
1346e20b5e
|
|
@ -13,7 +13,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
from cli.state import message_buffer
|
||||
from cli.models import AnalystType
|
||||
from cli.models import AnalystType, AgentStatus
|
||||
from cli.display import (
|
||||
create_layout,
|
||||
update_display,
|
||||
|
|
@ -138,32 +138,32 @@ def get_user_selections() -> dict:
|
|||
def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) -> None:
|
||||
if "market_report" in chunk and chunk["market_report"]:
|
||||
message_buffer.update_report_section("market_report", chunk["market_report"])
|
||||
message_buffer.update_agent_status("Market Analyst", "completed")
|
||||
message_buffer.update_agent_status("Market Analyst", AgentStatus.COMPLETED)
|
||||
if AnalystType.SOCIAL in selected_analysts:
|
||||
message_buffer.update_agent_status("Social Analyst", "in_progress")
|
||||
message_buffer.update_agent_status("Social Analyst", AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
||||
message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"])
|
||||
message_buffer.update_agent_status("Social Analyst", "completed")
|
||||
message_buffer.update_agent_status("Social Analyst", AgentStatus.COMPLETED)
|
||||
if AnalystType.NEWS in selected_analysts:
|
||||
message_buffer.update_agent_status("News Analyst", "in_progress")
|
||||
message_buffer.update_agent_status("News Analyst", AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "news_report" in chunk and chunk["news_report"]:
|
||||
message_buffer.update_report_section("news_report", chunk["news_report"])
|
||||
message_buffer.update_agent_status("News Analyst", "completed")
|
||||
message_buffer.update_agent_status("News Analyst", AgentStatus.COMPLETED)
|
||||
if AnalystType.FUNDAMENTALS in selected_analysts:
|
||||
message_buffer.update_agent_status("Fundamentals Analyst", "in_progress")
|
||||
message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
||||
message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"])
|
||||
message_buffer.update_agent_status("Fundamentals Analyst", "completed")
|
||||
update_research_team_status("in_progress")
|
||||
message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.COMPLETED)
|
||||
update_research_team_status(AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
|
||||
if "bull_history" in debate_state and debate_state["bull_history"]:
|
||||
update_research_team_status("in_progress")
|
||||
update_research_team_status(AgentStatus.IN_PROGRESS)
|
||||
bull_responses = debate_state["bull_history"].split("\n")
|
||||
latest_bull = bull_responses[-1] if bull_responses else ""
|
||||
if latest_bull:
|
||||
|
|
@ -174,7 +174,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
)
|
||||
|
||||
if "bear_history" in debate_state and debate_state["bear_history"]:
|
||||
update_research_team_status("in_progress")
|
||||
update_research_team_status(AgentStatus.IN_PROGRESS)
|
||||
bear_responses = debate_state["bear_history"].split("\n")
|
||||
latest_bear = bear_responses[-1] if bear_responses else ""
|
||||
if latest_bear:
|
||||
|
|
@ -185,7 +185,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
)
|
||||
|
||||
if "judge_decision" in debate_state and debate_state["judge_decision"]:
|
||||
update_research_team_status("in_progress")
|
||||
update_research_team_status(AgentStatus.IN_PROGRESS)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Research Manager: {debate_state['judge_decision']}",
|
||||
|
|
@ -194,18 +194,18 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
"investment_plan",
|
||||
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
|
||||
)
|
||||
update_research_team_status("completed")
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
update_research_team_status(AgentStatus.COMPLETED)
|
||||
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
||||
message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
||||
risk_state = chunk["risk_debate_state"]
|
||||
|
||||
if "current_risky_response" in risk_state and risk_state["current_risky_response"]:
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Risky Analyst: {risk_state['current_risky_response']}",
|
||||
|
|
@ -216,7 +216,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
)
|
||||
|
||||
if "current_safe_response" in risk_state and risk_state["current_safe_response"]:
|
||||
message_buffer.update_agent_status("Safe Analyst", "in_progress")
|
||||
message_buffer.update_agent_status("Safe Analyst", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Safe Analyst: {risk_state['current_safe_response']}",
|
||||
|
|
@ -227,7 +227,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
)
|
||||
|
||||
if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]:
|
||||
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
|
||||
message_buffer.update_agent_status("Neutral Analyst", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Neutral Analyst: {risk_state['current_neutral_response']}",
|
||||
|
|
@ -238,7 +238,7 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
)
|
||||
|
||||
if "judge_decision" in risk_state and risk_state["judge_decision"]:
|
||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||
message_buffer.update_agent_status("Portfolio Manager", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Portfolio Manager: {risk_state['judge_decision']}",
|
||||
|
|
@ -247,10 +247,10 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
"final_trade_decision",
|
||||
f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
|
||||
)
|
||||
message_buffer.update_agent_status("Risky Analyst", "completed")
|
||||
message_buffer.update_agent_status("Safe Analyst", "completed")
|
||||
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
||||
message_buffer.update_agent_status("Portfolio Manager", "completed")
|
||||
message_buffer.update_agent_status("Risky Analyst", AgentStatus.COMPLETED)
|
||||
message_buffer.update_agent_status("Safe Analyst", AgentStatus.COMPLETED)
|
||||
message_buffer.update_agent_status("Neutral Analyst", AgentStatus.COMPLETED)
|
||||
message_buffer.update_agent_status("Portfolio Manager", AgentStatus.COMPLETED)
|
||||
|
||||
|
||||
def setup_logging_decorators(report_dir, log_file) -> tuple:
|
||||
|
|
@ -383,7 +383,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
|
|||
update_display(layout)
|
||||
|
||||
for agent in message_buffer.agent_status:
|
||||
message_buffer.update_agent_status(agent, "pending")
|
||||
message_buffer.update_agent_status(agent, AgentStatus.PENDING)
|
||||
|
||||
for section in message_buffer.report_sections:
|
||||
message_buffer.report_sections[section] = None
|
||||
|
|
@ -391,7 +391,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
|
|||
message_buffer.final_report = None
|
||||
|
||||
first_analyst = f"{selected_analysts[0].value.capitalize()} Analyst"
|
||||
message_buffer.update_agent_status(first_analyst, "in_progress")
|
||||
message_buffer.update_agent_status(first_analyst, AgentStatus.IN_PROGRESS)
|
||||
update_display(layout)
|
||||
|
||||
spinner_text = f"Analyzing {ticker} on {analysis_date}..."
|
||||
|
|
@ -430,7 +430,7 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
|
|||
decision = graph.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
for agent in message_buffer.agent_status:
|
||||
message_buffer.update_agent_status(agent, "completed")
|
||||
message_buffer.update_agent_status(agent, AgentStatus.COMPLETED)
|
||||
|
||||
message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}")
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from rich.table import Table
|
|||
from rich import box
|
||||
|
||||
from tradingagents.backtesting import SimpleBacktestEngine, DataLoader
|
||||
from tradingagents.models.backtest import BacktestConfig
|
||||
from tradingagents.models.backtest import BacktestConfig, BacktestStatus
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
|
||||
from cli.display import create_question_box
|
||||
|
|
@ -158,7 +158,7 @@ def run_backtest(
|
|||
|
||||
console.print()
|
||||
|
||||
if result.status == "failed":
|
||||
if result.status == BacktestStatus.FAILED:
|
||||
console.print(f"[red]Backtest failed: {result.error_message}[/red]")
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional, Dict, Any
|
||||
|
||||
from rich.console import Console
|
||||
from cli.models import AgentStatus
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.markdown import Markdown
|
||||
|
|
@ -72,34 +73,34 @@ def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None:
|
|||
for team, agents in teams.items():
|
||||
first_agent = agents[0]
|
||||
status = message_buffer.agent_status[first_agent]
|
||||
if status == "in_progress":
|
||||
if status == AgentStatus.IN_PROGRESS:
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||
)
|
||||
status_cell = spinner
|
||||
else:
|
||||
status_color = {
|
||||
"pending": "yellow",
|
||||
"completed": "green",
|
||||
"error": "red",
|
||||
AgentStatus.PENDING: "yellow",
|
||||
AgentStatus.COMPLETED: "green",
|
||||
AgentStatus.ERROR: "red",
|
||||
}.get(status, "white")
|
||||
status_cell = f"[{status_color}]{status}[/{status_color}]"
|
||||
status_cell = f"[{status_color}]{status.value}[/{status_color}]"
|
||||
progress_table.add_row(team, first_agent, status_cell)
|
||||
|
||||
for agent in agents[1:]:
|
||||
status = message_buffer.agent_status[agent]
|
||||
if status == "in_progress":
|
||||
if status == AgentStatus.IN_PROGRESS:
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||
)
|
||||
status_cell = spinner
|
||||
else:
|
||||
status_color = {
|
||||
"pending": "yellow",
|
||||
"completed": "green",
|
||||
"error": "red",
|
||||
AgentStatus.PENDING: "yellow",
|
||||
AgentStatus.COMPLETED: "green",
|
||||
AgentStatus.ERROR: "red",
|
||||
}.get(status, "white")
|
||||
status_cell = f"[{status_color}]{status}[/{status_color}]"
|
||||
status_cell = f"[{status_color}]{status.value}[/{status_color}]"
|
||||
progress_table.add_row("", agent, status_cell)
|
||||
|
||||
progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim")
|
||||
|
|
@ -383,7 +384,7 @@ def display_complete_report(final_state: Dict[str, Any]) -> None:
|
|||
)
|
||||
|
||||
|
||||
def update_research_team_status(status: str) -> None:
|
||||
def update_research_team_status(status: AgentStatus) -> None:
|
||||
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
|
||||
for agent in research_team:
|
||||
message_buffer.update_agent_status(agent, status)
|
||||
|
|
|
|||
|
|
@ -8,3 +8,10 @@ class AnalystType(str, Enum):
|
|||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
FUNDAMENTALS = "fundamentals"
|
||||
|
||||
|
||||
class AgentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
|
|
|||
32
cli/state.py
32
cli/state.py
|
|
@ -2,6 +2,8 @@ import datetime
|
|||
from collections import deque
|
||||
from typing import Optional, Dict, Any, Deque
|
||||
|
||||
from cli.models import AgentStatus
|
||||
|
||||
|
||||
class MessageBuffer:
|
||||
def __init__(self, max_length: int = 100) -> None:
|
||||
|
|
@ -9,19 +11,19 @@ class MessageBuffer:
|
|||
self.tool_calls: Deque = deque(maxlen=max_length)
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
self.agent_status = {
|
||||
"Market Analyst": "pending",
|
||||
"Social Analyst": "pending",
|
||||
"News Analyst": "pending",
|
||||
"Fundamentals Analyst": "pending",
|
||||
"Bull Researcher": "pending",
|
||||
"Bear Researcher": "pending",
|
||||
"Research Manager": "pending",
|
||||
"Trader": "pending",
|
||||
"Risky Analyst": "pending",
|
||||
"Neutral Analyst": "pending",
|
||||
"Safe Analyst": "pending",
|
||||
"Portfolio Manager": "pending",
|
||||
self.agent_status: Dict[str, AgentStatus] = {
|
||||
"Market Analyst": AgentStatus.PENDING,
|
||||
"Social Analyst": AgentStatus.PENDING,
|
||||
"News Analyst": AgentStatus.PENDING,
|
||||
"Fundamentals Analyst": AgentStatus.PENDING,
|
||||
"Bull Researcher": AgentStatus.PENDING,
|
||||
"Bear Researcher": AgentStatus.PENDING,
|
||||
"Research Manager": AgentStatus.PENDING,
|
||||
"Trader": AgentStatus.PENDING,
|
||||
"Risky Analyst": AgentStatus.PENDING,
|
||||
"Neutral Analyst": AgentStatus.PENDING,
|
||||
"Safe Analyst": AgentStatus.PENDING,
|
||||
"Portfolio Manager": AgentStatus.PENDING,
|
||||
}
|
||||
self.current_agent = None
|
||||
self.report_sections = {
|
||||
|
|
@ -42,7 +44,7 @@ class MessageBuffer:
|
|||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.tool_calls.append((timestamp, tool_name, args))
|
||||
|
||||
def update_agent_status(self, agent: str, status: str) -> None:
|
||||
def update_agent_status(self, agent: str, status: AgentStatus) -> None:
|
||||
if agent in self.agent_status:
|
||||
self.agent_status[agent] = status
|
||||
self.current_agent = agent
|
||||
|
|
@ -123,7 +125,7 @@ class MessageBuffer:
|
|||
|
||||
def reset(self) -> None:
|
||||
for agent in self.agent_status:
|
||||
self.agent_status[agent] = "pending"
|
||||
self.agent_status[agent] = AgentStatus.PENDING
|
||||
for section in self.report_sections:
|
||||
self.report_sections[section] = None
|
||||
self.current_report = None
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from tradingagents.models.backtest import (
|
||||
BacktestConfig,
|
||||
BacktestResult,
|
||||
BacktestStatus,
|
||||
BacktestMetrics,
|
||||
EquityCurvePoint,
|
||||
TradeLog,
|
||||
|
|
@ -287,7 +288,7 @@ class TestBacktestResult:
|
|||
)
|
||||
|
||||
assert result.duration_seconds == 330.0
|
||||
assert result.status == "completed"
|
||||
assert result.status == BacktestStatus.COMPLETED
|
||||
|
||||
def test_to_dict(self):
|
||||
config = BacktestConfig(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from uuid import uuid4
|
|||
from tradingagents.models.backtest import (
|
||||
BacktestConfig,
|
||||
BacktestResult,
|
||||
BacktestStatus,
|
||||
EquityCurvePoint,
|
||||
TradeLog,
|
||||
)
|
||||
|
|
@ -69,7 +70,7 @@ class BacktestEngine:
|
|||
daily_returns=self.daily_returns,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
status="completed",
|
||||
status=BacktestStatus.COMPLETED,
|
||||
)
|
||||
|
||||
except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e:
|
||||
|
|
@ -84,7 +85,7 @@ class BacktestEngine:
|
|||
daily_returns=self.daily_returns,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
status="failed",
|
||||
status=BacktestStatus.FAILED,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,13 @@ from .portfolio import PortfolioConfig
|
|||
from .trading import Trade
|
||||
|
||||
|
||||
class BacktestStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class BacktestConfig(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
name: str = Field(default="Backtest")
|
||||
|
|
@ -212,7 +219,7 @@ class BacktestResult(BaseModel):
|
|||
daily_returns: list[Decimal] = Field(default_factory=list)
|
||||
started_at: datetime
|
||||
completed_at: datetime
|
||||
status: str = Field(default="completed")
|
||||
status: BacktestStatus = Field(default=BacktestStatus.COMPLETED)
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@computed_field
|
||||
|
|
|
|||
Loading…
Reference in New Issue