This commit is contained in:
MarkLo 2025-12-11 15:00:30 +08:00
parent 20e0c6a2d9
commit 482d8fa6aa
4 changed files with 24 additions and 3 deletions

View File

@ -1,7 +1,7 @@
"""
Pydantic models for request/response schemas
"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional, Dict, Any, Literal, Union
from datetime import date
@ -9,6 +9,12 @@ from datetime import date
class AnalysisRequest(BaseModel):
"""Request model for trading analysis"""
ticker: str = Field(..., description="Stock ticker symbol (e.g., 'NVDA', 'AAPL')", min_length=1, max_length=10)
# 防呆:自動將股票代碼轉換為大寫
@field_validator('ticker')
@classmethod
def uppercase_ticker(cls, v: str) -> str:
return v.strip().upper()
analysis_date: str = Field(..., description="Analysis date in YYYY-MM-DD format")
analysts: Optional[List[str]] = Field(
default=["market", "social", "news", "fundamentals"],
@ -140,3 +146,10 @@ class DownloadRequest(BaseModel):
analysis_date: str = Field(..., description="Analysis date in YYYY-MM-DD format")
task_id: str = Field(..., description="Task ID of the completed analysis")
analysts: List[str] = Field(..., description="List of analyst keys to download", min_length=1)
# 防呆:自動將股票代碼轉換為大寫
@field_validator('ticker')
@classmethod
def uppercase_ticker(cls, v: str) -> str:
return v.strip().upper()

View File

@ -559,7 +559,10 @@ def get_user_selections():
def get_ticker():
"""從使用者輸入中獲取股票代碼。"""
return typer.prompt("", default="SPY")
ticker = typer.prompt("", default="SPY")
# 防呆:將股票代碼轉換為大寫
ticker = ticker.strip().upper()
return ticker
def get_analysis_date():

View File

@ -41,6 +41,9 @@ class Propagator:
Returns:
Dict[str, Any]: 初始狀態的字典
"""
# 防呆:將股票代碼轉換為大寫並去除空白
company_name = company_name.strip().upper()
# 獲取真實公司名稱從Alpha Vantage獲取公司概況
ticker = company_name # company_name實際上是ticker
actual_company_name = ticker # 預設值為ticker

View File

@ -203,7 +203,9 @@ class TradingAgentsXGraph:
Returns:
tuple: 包含最終狀態和處理後信號的元組
"""
# 防呆:將股票代碼轉換為大寫並去除空白
company_name = company_name.strip().upper()
self.ticker = company_name
# 初始化狀態