security: Apply critical security fixes from PR #281 review
Implement the top 3 critical security fixes identified in Gemini code review: **Fix 1: ChromaDB Reset Protection** - Changed `allow_reset=True` to `False` in memory.py - Prevents catastrophic database deletion in production - File: tradingagents/agents/utils/memory.py:13 **Fix 2: Path Traversal Prevention** - Added `validate_ticker_symbol()` function with comprehensive validation - Applied validation to 5 functions using ticker in file paths: - get_YFin_data_window() - get_YFin_data() - get_data_in_range() - get_finnhub_company_insider_sentiment() - get_finnhub_company_insider_transactions() - Blocks: path traversal (../, \\), invalid chars, length > 10 - File: tradingagents/dataflows/local.py **Fix 3: CLI Input Validation** - Added validation loop to get_ticker() with user-friendly error messages - Prevents malicious input at entry point - Validates format, blocks traversal, limits length - File: cli/main.py:499-521 **Testing:** - Validation logic verified with attack vectors: - ../../etc/passwd (blocked ✓) - Long tickers (blocked ✓) - Special characters (blocked ✓) - Valid tickers: AAPL, BRK.B (pass ✓) **Changes:** - 3 files changed, 65 insertions(+), 3 deletions(-) - Implementation time: ~20 minutes - Zero breaking changes to existing functionality **References:** - Security analysis: docs/security/PR281_CRITICAL_FIXES.md - Future roadmap: docs/security/FUTURE_HARDENING.md Addresses critical path traversal (CWE-22) and data loss vulnerabilities.
This commit is contained in:
parent
3def80c37f
commit
218cedf56f
24
cli/main.py
24
cli/main.py
|
|
@ -497,8 +497,28 @@ def get_user_selections():
|
||||||
|
|
||||||
|
|
||||||
def get_ticker():
|
def get_ticker():
|
||||||
"""Get ticker symbol from user input."""
|
"""Get ticker symbol from user input with validation."""
|
||||||
return typer.prompt("", default="SPY")
|
while True:
|
||||||
|
ticker = typer.prompt("", default="SPY")
|
||||||
|
try:
|
||||||
|
# Validate ticker format
|
||||||
|
if not ticker or len(ticker) > 10:
|
||||||
|
console.print("[red]Error: Ticker must be 1-10 characters[/red]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for path traversal attempts
|
||||||
|
if '..' in ticker or '/' in ticker or '\\' in ticker:
|
||||||
|
console.print("[red]Error: Invalid characters in ticker symbol[/red]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate characters (alphanumeric, dots, hyphens only)
|
||||||
|
if not all(c.isalnum() or c in '.-' for c in ticker):
|
||||||
|
console.print("[red]Error: Ticker can only contain letters, numbers, dots, and hyphens[/red]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return ticker.upper() # Return normalized uppercase ticker
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Error validating ticker: {e}[/red]")
|
||||||
|
|
||||||
|
|
||||||
def get_analysis_date():
|
def get_analysis_date():
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ class FinancialSituationMemory:
|
||||||
else:
|
else:
|
||||||
self.embedding = "text-embedding-3-small"
|
self.embedding = "text-embedding-3-small"
|
||||||
self.client = OpenAI(base_url=config["backend_url"])
|
self.client = OpenAI(base_url=config["backend_url"])
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
self.chroma_client = chromadb.Client(Settings(allow_reset=False))
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||||
|
|
||||||
def get_embedding(self, text):
|
def get_embedding(self, text):
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,45 @@ from dateutil.relativedelta import relativedelta
|
||||||
import json
|
import json
|
||||||
from .reddit_utils import fetch_top_from_category
|
from .reddit_utils import fetch_top_from_category
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def validate_ticker_symbol(symbol: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate and sanitize ticker symbol to prevent path traversal attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Ticker symbol to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized ticker symbol (uppercase)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If ticker contains invalid characters or patterns
|
||||||
|
"""
|
||||||
|
# Ticker symbols should only contain alphanumeric characters, dots, and hyphens
|
||||||
|
if not re.match(r'^[A-Za-z0-9.\-]+$', symbol):
|
||||||
|
raise ValueError(f"Invalid ticker symbol: {symbol}")
|
||||||
|
|
||||||
|
# Prevent path traversal patterns
|
||||||
|
if '..' in symbol or '/' in symbol or '\\' in symbol:
|
||||||
|
raise ValueError(f"Path traversal attempt detected in ticker: {symbol}")
|
||||||
|
|
||||||
|
# Limit length (typical tickers are 1-5 characters, extended can be up to 10)
|
||||||
|
if len(symbol) > 10:
|
||||||
|
raise ValueError(f"Ticker symbol too long: {symbol}")
|
||||||
|
|
||||||
|
return symbol.upper() # Normalize to uppercase
|
||||||
|
|
||||||
|
|
||||||
def get_YFin_data_window(
|
def get_YFin_data_window(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
look_back_days: Annotated[int, "how many days to look back"],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
# Validate ticker symbol to prevent path traversal
|
||||||
|
symbol = validate_ticker_symbol(symbol)
|
||||||
|
|
||||||
# calculate past days
|
# calculate past days
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
before = date_obj - relativedelta(days=look_back_days)
|
before = date_obj - relativedelta(days=look_back_days)
|
||||||
|
|
@ -53,6 +86,9 @@ def get_YFin_data(
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
# Validate ticker symbol to prevent path traversal
|
||||||
|
symbol = validate_ticker_symbol(symbol)
|
||||||
|
|
||||||
# read in data
|
# read in data
|
||||||
data = pd.read_csv(
|
data = pd.read_csv(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
|
|
@ -129,6 +165,8 @@ def get_finnhub_company_insider_sentiment(
|
||||||
Returns:
|
Returns:
|
||||||
str: a report of the sentiment in the past 15 days starting at curr_date
|
str: a report of the sentiment in the past 15 days starting at curr_date
|
||||||
"""
|
"""
|
||||||
|
# Validate ticker symbol to prevent path traversal
|
||||||
|
ticker = validate_ticker_symbol(ticker)
|
||||||
|
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
before = date_obj - relativedelta(days=15) # Default 15 days lookback
|
before = date_obj - relativedelta(days=15) # Default 15 days lookback
|
||||||
|
|
@ -166,6 +204,8 @@ def get_finnhub_company_insider_transactions(
|
||||||
Returns:
|
Returns:
|
||||||
str: a report of the company's insider transaction/trading informtaion in the past 15 days
|
str: a report of the company's insider transaction/trading informtaion in the past 15 days
|
||||||
"""
|
"""
|
||||||
|
# Validate ticker symbol to prevent path traversal
|
||||||
|
ticker = validate_ticker_symbol(ticker)
|
||||||
|
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
before = date_obj - relativedelta(days=15) # Default 15 days lookback
|
before = date_obj - relativedelta(days=15) # Default 15 days lookback
|
||||||
|
|
@ -201,6 +241,8 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
|
||||||
data_dir (str): Directory where the data is saved.
|
data_dir (str): Directory where the data is saved.
|
||||||
period (str): Default to none, if there is a period specified, should be annual or quarterly.
|
period (str): Default to none, if there is a period specified, should be annual or quarterly.
|
||||||
"""
|
"""
|
||||||
|
# Validate ticker symbol to prevent path traversal
|
||||||
|
ticker = validate_ticker_symbol(ticker)
|
||||||
|
|
||||||
if period:
|
if period:
|
||||||
data_path = os.path.join(
|
data_path = os.path.join(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue