security: Apply critical security fixes from PR #281 review
Implement the top 3 critical security fixes identified in Gemini code review: **Fix 1: ChromaDB Reset Protection** - Changed `allow_reset=True` to `False` in memory.py - Prevents catastrophic database deletion in production - File: tradingagents/agents/utils/memory.py:13 **Fix 2: Path Traversal Prevention** - Added `validate_ticker_symbol()` function with comprehensive validation - Applied validation to 5 functions using ticker in file paths: - get_YFin_data_window() - get_YFin_data() - get_data_in_range() - get_finnhub_company_insider_sentiment() - get_finnhub_company_insider_transactions() - Blocks: path traversal (../, \\), invalid chars, length > 10 - File: tradingagents/dataflows/local.py **Fix 3: CLI Input Validation** - Added validation loop to get_ticker() with user-friendly error messages - Prevents malicious input at entry point - Validates format, blocks traversal, limits length - File: cli/main.py:499-521 **Testing:** - Validation logic verified with attack vectors: - ../../etc/passwd (blocked ✓) - Long tickers (blocked ✓) - Special characters (blocked ✓) - Valid tickers: AAPL, BRK.B (pass ✓) **Changes:** - 3 files changed, 65 insertions(+), 3 deletions(-) - Implementation time: ~20 minutes - Zero breaking changes to existing functionality **References:** - Security analysis: docs/security/PR281_CRITICAL_FIXES.md - Future roadmap: docs/security/FUTURE_HARDENING.md Addresses critical path traversal (CWE-22) and data loss vulnerabilities.
This commit is contained in:
parent
3def80c37f
commit
218cedf56f
24
cli/main.py
24
cli/main.py
|
|
@ -497,8 +497,28 @@ def get_user_selections():
|
|||
|
||||
|
||||
def get_ticker():
|
||||
"""Get ticker symbol from user input."""
|
||||
return typer.prompt("", default="SPY")
|
||||
"""Get ticker symbol from user input with validation."""
|
||||
while True:
|
||||
ticker = typer.prompt("", default="SPY")
|
||||
try:
|
||||
# Validate ticker format
|
||||
if not ticker or len(ticker) > 10:
|
||||
console.print("[red]Error: Ticker must be 1-10 characters[/red]")
|
||||
continue
|
||||
|
||||
# Check for path traversal attempts
|
||||
if '..' in ticker or '/' in ticker or '\\' in ticker:
|
||||
console.print("[red]Error: Invalid characters in ticker symbol[/red]")
|
||||
continue
|
||||
|
||||
# Validate characters (alphanumeric, dots, hyphens only)
|
||||
if not all(c.isalnum() or c in '.-' for c in ticker):
|
||||
console.print("[red]Error: Ticker can only contain letters, numbers, dots, and hyphens[/red]")
|
||||
continue
|
||||
|
||||
return ticker.upper() # Return normalized uppercase ticker
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error validating ticker: {e}[/red]")
|
||||
|
||||
|
||||
def get_analysis_date():
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class FinancialSituationMemory:
|
|||
else:
|
||||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=False))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
|
|
|
|||
|
|
@ -7,12 +7,45 @@ from dateutil.relativedelta import relativedelta
|
|||
import json
|
||||
from .reddit_utils import fetch_top_from_category
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
|
||||
|
||||
def validate_ticker_symbol(symbol: str) -> str:
|
||||
"""
|
||||
Validate and sanitize ticker symbol to prevent path traversal attacks.
|
||||
|
||||
Args:
|
||||
symbol: Ticker symbol to validate
|
||||
|
||||
Returns:
|
||||
Sanitized ticker symbol (uppercase)
|
||||
|
||||
Raises:
|
||||
ValueError: If ticker contains invalid characters or patterns
|
||||
"""
|
||||
# Ticker symbols should only contain alphanumeric characters, dots, and hyphens
|
||||
if not re.match(r'^[A-Za-z0-9.\-]+$', symbol):
|
||||
raise ValueError(f"Invalid ticker symbol: {symbol}")
|
||||
|
||||
# Prevent path traversal patterns
|
||||
if '..' in symbol or '/' in symbol or '\\' in symbol:
|
||||
raise ValueError(f"Path traversal attempt detected in ticker: {symbol}")
|
||||
|
||||
# Limit length (typical tickers are 1-5 characters, extended can be up to 10)
|
||||
if len(symbol) > 10:
|
||||
raise ValueError(f"Ticker symbol too long: {symbol}")
|
||||
|
||||
return symbol.upper() # Normalize to uppercase
|
||||
|
||||
|
||||
def get_YFin_data_window(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
) -> str:
|
||||
# Validate ticker symbol to prevent path traversal
|
||||
symbol = validate_ticker_symbol(symbol)
|
||||
|
||||
# calculate past days
|
||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = date_obj - relativedelta(days=look_back_days)
|
||||
|
|
@ -53,6 +86,9 @@ def get_YFin_data(
|
|||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
# Validate ticker symbol to prevent path traversal
|
||||
symbol = validate_ticker_symbol(symbol)
|
||||
|
||||
# read in data
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
|
|
@ -129,6 +165,8 @@ def get_finnhub_company_insider_sentiment(
|
|||
Returns:
|
||||
str: a report of the sentiment in the past 15 days starting at curr_date
|
||||
"""
|
||||
# Validate ticker symbol to prevent path traversal
|
||||
ticker = validate_ticker_symbol(ticker)
|
||||
|
||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = date_obj - relativedelta(days=15) # Default 15 days lookback
|
||||
|
|
@ -166,6 +204,8 @@ def get_finnhub_company_insider_transactions(
|
|||
Returns:
|
||||
str: a report of the company's insider transaction/trading informtaion in the past 15 days
|
||||
"""
|
||||
# Validate ticker symbol to prevent path traversal
|
||||
ticker = validate_ticker_symbol(ticker)
|
||||
|
||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = date_obj - relativedelta(days=15) # Default 15 days lookback
|
||||
|
|
@ -201,6 +241,8 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
|
|||
data_dir (str): Directory where the data is saved.
|
||||
period (str): Default to none, if there is a period specified, should be annual or quarterly.
|
||||
"""
|
||||
# Validate ticker symbol to prevent path traversal
|
||||
ticker = validate_ticker_symbol(ticker)
|
||||
|
||||
if period:
|
||||
data_path = os.path.join(
|
||||
|
|
|
|||
Loading…
Reference in New Issue