From e862c4f803abb762f134a18a65f3858a78c72628 Mon Sep 17 00:00:00 2001 From: Joseph O'Brien <98370624+89jobrien@users.noreply.github.com> Date: Wed, 3 Dec 2025 03:09:05 -0500 Subject: [PATCH] feat: add centralized input validation for ticker symbols and dates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create tradingagents/validation.py with comprehensive validation functions: - validate_ticker/validate_tickers for ticker symbol validation - validate_date/validate_date_range for date validation - parse_date for flexible date parsing (multiple formats) - Helper functions: is_valid_ticker, is_valid_date, is_trading_day - get_previous_trading_day, get_next_trading_day utilities - Add validation to y_finance.py (get_YFin_data_online, get_stock_stats_indicators_window) - Add validation to trading_graph.py propagate method - Create comprehensive test suite (55 tests) in tests/test_validation.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/test_validation.py | 275 +++++++++++++++++++++++++ tradingagents/dataflows/y_finance.py | 13 +- tradingagents/graph/trading_graph.py | 6 + tradingagents/validation.py | 295 +++++++++++++++++++++++++++ 4 files changed, 585 insertions(+), 4 deletions(-) create mode 100644 tests/test_validation.py create mode 100644 tradingagents/validation.py diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..2cebe9d5 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,275 @@ +import pytest +from datetime import date, datetime, timedelta + +from tradingagents.validation import ( + ValidationError, + TickerValidationError, + DateValidationError, + validate_ticker, + validate_tickers, + parse_date, + validate_date, + validate_date_range, + format_date, + is_valid_ticker, + is_valid_date, + is_trading_day, + get_previous_trading_day, + get_next_trading_day, +) + + +class TestValidateTicker: + def test_valid_simple_ticker(self): + assert validate_ticker("AAPL") == "AAPL" + assert validate_ticker("A") == "A" + assert validate_ticker("GOOGL") == "GOOGL" + + def test_valid_ticker_lowercase_converted(self): + assert validate_ticker("aapl") == "AAPL" + assert validate_ticker("Msft") == "MSFT" + + def test_valid_ticker_with_whitespace(self): + assert validate_ticker(" AAPL ") == "AAPL" + assert validate_ticker("\tTSLA\n") == "TSLA" + + def test_valid_ticker_with_class_indicator(self): + assert validate_ticker("BRK-B") == "BRK-B" + assert validate_ticker("BRK-A") == "BRK-A" + assert validate_ticker("BRK.B") == "BRK.B" + + def test_invalid_ticker_none(self): + with pytest.raises(TickerValidationError, match="cannot be None"): + validate_ticker(None) + + def test_invalid_ticker_none_allowed(self): + assert validate_ticker(None, allow_empty=True) == "" + + def test_invalid_ticker_empty(self): + with pytest.raises(TickerValidationError, match="cannot be empty"): + validate_ticker("") + + def test_invalid_ticker_empty_allowed(self): + assert validate_ticker("", allow_empty=True) == "" + + def test_invalid_ticker_wrong_type(self): + with pytest.raises(TickerValidationError, match="must be a string"): + validate_ticker(123) + + def test_invalid_ticker_too_long(self): + with pytest.raises(TickerValidationError, match="too long"): + validate_ticker("VERYLONGTICKER") + + def test_invalid_ticker_format(self): + with pytest.raises(TickerValidationError, match="Invalid ticker format"): + validate_ticker("AAPL123") + + with pytest.raises(TickerValidationError, match="Invalid ticker format"): + validate_ticker("AA-PL-B") + + with pytest.raises(TickerValidationError, match="Invalid ticker format"): + validate_ticker("A@PL") + + +class TestValidateTickers: + def test_valid_tickers_list(self): + result = validate_tickers(["AAPL", "MSFT", "GOOGL"]) + assert result == ["AAPL", "MSFT", "GOOGL"] + + def test_valid_tickers_lowercase_converted(self): + result = validate_tickers(["aapl", "msft"]) + assert result == ["AAPL", "MSFT"] + + def test_valid_tickers_tuple(self): + result = validate_tickers(("AAPL", "MSFT")) + assert result == ["AAPL", "MSFT"] + + def test_invalid_tickers_none(self): + with pytest.raises(TickerValidationError, match="cannot be None"): + validate_tickers(None) + + def test_invalid_tickers_none_allowed(self): + assert validate_tickers(None, allow_empty_list=True) == [] + + def test_invalid_tickers_empty_list(self): + with pytest.raises(TickerValidationError, match="cannot be empty"): + validate_tickers([]) + + def test_invalid_tickers_empty_list_allowed(self): + assert validate_tickers([], allow_empty_list=True) == [] + + def test_invalid_tickers_wrong_type(self): + with pytest.raises(TickerValidationError, match="must be a list"): + validate_tickers("AAPL") + + def test_invalid_tickers_contains_invalid(self): + with pytest.raises(TickerValidationError, match="Invalid tickers"): + validate_tickers(["AAPL", "INVALID123", "MSFT"]) + + +class TestParseDate: + def test_parse_date_string_default_format(self): + result = parse_date("2024-01-15") + assert result == date(2024, 1, 15) + + def test_parse_date_string_various_formats(self): + assert parse_date("2024/01/15") == date(2024, 1, 15) + assert parse_date("01/15/2024") == date(2024, 1, 15) + assert parse_date("20240115") == date(2024, 1, 15) + + def test_parse_date_from_date(self): + d = date(2024, 1, 15) + assert parse_date(d) == d + + def test_parse_date_from_datetime(self): + dt = datetime(2024, 1, 15, 10, 30) + assert parse_date(dt) == date(2024, 1, 15) + + def test_parse_date_none(self): + assert parse_date(None) is None + + def test_parse_date_empty_string(self): + assert parse_date("") is None + assert parse_date(" ") is None + + def test_parse_date_invalid_format(self): + with pytest.raises(DateValidationError, match="Could not parse date"): + parse_date("not-a-date") + + def test_parse_date_invalid_type(self): + with pytest.raises(DateValidationError, match="must be string"): + parse_date(123) + + +class TestValidateDate: + def test_validate_date_valid(self): + result = validate_date("2024-01-15") + assert result == date(2024, 1, 15) + + def test_validate_date_none_not_allowed(self): + with pytest.raises(DateValidationError, match="cannot be None"): + validate_date(None) + + def test_validate_date_none_allowed(self): + assert validate_date(None, allow_none=True) is None + + def test_validate_date_before_min(self): + with pytest.raises(DateValidationError, match="before minimum"): + validate_date("1960-01-01") + + def test_validate_date_custom_min(self): + with pytest.raises(DateValidationError, match="before minimum"): + validate_date("2020-01-01", min_date=date(2021, 1, 1)) + + def test_validate_date_future_not_allowed(self): + future = date.today() + timedelta(days=30) + with pytest.raises(DateValidationError, match="in the future"): + validate_date(future.strftime("%Y-%m-%d"), allow_future=False) + + def test_validate_date_future_allowed(self): + future = date.today() + timedelta(days=30) + result = validate_date(future.strftime("%Y-%m-%d"), allow_future=True) + assert result == future + + def test_validate_date_after_max(self): + far_future = date.today() + timedelta(days=500) + with pytest.raises(DateValidationError, match="after maximum"): + validate_date(far_future.strftime("%Y-%m-%d")) + + def test_validate_date_weekend_not_allowed(self): + saturday = date(2024, 1, 6) + with pytest.raises(DateValidationError, match="Saturday"): + validate_date(saturday, allow_weekend=False) + + sunday = date(2024, 1, 7) + with pytest.raises(DateValidationError, match="Sunday"): + validate_date(sunday, allow_weekend=False) + + def test_validate_date_weekend_allowed(self): + saturday = date(2024, 1, 6) + result = validate_date(saturday, allow_weekend=True) + assert result == saturday + + +class TestValidateDateRange: + def test_validate_date_range_valid(self): + start, end = validate_date_range("2024-01-01", "2024-01-31") + assert start == date(2024, 1, 1) + assert end == date(2024, 1, 31) + + def test_validate_date_range_end_before_start(self): + with pytest.raises(DateValidationError, match="must be on or after"): + validate_date_range("2024-01-31", "2024-01-01") + + def test_validate_date_range_same_day(self): + with pytest.raises(DateValidationError, match="must be after"): + validate_date_range("2024-01-15", "2024-01-15") + + def test_validate_date_range_max_range(self): + with pytest.raises(DateValidationError, match="exceeds maximum"): + validate_date_range("2020-01-01", "2024-01-01", max_range_days=365) + + +class TestFormatDate: + def test_format_date_default(self): + result = format_date(date(2024, 1, 15)) + assert result == "2024-01-15" + + def test_format_date_custom_format(self): + result = format_date(date(2024, 1, 15), output_format="%m/%d/%Y") + assert result == "01/15/2024" + + def test_format_date_from_string(self): + result = format_date("2024-01-15", output_format="%Y%m%d") + assert result == "20240115" + + def test_format_date_none(self): + with pytest.raises(DateValidationError, match="Cannot format None"): + format_date(None) + + +class TestHelperFunctions: + def test_is_valid_ticker(self): + assert is_valid_ticker("AAPL") is True + assert is_valid_ticker("INVALID123") is False + assert is_valid_ticker("") is False + + def test_is_valid_date(self): + assert is_valid_date("2024-01-15") is True + assert is_valid_date("not-a-date") is False + assert is_valid_date(date(2024, 1, 15)) is True + + def test_is_trading_day(self): + assert is_trading_day(date(2024, 1, 15)) is True + assert is_trading_day(date(2024, 1, 13)) is False + assert is_trading_day("2024-01-15") is True + + def test_get_previous_trading_day_from_monday(self): + monday = date(2024, 1, 15) + result = get_previous_trading_day(monday) + assert result == monday + + def test_get_previous_trading_day_from_saturday(self): + saturday = date(2024, 1, 13) + result = get_previous_trading_day(saturday) + assert result == date(2024, 1, 12) + + def test_get_previous_trading_day_from_sunday(self): + sunday = date(2024, 1, 14) + result = get_previous_trading_day(sunday) + assert result == date(2024, 1, 12) + + def test_get_next_trading_day_from_friday(self): + friday = date(2024, 1, 12) + result = get_next_trading_day(friday) + assert result == friday + + def test_get_next_trading_day_from_saturday(self): + saturday = date(2024, 1, 13) + result = get_next_trading_day(saturday) + assert result == date(2024, 1, 15) + + def test_get_next_trading_day_from_sunday(self): + sunday = date(2024, 1, 14) + result = get_next_trading_day(sunday) + assert result == date(2024, 1, 15) diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 874ccea2..a8ab7e8d 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -5,6 +5,7 @@ from dateutil.relativedelta import relativedelta import yfinance as yf import os from .stockstats_utils import StockstatsUtils +from tradingagents.validation import validate_ticker, validate_date_range, validate_date logger = logging.getLogger(__name__) @@ -13,11 +14,12 @@ def get_YFin_data_online( start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], ): + symbol = validate_ticker(symbol) + start, end = validate_date_range(start_date, end_date) + start_date = start.strftime("%Y-%m-%d") + end_date = end.strftime("%Y-%m-%d") - datetime.strptime(start_date, "%Y-%m-%d") - datetime.strptime(end_date, "%Y-%m-%d") - - ticker = yf.Ticker(symbol.upper()) + ticker = yf.Ticker(symbol) data = ticker.history(start=start_date, end=end_date) @@ -50,6 +52,9 @@ def get_stock_stats_indicators_window( ], look_back_days: Annotated[int, "how many days to look back"], ) -> str: + symbol = validate_ticker(symbol) + validated_date = validate_date(curr_date, allow_future=False) + curr_date = validated_date.strftime("%Y-%m-%d") best_ind_params = { "close_50_sma": ( diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index b2c2a5dd..8bc71d8a 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -48,6 +48,7 @@ from tradingagents.agents.discovery import ( calculate_trending_scores, ) from tradingagents.dataflows.interface import get_bulk_news +from tradingagents.validation import validate_ticker, validate_date, parse_date from .conditional_logic import ConditionalLogic from .setup import GraphSetup @@ -155,6 +156,11 @@ class TradingAgentsGraph: } def propagate(self, company_name, trade_date): + company_name = validate_ticker(company_name) + validated_date = validate_date(trade_date, allow_future=False) + if isinstance(trade_date, str): + trade_date = validated_date + self.ticker = company_name init_agent_state = self.propagator.create_initial_state( company_name, trade_date diff --git a/tradingagents/validation.py b/tradingagents/validation.py new file mode 100644 index 00000000..d1178b15 --- /dev/null +++ b/tradingagents/validation.py @@ -0,0 +1,295 @@ +import logging +import re +from datetime import date, datetime, timedelta +from typing import Optional, Union + +logger = logging.getLogger(__name__) + +TICKER_PATTERN = re.compile(r"^[A-Z]{1,5}(-[A-Z]{1,2})?$") + +TICKER_SPECIAL_PATTERN = re.compile(r"^[A-Z]{1,5}(\.[A-Z]{1,2})?$") + +MIN_VALID_DATE = date(1970, 1, 1) +MAX_FUTURE_DAYS = 365 + + +class ValidationError(Exception): + pass + + +class TickerValidationError(ValidationError): + pass + + +class DateValidationError(ValidationError): + pass + + +def validate_ticker( + ticker: str, + allow_empty: bool = False, + check_format_only: bool = True, +) -> str: + if ticker is None: + if allow_empty: + return "" + raise TickerValidationError("Ticker cannot be None") + + if not isinstance(ticker, str): + raise TickerValidationError(f"Ticker must be a string, got {type(ticker).__name__}") + + ticker = ticker.strip().upper() + + if not ticker: + if allow_empty: + return "" + raise TickerValidationError("Ticker cannot be empty") + + if len(ticker) > 10: + raise TickerValidationError(f"Ticker '{ticker}' is too long (max 10 characters)") + + if not TICKER_PATTERN.match(ticker) and not TICKER_SPECIAL_PATTERN.match(ticker): + raise TickerValidationError( + f"Invalid ticker format '{ticker}'. Must be 1-5 uppercase letters, " + "optionally followed by a class indicator (e.g., BRK-B, BRK.A)" + ) + + return ticker + + +def validate_tickers( + tickers: list[str], + allow_empty_list: bool = False, + check_format_only: bool = True, +) -> list[str]: + if tickers is None: + if allow_empty_list: + return [] + raise TickerValidationError("Tickers list cannot be None") + + if not isinstance(tickers, (list, tuple)): + raise TickerValidationError(f"Tickers must be a list, got {type(tickers).__name__}") + + if not tickers: + if allow_empty_list: + return [] + raise TickerValidationError("Tickers list cannot be empty") + + validated = [] + errors = [] + + for i, ticker in enumerate(tickers): + try: + validated.append(validate_ticker(ticker, check_format_only=check_format_only)) + except TickerValidationError as e: + errors.append(f"Index {i}: {e}") + + if errors: + raise TickerValidationError(f"Invalid tickers: {'; '.join(errors)}") + + return validated + + +def parse_date( + date_input: Union[str, date, datetime, None], + date_format: str = "%Y-%m-%d", +) -> Optional[date]: + if date_input is None: + return None + + if isinstance(date_input, datetime): + return date_input.date() + + if isinstance(date_input, date): + return date_input + + if not isinstance(date_input, str): + raise DateValidationError(f"Date must be string, date, or datetime, got {type(date_input).__name__}") + + date_input = date_input.strip() + + if not date_input: + return None + + formats_to_try = [ + date_format, + "%Y-%m-%d", + "%Y/%m/%d", + "%m/%d/%Y", + "%m-%d-%Y", + "%d/%m/%Y", + "%d-%m-%Y", + "%Y%m%d", + ] + + for fmt in formats_to_try: + try: + return datetime.strptime(date_input, fmt).date() + except ValueError: + continue + + raise DateValidationError( + f"Could not parse date '{date_input}'. Expected format: {date_format} " + f"(e.g., {datetime.now().strftime(date_format)})" + ) + + +def validate_date( + date_input: Union[str, date, datetime, None], + date_format: str = "%Y-%m-%d", + allow_none: bool = False, + min_date: Optional[date] = None, + max_date: Optional[date] = None, + allow_future: bool = True, + allow_weekend: bool = True, +) -> Optional[date]: + if date_input is None: + if allow_none: + return None + raise DateValidationError("Date cannot be None") + + parsed = parse_date(date_input, date_format) + + if parsed is None: + if allow_none: + return None + raise DateValidationError("Date cannot be empty") + + effective_min = min_date or MIN_VALID_DATE + if parsed < effective_min: + raise DateValidationError( + f"Date {parsed} is before minimum allowed date {effective_min}" + ) + + today = date.today() + + if not allow_future and parsed > today: + raise DateValidationError( + f"Date {parsed} is in the future. Future dates are not allowed." + ) + + effective_max = max_date or (today + timedelta(days=MAX_FUTURE_DAYS)) + if parsed > effective_max: + raise DateValidationError( + f"Date {parsed} is after maximum allowed date {effective_max}" + ) + + if not allow_weekend and parsed.weekday() >= 5: + day_name = "Saturday" if parsed.weekday() == 5 else "Sunday" + raise DateValidationError( + f"Date {parsed} falls on a {day_name}. Weekend dates are not allowed." + ) + + return parsed + + +def validate_date_range( + start_date: Union[str, date, datetime], + end_date: Union[str, date, datetime], + date_format: str = "%Y-%m-%d", + min_date: Optional[date] = None, + max_date: Optional[date] = None, + allow_future: bool = True, + max_range_days: Optional[int] = None, +) -> tuple[date, date]: + start = validate_date( + start_date, + date_format=date_format, + allow_none=False, + min_date=min_date, + max_date=max_date, + allow_future=allow_future, + ) + + end = validate_date( + end_date, + date_format=date_format, + allow_none=False, + min_date=min_date, + max_date=max_date, + allow_future=allow_future, + ) + + if end < start: + raise DateValidationError( + f"End date ({end}) must be on or after start date ({start})" + ) + + if end == start: + raise DateValidationError( + f"End date ({end}) must be after start date ({start})" + ) + + if max_range_days is not None: + range_days = (end - start).days + if range_days > max_range_days: + raise DateValidationError( + f"Date range of {range_days} days exceeds maximum of {max_range_days} days" + ) + + return start, end + + +def format_date( + date_input: Union[str, date, datetime], + output_format: str = "%Y-%m-%d", + input_format: str = "%Y-%m-%d", +) -> str: + parsed = parse_date(date_input, input_format) + if parsed is None: + raise DateValidationError("Cannot format None date") + return parsed.strftime(output_format) + + +def is_valid_ticker(ticker: str) -> bool: + try: + validate_ticker(ticker) + return True + except TickerValidationError: + return False + + +def is_valid_date( + date_input: Union[str, date, datetime], + date_format: str = "%Y-%m-%d", +) -> bool: + try: + validate_date(date_input, date_format=date_format) + return True + except DateValidationError: + return False + + +def is_trading_day(check_date: Union[str, date, datetime]) -> bool: + parsed = parse_date(check_date) + if parsed is None: + return False + return parsed.weekday() < 5 + + +def get_previous_trading_day(from_date: Union[str, date, datetime, None] = None) -> date: + if from_date is None: + check = date.today() + else: + check = parse_date(from_date) + if check is None: + check = date.today() + + while check.weekday() >= 5: + check = check - timedelta(days=1) + + return check + + +def get_next_trading_day(from_date: Union[str, date, datetime, None] = None) -> date: + if from_date is None: + check = date.today() + else: + check = parse_date(from_date) + if check is None: + check = date.today() + + while check.weekday() >= 5: + check = check + timedelta(days=1) + + return check