diff --git a/backend/app/models/schemas.py b/backend/app/models/schemas.py index 6f191c76..ddc9f922 100644 --- a/backend/app/models/schemas.py +++ b/backend/app/models/schemas.py @@ -1,7 +1,7 @@ """ Pydantic models for request/response schemas """ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from typing import List, Optional, Dict, Any, Literal, Union from datetime import date @@ -9,6 +9,12 @@ from datetime import date class AnalysisRequest(BaseModel): """Request model for trading analysis""" ticker: str = Field(..., description="Stock ticker symbol (e.g., 'NVDA', 'AAPL')", min_length=1, max_length=10) + + # 防呆:自動將股票代碼轉換為大寫 + @field_validator('ticker') + @classmethod + def uppercase_ticker(cls, v: str) -> str: + return v.strip().upper() analysis_date: str = Field(..., description="Analysis date in YYYY-MM-DD format") analysts: Optional[List[str]] = Field( default=["market", "social", "news", "fundamentals"], @@ -140,3 +146,10 @@ class DownloadRequest(BaseModel): analysis_date: str = Field(..., description="Analysis date in YYYY-MM-DD format") task_id: str = Field(..., description="Task ID of the completed analysis") analysts: List[str] = Field(..., description="List of analyst keys to download", min_length=1) + + # 防呆:自動將股票代碼轉換為大寫 + @field_validator('ticker') + @classmethod + def uppercase_ticker(cls, v: str) -> str: + return v.strip().upper() + diff --git a/cli/main.py b/cli/main.py index 9e947c12..b322672c 100644 --- a/cli/main.py +++ b/cli/main.py @@ -559,7 +559,10 @@ def get_user_selections(): def get_ticker(): """從使用者輸入中獲取股票代碼。""" - return typer.prompt("", default="SPY") + ticker = typer.prompt("", default="SPY") + # 防呆:將股票代碼轉換為大寫 + ticker = ticker.strip().upper() + return ticker def get_analysis_date(): diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index a1cf9430..46a991e7 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -41,6 +41,9 @@ class Propagator: Returns: Dict[str, Any]: 初始狀態的字典。 """ + # 防呆:將股票代碼轉換為大寫並去除空白 + company_name = company_name.strip().upper() + # 獲取真實公司名稱(從Alpha Vantage獲取公司概況) ticker = company_name # company_name實際上是ticker actual_company_name = ticker # 預設值為ticker diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index fa473f36..10bc5f4d 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -203,7 +203,9 @@ class TradingAgentsXGraph: Returns: tuple: 包含最終狀態和處理後信號的元組。 """ - + # 防呆:將股票代碼轉換為大寫並去除空白 + company_name = company_name.strip().upper() + self.ticker = company_name # 初始化狀態