feat: add centralized input validation for ticker symbols and dates

- Create tradingagents/validation.py with comprehensive validation functions:
  - validate_ticker/validate_tickers for ticker symbol validation
  - validate_date/validate_date_range for date validation
  - parse_date for flexible date parsing (multiple formats)
  - Helper functions: is_valid_ticker, is_valid_date, is_trading_day
  - get_previous_trading_day, get_next_trading_day utilities
- Add validation to y_finance.py (get_YFin_data_online, get_stock_stats_indicators_window)
- Add validation to trading_graph.py propagate method
- Create comprehensive test suite (55 tests) in tests/test_validation.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 03:09:05 -05:00
parent f70874982a
commit e862c4f803
4 changed files with 585 additions and 4 deletions

275
tests/test_validation.py Normal file
View File

@ -0,0 +1,275 @@
import pytest
from datetime import date, datetime, timedelta
from tradingagents.validation import (
ValidationError,
TickerValidationError,
DateValidationError,
validate_ticker,
validate_tickers,
parse_date,
validate_date,
validate_date_range,
format_date,
is_valid_ticker,
is_valid_date,
is_trading_day,
get_previous_trading_day,
get_next_trading_day,
)
class TestValidateTicker:
def test_valid_simple_ticker(self):
assert validate_ticker("AAPL") == "AAPL"
assert validate_ticker("A") == "A"
assert validate_ticker("GOOGL") == "GOOGL"
def test_valid_ticker_lowercase_converted(self):
assert validate_ticker("aapl") == "AAPL"
assert validate_ticker("Msft") == "MSFT"
def test_valid_ticker_with_whitespace(self):
assert validate_ticker(" AAPL ") == "AAPL"
assert validate_ticker("\tTSLA\n") == "TSLA"
def test_valid_ticker_with_class_indicator(self):
assert validate_ticker("BRK-B") == "BRK-B"
assert validate_ticker("BRK-A") == "BRK-A"
assert validate_ticker("BRK.B") == "BRK.B"
def test_invalid_ticker_none(self):
with pytest.raises(TickerValidationError, match="cannot be None"):
validate_ticker(None)
def test_invalid_ticker_none_allowed(self):
assert validate_ticker(None, allow_empty=True) == ""
def test_invalid_ticker_empty(self):
with pytest.raises(TickerValidationError, match="cannot be empty"):
validate_ticker("")
def test_invalid_ticker_empty_allowed(self):
assert validate_ticker("", allow_empty=True) == ""
def test_invalid_ticker_wrong_type(self):
with pytest.raises(TickerValidationError, match="must be a string"):
validate_ticker(123)
def test_invalid_ticker_too_long(self):
with pytest.raises(TickerValidationError, match="too long"):
validate_ticker("VERYLONGTICKER")
def test_invalid_ticker_format(self):
with pytest.raises(TickerValidationError, match="Invalid ticker format"):
validate_ticker("AAPL123")
with pytest.raises(TickerValidationError, match="Invalid ticker format"):
validate_ticker("AA-PL-B")
with pytest.raises(TickerValidationError, match="Invalid ticker format"):
validate_ticker("A@PL")
class TestValidateTickers:
def test_valid_tickers_list(self):
result = validate_tickers(["AAPL", "MSFT", "GOOGL"])
assert result == ["AAPL", "MSFT", "GOOGL"]
def test_valid_tickers_lowercase_converted(self):
result = validate_tickers(["aapl", "msft"])
assert result == ["AAPL", "MSFT"]
def test_valid_tickers_tuple(self):
result = validate_tickers(("AAPL", "MSFT"))
assert result == ["AAPL", "MSFT"]
def test_invalid_tickers_none(self):
with pytest.raises(TickerValidationError, match="cannot be None"):
validate_tickers(None)
def test_invalid_tickers_none_allowed(self):
assert validate_tickers(None, allow_empty_list=True) == []
def test_invalid_tickers_empty_list(self):
with pytest.raises(TickerValidationError, match="cannot be empty"):
validate_tickers([])
def test_invalid_tickers_empty_list_allowed(self):
assert validate_tickers([], allow_empty_list=True) == []
def test_invalid_tickers_wrong_type(self):
with pytest.raises(TickerValidationError, match="must be a list"):
validate_tickers("AAPL")
def test_invalid_tickers_contains_invalid(self):
with pytest.raises(TickerValidationError, match="Invalid tickers"):
validate_tickers(["AAPL", "INVALID123", "MSFT"])
class TestParseDate:
def test_parse_date_string_default_format(self):
result = parse_date("2024-01-15")
assert result == date(2024, 1, 15)
def test_parse_date_string_various_formats(self):
assert parse_date("2024/01/15") == date(2024, 1, 15)
assert parse_date("01/15/2024") == date(2024, 1, 15)
assert parse_date("20240115") == date(2024, 1, 15)
def test_parse_date_from_date(self):
d = date(2024, 1, 15)
assert parse_date(d) == d
def test_parse_date_from_datetime(self):
dt = datetime(2024, 1, 15, 10, 30)
assert parse_date(dt) == date(2024, 1, 15)
def test_parse_date_none(self):
assert parse_date(None) is None
def test_parse_date_empty_string(self):
assert parse_date("") is None
assert parse_date(" ") is None
def test_parse_date_invalid_format(self):
with pytest.raises(DateValidationError, match="Could not parse date"):
parse_date("not-a-date")
def test_parse_date_invalid_type(self):
with pytest.raises(DateValidationError, match="must be string"):
parse_date(123)
class TestValidateDate:
def test_validate_date_valid(self):
result = validate_date("2024-01-15")
assert result == date(2024, 1, 15)
def test_validate_date_none_not_allowed(self):
with pytest.raises(DateValidationError, match="cannot be None"):
validate_date(None)
def test_validate_date_none_allowed(self):
assert validate_date(None, allow_none=True) is None
def test_validate_date_before_min(self):
with pytest.raises(DateValidationError, match="before minimum"):
validate_date("1960-01-01")
def test_validate_date_custom_min(self):
with pytest.raises(DateValidationError, match="before minimum"):
validate_date("2020-01-01", min_date=date(2021, 1, 1))
def test_validate_date_future_not_allowed(self):
future = date.today() + timedelta(days=30)
with pytest.raises(DateValidationError, match="in the future"):
validate_date(future.strftime("%Y-%m-%d"), allow_future=False)
def test_validate_date_future_allowed(self):
future = date.today() + timedelta(days=30)
result = validate_date(future.strftime("%Y-%m-%d"), allow_future=True)
assert result == future
def test_validate_date_after_max(self):
far_future = date.today() + timedelta(days=500)
with pytest.raises(DateValidationError, match="after maximum"):
validate_date(far_future.strftime("%Y-%m-%d"))
def test_validate_date_weekend_not_allowed(self):
saturday = date(2024, 1, 6)
with pytest.raises(DateValidationError, match="Saturday"):
validate_date(saturday, allow_weekend=False)
sunday = date(2024, 1, 7)
with pytest.raises(DateValidationError, match="Sunday"):
validate_date(sunday, allow_weekend=False)
def test_validate_date_weekend_allowed(self):
saturday = date(2024, 1, 6)
result = validate_date(saturday, allow_weekend=True)
assert result == saturday
class TestValidateDateRange:
def test_validate_date_range_valid(self):
start, end = validate_date_range("2024-01-01", "2024-01-31")
assert start == date(2024, 1, 1)
assert end == date(2024, 1, 31)
def test_validate_date_range_end_before_start(self):
with pytest.raises(DateValidationError, match="must be on or after"):
validate_date_range("2024-01-31", "2024-01-01")
def test_validate_date_range_same_day(self):
with pytest.raises(DateValidationError, match="must be after"):
validate_date_range("2024-01-15", "2024-01-15")
def test_validate_date_range_max_range(self):
with pytest.raises(DateValidationError, match="exceeds maximum"):
validate_date_range("2020-01-01", "2024-01-01", max_range_days=365)
class TestFormatDate:
def test_format_date_default(self):
result = format_date(date(2024, 1, 15))
assert result == "2024-01-15"
def test_format_date_custom_format(self):
result = format_date(date(2024, 1, 15), output_format="%m/%d/%Y")
assert result == "01/15/2024"
def test_format_date_from_string(self):
result = format_date("2024-01-15", output_format="%Y%m%d")
assert result == "20240115"
def test_format_date_none(self):
with pytest.raises(DateValidationError, match="Cannot format None"):
format_date(None)
class TestHelperFunctions:
def test_is_valid_ticker(self):
assert is_valid_ticker("AAPL") is True
assert is_valid_ticker("INVALID123") is False
assert is_valid_ticker("") is False
def test_is_valid_date(self):
assert is_valid_date("2024-01-15") is True
assert is_valid_date("not-a-date") is False
assert is_valid_date(date(2024, 1, 15)) is True
def test_is_trading_day(self):
assert is_trading_day(date(2024, 1, 15)) is True
assert is_trading_day(date(2024, 1, 13)) is False
assert is_trading_day("2024-01-15") is True
def test_get_previous_trading_day_from_monday(self):
monday = date(2024, 1, 15)
result = get_previous_trading_day(monday)
assert result == monday
def test_get_previous_trading_day_from_saturday(self):
saturday = date(2024, 1, 13)
result = get_previous_trading_day(saturday)
assert result == date(2024, 1, 12)
def test_get_previous_trading_day_from_sunday(self):
sunday = date(2024, 1, 14)
result = get_previous_trading_day(sunday)
assert result == date(2024, 1, 12)
def test_get_next_trading_day_from_friday(self):
friday = date(2024, 1, 12)
result = get_next_trading_day(friday)
assert result == friday
def test_get_next_trading_day_from_saturday(self):
saturday = date(2024, 1, 13)
result = get_next_trading_day(saturday)
assert result == date(2024, 1, 15)
def test_get_next_trading_day_from_sunday(self):
sunday = date(2024, 1, 14)
result = get_next_trading_day(sunday)
assert result == date(2024, 1, 15)

View File

@ -5,6 +5,7 @@ from dateutil.relativedelta import relativedelta
import yfinance as yf
import os
from .stockstats_utils import StockstatsUtils
from tradingagents.validation import validate_ticker, validate_date_range, validate_date
logger = logging.getLogger(__name__)
@ -13,11 +14,12 @@ def get_YFin_data_online(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
symbol = validate_ticker(symbol)
start, end = validate_date_range(start_date, end_date)
start_date = start.strftime("%Y-%m-%d")
end_date = end.strftime("%Y-%m-%d")
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
ticker = yf.Ticker(symbol.upper())
ticker = yf.Ticker(symbol)
data = ticker.history(start=start_date, end=end_date)
@ -50,6 +52,9 @@ def get_stock_stats_indicators_window(
],
look_back_days: Annotated[int, "how many days to look back"],
) -> str:
symbol = validate_ticker(symbol)
validated_date = validate_date(curr_date, allow_future=False)
curr_date = validated_date.strftime("%Y-%m-%d")
best_ind_params = {
"close_50_sma": (

View File

@ -48,6 +48,7 @@ from tradingagents.agents.discovery import (
calculate_trending_scores,
)
from tradingagents.dataflows.interface import get_bulk_news
from tradingagents.validation import validate_ticker, validate_date, parse_date
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
@ -155,6 +156,11 @@ class TradingAgentsGraph:
}
def propagate(self, company_name, trade_date):
company_name = validate_ticker(company_name)
validated_date = validate_date(trade_date, allow_future=False)
if isinstance(trade_date, str):
trade_date = validated_date
self.ticker = company_name
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date

295
tradingagents/validation.py Normal file
View File

@ -0,0 +1,295 @@
import logging
import re
from datetime import date, datetime, timedelta
from typing import Optional, Union
logger = logging.getLogger(__name__)
TICKER_PATTERN = re.compile(r"^[A-Z]{1,5}(-[A-Z]{1,2})?$")
TICKER_SPECIAL_PATTERN = re.compile(r"^[A-Z]{1,5}(\.[A-Z]{1,2})?$")
MIN_VALID_DATE = date(1970, 1, 1)
MAX_FUTURE_DAYS = 365
class ValidationError(Exception):
pass
class TickerValidationError(ValidationError):
pass
class DateValidationError(ValidationError):
pass
def validate_ticker(
ticker: str,
allow_empty: bool = False,
check_format_only: bool = True,
) -> str:
if ticker is None:
if allow_empty:
return ""
raise TickerValidationError("Ticker cannot be None")
if not isinstance(ticker, str):
raise TickerValidationError(f"Ticker must be a string, got {type(ticker).__name__}")
ticker = ticker.strip().upper()
if not ticker:
if allow_empty:
return ""
raise TickerValidationError("Ticker cannot be empty")
if len(ticker) > 10:
raise TickerValidationError(f"Ticker '{ticker}' is too long (max 10 characters)")
if not TICKER_PATTERN.match(ticker) and not TICKER_SPECIAL_PATTERN.match(ticker):
raise TickerValidationError(
f"Invalid ticker format '{ticker}'. Must be 1-5 uppercase letters, "
"optionally followed by a class indicator (e.g., BRK-B, BRK.A)"
)
return ticker
def validate_tickers(
tickers: list[str],
allow_empty_list: bool = False,
check_format_only: bool = True,
) -> list[str]:
if tickers is None:
if allow_empty_list:
return []
raise TickerValidationError("Tickers list cannot be None")
if not isinstance(tickers, (list, tuple)):
raise TickerValidationError(f"Tickers must be a list, got {type(tickers).__name__}")
if not tickers:
if allow_empty_list:
return []
raise TickerValidationError("Tickers list cannot be empty")
validated = []
errors = []
for i, ticker in enumerate(tickers):
try:
validated.append(validate_ticker(ticker, check_format_only=check_format_only))
except TickerValidationError as e:
errors.append(f"Index {i}: {e}")
if errors:
raise TickerValidationError(f"Invalid tickers: {'; '.join(errors)}")
return validated
def parse_date(
date_input: Union[str, date, datetime, None],
date_format: str = "%Y-%m-%d",
) -> Optional[date]:
if date_input is None:
return None
if isinstance(date_input, datetime):
return date_input.date()
if isinstance(date_input, date):
return date_input
if not isinstance(date_input, str):
raise DateValidationError(f"Date must be string, date, or datetime, got {type(date_input).__name__}")
date_input = date_input.strip()
if not date_input:
return None
formats_to_try = [
date_format,
"%Y-%m-%d",
"%Y/%m/%d",
"%m/%d/%Y",
"%m-%d-%Y",
"%d/%m/%Y",
"%d-%m-%Y",
"%Y%m%d",
]
for fmt in formats_to_try:
try:
return datetime.strptime(date_input, fmt).date()
except ValueError:
continue
raise DateValidationError(
f"Could not parse date '{date_input}'. Expected format: {date_format} "
f"(e.g., {datetime.now().strftime(date_format)})"
)
def validate_date(
date_input: Union[str, date, datetime, None],
date_format: str = "%Y-%m-%d",
allow_none: bool = False,
min_date: Optional[date] = None,
max_date: Optional[date] = None,
allow_future: bool = True,
allow_weekend: bool = True,
) -> Optional[date]:
if date_input is None:
if allow_none:
return None
raise DateValidationError("Date cannot be None")
parsed = parse_date(date_input, date_format)
if parsed is None:
if allow_none:
return None
raise DateValidationError("Date cannot be empty")
effective_min = min_date or MIN_VALID_DATE
if parsed < effective_min:
raise DateValidationError(
f"Date {parsed} is before minimum allowed date {effective_min}"
)
today = date.today()
if not allow_future and parsed > today:
raise DateValidationError(
f"Date {parsed} is in the future. Future dates are not allowed."
)
effective_max = max_date or (today + timedelta(days=MAX_FUTURE_DAYS))
if parsed > effective_max:
raise DateValidationError(
f"Date {parsed} is after maximum allowed date {effective_max}"
)
if not allow_weekend and parsed.weekday() >= 5:
day_name = "Saturday" if parsed.weekday() == 5 else "Sunday"
raise DateValidationError(
f"Date {parsed} falls on a {day_name}. Weekend dates are not allowed."
)
return parsed
def validate_date_range(
start_date: Union[str, date, datetime],
end_date: Union[str, date, datetime],
date_format: str = "%Y-%m-%d",
min_date: Optional[date] = None,
max_date: Optional[date] = None,
allow_future: bool = True,
max_range_days: Optional[int] = None,
) -> tuple[date, date]:
start = validate_date(
start_date,
date_format=date_format,
allow_none=False,
min_date=min_date,
max_date=max_date,
allow_future=allow_future,
)
end = validate_date(
end_date,
date_format=date_format,
allow_none=False,
min_date=min_date,
max_date=max_date,
allow_future=allow_future,
)
if end < start:
raise DateValidationError(
f"End date ({end}) must be on or after start date ({start})"
)
if end == start:
raise DateValidationError(
f"End date ({end}) must be after start date ({start})"
)
if max_range_days is not None:
range_days = (end - start).days
if range_days > max_range_days:
raise DateValidationError(
f"Date range of {range_days} days exceeds maximum of {max_range_days} days"
)
return start, end
def format_date(
date_input: Union[str, date, datetime],
output_format: str = "%Y-%m-%d",
input_format: str = "%Y-%m-%d",
) -> str:
parsed = parse_date(date_input, input_format)
if parsed is None:
raise DateValidationError("Cannot format None date")
return parsed.strftime(output_format)
def is_valid_ticker(ticker: str) -> bool:
try:
validate_ticker(ticker)
return True
except TickerValidationError:
return False
def is_valid_date(
date_input: Union[str, date, datetime],
date_format: str = "%Y-%m-%d",
) -> bool:
try:
validate_date(date_input, date_format=date_format)
return True
except DateValidationError:
return False
def is_trading_day(check_date: Union[str, date, datetime]) -> bool:
parsed = parse_date(check_date)
if parsed is None:
return False
return parsed.weekday() < 5
def get_previous_trading_day(from_date: Union[str, date, datetime, None] = None) -> date:
if from_date is None:
check = date.today()
else:
check = parse_date(from_date)
if check is None:
check = date.today()
while check.weekday() >= 5:
check = check - timedelta(days=1)
return check
def get_next_trading_day(from_date: Union[str, date, datetime, None] = None) -> date:
if from_date is None:
check = date.today()
else:
check = parse_date(from_date)
if check is None:
check = date.today()
while check.weekday() >= 5:
check = check + timedelta(days=1)
return check