diff --git a/cli/main.py b/cli/main.py index 2e06d50c..21b04824 100644 --- a/cli/main.py +++ b/cli/main.py @@ -497,8 +497,28 @@ def get_user_selections(): def get_ticker(): - """Get ticker symbol from user input.""" - return typer.prompt("", default="SPY") + """Get ticker symbol from user input with validation.""" + while True: + ticker = typer.prompt("", default="SPY") + try: + # Validate ticker format + if not ticker or len(ticker) > 10: + console.print("[red]Error: Ticker must be 1-10 characters[/red]") + continue + + # Check for path traversal attempts + if '..' in ticker or '/' in ticker or '\\' in ticker: + console.print("[red]Error: Invalid characters in ticker symbol[/red]") + continue + + # Validate characters (alphanumeric, dots, hyphens only) + if not all(c.isalnum() or c in '.-' for c in ticker): + console.print("[red]Error: Ticker can only contain letters, numbers, dots, and hyphens[/red]") + continue + + return ticker.upper() # Return normalized uppercase ticker + except Exception as e: + console.print(f"[red]Error validating ticker: {e}[/red]") def get_analysis_date(): diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..10c40e33 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -10,7 +10,7 @@ class FinancialSituationMemory: else: self.embedding = "text-embedding-3-small" self.client = OpenAI(base_url=config["backend_url"]) - self.chroma_client = chromadb.Client(Settings(allow_reset=True)) + self.chroma_client = chromadb.Client(Settings(allow_reset=False)) self.situation_collection = self.chroma_client.create_collection(name=name) def get_embedding(self, text): diff --git a/tradingagents/dataflows/local.py b/tradingagents/dataflows/local.py index 502bc43a..902257af 100644 --- a/tradingagents/dataflows/local.py +++ b/tradingagents/dataflows/local.py @@ -7,12 +7,45 @@ from dateutil.relativedelta import relativedelta import json from .reddit_utils import fetch_top_from_category from tqdm import tqdm +import re + + +def validate_ticker_symbol(symbol: str) -> str: + """ + Validate and sanitize ticker symbol to prevent path traversal attacks. + + Args: + symbol: Ticker symbol to validate + + Returns: + Sanitized ticker symbol (uppercase) + + Raises: + ValueError: If ticker contains invalid characters or patterns + """ + # Ticker symbols should only contain alphanumeric characters, dots, and hyphens + if not re.match(r'^[A-Za-z0-9.\-]+$', symbol): + raise ValueError(f"Invalid ticker symbol: {symbol}") + + # Prevent path traversal patterns + if '..' in symbol or '/' in symbol or '\\' in symbol: + raise ValueError(f"Path traversal attempt detected in ticker: {symbol}") + + # Limit length (typical tickers are 1-5 characters, extended can be up to 10) + if len(symbol) > 10: + raise ValueError(f"Ticker symbol too long: {symbol}") + + return symbol.upper() # Normalize to uppercase + def get_YFin_data_window( symbol: Annotated[str, "ticker symbol of the company"], curr_date: Annotated[str, "Start date in yyyy-mm-dd format"], look_back_days: Annotated[int, "how many days to look back"], ) -> str: + # Validate ticker symbol to prevent path traversal + symbol = validate_ticker_symbol(symbol) + # calculate past days date_obj = datetime.strptime(curr_date, "%Y-%m-%d") before = date_obj - relativedelta(days=look_back_days) @@ -53,6 +86,9 @@ def get_YFin_data( start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: + # Validate ticker symbol to prevent path traversal + symbol = validate_ticker_symbol(symbol) + # read in data data = pd.read_csv( os.path.join( @@ -129,6 +165,8 @@ def get_finnhub_company_insider_sentiment( Returns: str: a report of the sentiment in the past 15 days starting at curr_date """ + # Validate ticker symbol to prevent path traversal + ticker = validate_ticker_symbol(ticker) date_obj = datetime.strptime(curr_date, "%Y-%m-%d") before = date_obj - relativedelta(days=15) # Default 15 days lookback @@ -166,6 +204,8 @@ def get_finnhub_company_insider_transactions( Returns: str: a report of the company's insider transaction/trading informtaion in the past 15 days """ + # Validate ticker symbol to prevent path traversal + ticker = validate_ticker_symbol(ticker) date_obj = datetime.strptime(curr_date, "%Y-%m-%d") before = date_obj - relativedelta(days=15) # Default 15 days lookback @@ -201,6 +241,8 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period= data_dir (str): Directory where the data is saved. period (str): Default to none, if there is a period specified, should be annual or quarterly. """ + # Validate ticker symbol to prevent path traversal + ticker = validate_ticker_symbol(ticker) if period: data_path = os.path.join(