feat: add centralized input validation for ticker symbols and dates
- Create tradingagents/validation.py with comprehensive validation functions: - validate_ticker/validate_tickers for ticker symbol validation - validate_date/validate_date_range for date validation - parse_date for flexible date parsing (multiple formats) - Helper functions: is_valid_ticker, is_valid_date, is_trading_day - get_previous_trading_day, get_next_trading_day utilities - Add validation to y_finance.py (get_YFin_data_online, get_stock_stats_indicators_window) - Add validation to trading_graph.py propagate method - Create comprehensive test suite (55 tests) in tests/test_validation.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
f70874982a
commit
e862c4f803
|
|
@ -0,0 +1,275 @@
|
|||
import pytest
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
from tradingagents.validation import (
|
||||
ValidationError,
|
||||
TickerValidationError,
|
||||
DateValidationError,
|
||||
validate_ticker,
|
||||
validate_tickers,
|
||||
parse_date,
|
||||
validate_date,
|
||||
validate_date_range,
|
||||
format_date,
|
||||
is_valid_ticker,
|
||||
is_valid_date,
|
||||
is_trading_day,
|
||||
get_previous_trading_day,
|
||||
get_next_trading_day,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateTicker:
|
||||
def test_valid_simple_ticker(self):
|
||||
assert validate_ticker("AAPL") == "AAPL"
|
||||
assert validate_ticker("A") == "A"
|
||||
assert validate_ticker("GOOGL") == "GOOGL"
|
||||
|
||||
def test_valid_ticker_lowercase_converted(self):
|
||||
assert validate_ticker("aapl") == "AAPL"
|
||||
assert validate_ticker("Msft") == "MSFT"
|
||||
|
||||
def test_valid_ticker_with_whitespace(self):
|
||||
assert validate_ticker(" AAPL ") == "AAPL"
|
||||
assert validate_ticker("\tTSLA\n") == "TSLA"
|
||||
|
||||
def test_valid_ticker_with_class_indicator(self):
|
||||
assert validate_ticker("BRK-B") == "BRK-B"
|
||||
assert validate_ticker("BRK-A") == "BRK-A"
|
||||
assert validate_ticker("BRK.B") == "BRK.B"
|
||||
|
||||
def test_invalid_ticker_none(self):
|
||||
with pytest.raises(TickerValidationError, match="cannot be None"):
|
||||
validate_ticker(None)
|
||||
|
||||
def test_invalid_ticker_none_allowed(self):
|
||||
assert validate_ticker(None, allow_empty=True) == ""
|
||||
|
||||
def test_invalid_ticker_empty(self):
|
||||
with pytest.raises(TickerValidationError, match="cannot be empty"):
|
||||
validate_ticker("")
|
||||
|
||||
def test_invalid_ticker_empty_allowed(self):
|
||||
assert validate_ticker("", allow_empty=True) == ""
|
||||
|
||||
def test_invalid_ticker_wrong_type(self):
|
||||
with pytest.raises(TickerValidationError, match="must be a string"):
|
||||
validate_ticker(123)
|
||||
|
||||
def test_invalid_ticker_too_long(self):
|
||||
with pytest.raises(TickerValidationError, match="too long"):
|
||||
validate_ticker("VERYLONGTICKER")
|
||||
|
||||
def test_invalid_ticker_format(self):
|
||||
with pytest.raises(TickerValidationError, match="Invalid ticker format"):
|
||||
validate_ticker("AAPL123")
|
||||
|
||||
with pytest.raises(TickerValidationError, match="Invalid ticker format"):
|
||||
validate_ticker("AA-PL-B")
|
||||
|
||||
with pytest.raises(TickerValidationError, match="Invalid ticker format"):
|
||||
validate_ticker("A@PL")
|
||||
|
||||
|
||||
class TestValidateTickers:
|
||||
def test_valid_tickers_list(self):
|
||||
result = validate_tickers(["AAPL", "MSFT", "GOOGL"])
|
||||
assert result == ["AAPL", "MSFT", "GOOGL"]
|
||||
|
||||
def test_valid_tickers_lowercase_converted(self):
|
||||
result = validate_tickers(["aapl", "msft"])
|
||||
assert result == ["AAPL", "MSFT"]
|
||||
|
||||
def test_valid_tickers_tuple(self):
|
||||
result = validate_tickers(("AAPL", "MSFT"))
|
||||
assert result == ["AAPL", "MSFT"]
|
||||
|
||||
def test_invalid_tickers_none(self):
|
||||
with pytest.raises(TickerValidationError, match="cannot be None"):
|
||||
validate_tickers(None)
|
||||
|
||||
def test_invalid_tickers_none_allowed(self):
|
||||
assert validate_tickers(None, allow_empty_list=True) == []
|
||||
|
||||
def test_invalid_tickers_empty_list(self):
|
||||
with pytest.raises(TickerValidationError, match="cannot be empty"):
|
||||
validate_tickers([])
|
||||
|
||||
def test_invalid_tickers_empty_list_allowed(self):
|
||||
assert validate_tickers([], allow_empty_list=True) == []
|
||||
|
||||
def test_invalid_tickers_wrong_type(self):
|
||||
with pytest.raises(TickerValidationError, match="must be a list"):
|
||||
validate_tickers("AAPL")
|
||||
|
||||
def test_invalid_tickers_contains_invalid(self):
|
||||
with pytest.raises(TickerValidationError, match="Invalid tickers"):
|
||||
validate_tickers(["AAPL", "INVALID123", "MSFT"])
|
||||
|
||||
|
||||
class TestParseDate:
|
||||
def test_parse_date_string_default_format(self):
|
||||
result = parse_date("2024-01-15")
|
||||
assert result == date(2024, 1, 15)
|
||||
|
||||
def test_parse_date_string_various_formats(self):
|
||||
assert parse_date("2024/01/15") == date(2024, 1, 15)
|
||||
assert parse_date("01/15/2024") == date(2024, 1, 15)
|
||||
assert parse_date("20240115") == date(2024, 1, 15)
|
||||
|
||||
def test_parse_date_from_date(self):
|
||||
d = date(2024, 1, 15)
|
||||
assert parse_date(d) == d
|
||||
|
||||
def test_parse_date_from_datetime(self):
|
||||
dt = datetime(2024, 1, 15, 10, 30)
|
||||
assert parse_date(dt) == date(2024, 1, 15)
|
||||
|
||||
def test_parse_date_none(self):
|
||||
assert parse_date(None) is None
|
||||
|
||||
def test_parse_date_empty_string(self):
|
||||
assert parse_date("") is None
|
||||
assert parse_date(" ") is None
|
||||
|
||||
def test_parse_date_invalid_format(self):
|
||||
with pytest.raises(DateValidationError, match="Could not parse date"):
|
||||
parse_date("not-a-date")
|
||||
|
||||
def test_parse_date_invalid_type(self):
|
||||
with pytest.raises(DateValidationError, match="must be string"):
|
||||
parse_date(123)
|
||||
|
||||
|
||||
class TestValidateDate:
|
||||
def test_validate_date_valid(self):
|
||||
result = validate_date("2024-01-15")
|
||||
assert result == date(2024, 1, 15)
|
||||
|
||||
def test_validate_date_none_not_allowed(self):
|
||||
with pytest.raises(DateValidationError, match="cannot be None"):
|
||||
validate_date(None)
|
||||
|
||||
def test_validate_date_none_allowed(self):
|
||||
assert validate_date(None, allow_none=True) is None
|
||||
|
||||
def test_validate_date_before_min(self):
|
||||
with pytest.raises(DateValidationError, match="before minimum"):
|
||||
validate_date("1960-01-01")
|
||||
|
||||
def test_validate_date_custom_min(self):
|
||||
with pytest.raises(DateValidationError, match="before minimum"):
|
||||
validate_date("2020-01-01", min_date=date(2021, 1, 1))
|
||||
|
||||
def test_validate_date_future_not_allowed(self):
|
||||
future = date.today() + timedelta(days=30)
|
||||
with pytest.raises(DateValidationError, match="in the future"):
|
||||
validate_date(future.strftime("%Y-%m-%d"), allow_future=False)
|
||||
|
||||
def test_validate_date_future_allowed(self):
|
||||
future = date.today() + timedelta(days=30)
|
||||
result = validate_date(future.strftime("%Y-%m-%d"), allow_future=True)
|
||||
assert result == future
|
||||
|
||||
def test_validate_date_after_max(self):
|
||||
far_future = date.today() + timedelta(days=500)
|
||||
with pytest.raises(DateValidationError, match="after maximum"):
|
||||
validate_date(far_future.strftime("%Y-%m-%d"))
|
||||
|
||||
def test_validate_date_weekend_not_allowed(self):
|
||||
saturday = date(2024, 1, 6)
|
||||
with pytest.raises(DateValidationError, match="Saturday"):
|
||||
validate_date(saturday, allow_weekend=False)
|
||||
|
||||
sunday = date(2024, 1, 7)
|
||||
with pytest.raises(DateValidationError, match="Sunday"):
|
||||
validate_date(sunday, allow_weekend=False)
|
||||
|
||||
def test_validate_date_weekend_allowed(self):
|
||||
saturday = date(2024, 1, 6)
|
||||
result = validate_date(saturday, allow_weekend=True)
|
||||
assert result == saturday
|
||||
|
||||
|
||||
class TestValidateDateRange:
|
||||
def test_validate_date_range_valid(self):
|
||||
start, end = validate_date_range("2024-01-01", "2024-01-31")
|
||||
assert start == date(2024, 1, 1)
|
||||
assert end == date(2024, 1, 31)
|
||||
|
||||
def test_validate_date_range_end_before_start(self):
|
||||
with pytest.raises(DateValidationError, match="must be on or after"):
|
||||
validate_date_range("2024-01-31", "2024-01-01")
|
||||
|
||||
def test_validate_date_range_same_day(self):
|
||||
with pytest.raises(DateValidationError, match="must be after"):
|
||||
validate_date_range("2024-01-15", "2024-01-15")
|
||||
|
||||
def test_validate_date_range_max_range(self):
|
||||
with pytest.raises(DateValidationError, match="exceeds maximum"):
|
||||
validate_date_range("2020-01-01", "2024-01-01", max_range_days=365)
|
||||
|
||||
|
||||
class TestFormatDate:
|
||||
def test_format_date_default(self):
|
||||
result = format_date(date(2024, 1, 15))
|
||||
assert result == "2024-01-15"
|
||||
|
||||
def test_format_date_custom_format(self):
|
||||
result = format_date(date(2024, 1, 15), output_format="%m/%d/%Y")
|
||||
assert result == "01/15/2024"
|
||||
|
||||
def test_format_date_from_string(self):
|
||||
result = format_date("2024-01-15", output_format="%Y%m%d")
|
||||
assert result == "20240115"
|
||||
|
||||
def test_format_date_none(self):
|
||||
with pytest.raises(DateValidationError, match="Cannot format None"):
|
||||
format_date(None)
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
def test_is_valid_ticker(self):
|
||||
assert is_valid_ticker("AAPL") is True
|
||||
assert is_valid_ticker("INVALID123") is False
|
||||
assert is_valid_ticker("") is False
|
||||
|
||||
def test_is_valid_date(self):
|
||||
assert is_valid_date("2024-01-15") is True
|
||||
assert is_valid_date("not-a-date") is False
|
||||
assert is_valid_date(date(2024, 1, 15)) is True
|
||||
|
||||
def test_is_trading_day(self):
|
||||
assert is_trading_day(date(2024, 1, 15)) is True
|
||||
assert is_trading_day(date(2024, 1, 13)) is False
|
||||
assert is_trading_day("2024-01-15") is True
|
||||
|
||||
def test_get_previous_trading_day_from_monday(self):
|
||||
monday = date(2024, 1, 15)
|
||||
result = get_previous_trading_day(monday)
|
||||
assert result == monday
|
||||
|
||||
def test_get_previous_trading_day_from_saturday(self):
|
||||
saturday = date(2024, 1, 13)
|
||||
result = get_previous_trading_day(saturday)
|
||||
assert result == date(2024, 1, 12)
|
||||
|
||||
def test_get_previous_trading_day_from_sunday(self):
|
||||
sunday = date(2024, 1, 14)
|
||||
result = get_previous_trading_day(sunday)
|
||||
assert result == date(2024, 1, 12)
|
||||
|
||||
def test_get_next_trading_day_from_friday(self):
|
||||
friday = date(2024, 1, 12)
|
||||
result = get_next_trading_day(friday)
|
||||
assert result == friday
|
||||
|
||||
def test_get_next_trading_day_from_saturday(self):
|
||||
saturday = date(2024, 1, 13)
|
||||
result = get_next_trading_day(saturday)
|
||||
assert result == date(2024, 1, 15)
|
||||
|
||||
def test_get_next_trading_day_from_sunday(self):
|
||||
sunday = date(2024, 1, 14)
|
||||
result = get_next_trading_day(sunday)
|
||||
assert result == date(2024, 1, 15)
|
||||
|
|
@ -5,6 +5,7 @@ from dateutil.relativedelta import relativedelta
|
|||
import yfinance as yf
|
||||
import os
|
||||
from .stockstats_utils import StockstatsUtils
|
||||
from tradingagents.validation import validate_ticker, validate_date_range, validate_date
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -13,11 +14,12 @@ def get_YFin_data_online(
|
|||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
):
|
||||
symbol = validate_ticker(symbol)
|
||||
start, end = validate_date_range(start_date, end_date)
|
||||
start_date = start.strftime("%Y-%m-%d")
|
||||
end_date = end.strftime("%Y-%m-%d")
|
||||
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
ticker = yf.Ticker(symbol.upper())
|
||||
ticker = yf.Ticker(symbol)
|
||||
|
||||
data = ticker.history(start=start_date, end=end_date)
|
||||
|
||||
|
|
@ -50,6 +52,9 @@ def get_stock_stats_indicators_window(
|
|||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
) -> str:
|
||||
symbol = validate_ticker(symbol)
|
||||
validated_date = validate_date(curr_date, allow_future=False)
|
||||
curr_date = validated_date.strftime("%Y-%m-%d")
|
||||
|
||||
best_ind_params = {
|
||||
"close_50_sma": (
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from tradingagents.agents.discovery import (
|
|||
calculate_trending_scores,
|
||||
)
|
||||
from tradingagents.dataflows.interface import get_bulk_news
|
||||
from tradingagents.validation import validate_ticker, validate_date, parse_date
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
|
|
@ -155,6 +156,11 @@ class TradingAgentsGraph:
|
|||
}
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
company_name = validate_ticker(company_name)
|
||||
validated_date = validate_date(trade_date, allow_future=False)
|
||||
if isinstance(trade_date, str):
|
||||
trade_date = validated_date
|
||||
|
||||
self.ticker = company_name
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
|
|
|
|||
|
|
@ -0,0 +1,295 @@
|
|||
import logging
|
||||
import re
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TICKER_PATTERN = re.compile(r"^[A-Z]{1,5}(-[A-Z]{1,2})?$")
|
||||
|
||||
TICKER_SPECIAL_PATTERN = re.compile(r"^[A-Z]{1,5}(\.[A-Z]{1,2})?$")
|
||||
|
||||
MIN_VALID_DATE = date(1970, 1, 1)
|
||||
MAX_FUTURE_DAYS = 365
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TickerValidationError(ValidationError):
|
||||
pass
|
||||
|
||||
|
||||
class DateValidationError(ValidationError):
|
||||
pass
|
||||
|
||||
|
||||
def validate_ticker(
|
||||
ticker: str,
|
||||
allow_empty: bool = False,
|
||||
check_format_only: bool = True,
|
||||
) -> str:
|
||||
if ticker is None:
|
||||
if allow_empty:
|
||||
return ""
|
||||
raise TickerValidationError("Ticker cannot be None")
|
||||
|
||||
if not isinstance(ticker, str):
|
||||
raise TickerValidationError(f"Ticker must be a string, got {type(ticker).__name__}")
|
||||
|
||||
ticker = ticker.strip().upper()
|
||||
|
||||
if not ticker:
|
||||
if allow_empty:
|
||||
return ""
|
||||
raise TickerValidationError("Ticker cannot be empty")
|
||||
|
||||
if len(ticker) > 10:
|
||||
raise TickerValidationError(f"Ticker '{ticker}' is too long (max 10 characters)")
|
||||
|
||||
if not TICKER_PATTERN.match(ticker) and not TICKER_SPECIAL_PATTERN.match(ticker):
|
||||
raise TickerValidationError(
|
||||
f"Invalid ticker format '{ticker}'. Must be 1-5 uppercase letters, "
|
||||
"optionally followed by a class indicator (e.g., BRK-B, BRK.A)"
|
||||
)
|
||||
|
||||
return ticker
|
||||
|
||||
|
||||
def validate_tickers(
|
||||
tickers: list[str],
|
||||
allow_empty_list: bool = False,
|
||||
check_format_only: bool = True,
|
||||
) -> list[str]:
|
||||
if tickers is None:
|
||||
if allow_empty_list:
|
||||
return []
|
||||
raise TickerValidationError("Tickers list cannot be None")
|
||||
|
||||
if not isinstance(tickers, (list, tuple)):
|
||||
raise TickerValidationError(f"Tickers must be a list, got {type(tickers).__name__}")
|
||||
|
||||
if not tickers:
|
||||
if allow_empty_list:
|
||||
return []
|
||||
raise TickerValidationError("Tickers list cannot be empty")
|
||||
|
||||
validated = []
|
||||
errors = []
|
||||
|
||||
for i, ticker in enumerate(tickers):
|
||||
try:
|
||||
validated.append(validate_ticker(ticker, check_format_only=check_format_only))
|
||||
except TickerValidationError as e:
|
||||
errors.append(f"Index {i}: {e}")
|
||||
|
||||
if errors:
|
||||
raise TickerValidationError(f"Invalid tickers: {'; '.join(errors)}")
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
def parse_date(
|
||||
date_input: Union[str, date, datetime, None],
|
||||
date_format: str = "%Y-%m-%d",
|
||||
) -> Optional[date]:
|
||||
if date_input is None:
|
||||
return None
|
||||
|
||||
if isinstance(date_input, datetime):
|
||||
return date_input.date()
|
||||
|
||||
if isinstance(date_input, date):
|
||||
return date_input
|
||||
|
||||
if not isinstance(date_input, str):
|
||||
raise DateValidationError(f"Date must be string, date, or datetime, got {type(date_input).__name__}")
|
||||
|
||||
date_input = date_input.strip()
|
||||
|
||||
if not date_input:
|
||||
return None
|
||||
|
||||
formats_to_try = [
|
||||
date_format,
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%m/%d/%Y",
|
||||
"%m-%d-%Y",
|
||||
"%d/%m/%Y",
|
||||
"%d-%m-%Y",
|
||||
"%Y%m%d",
|
||||
]
|
||||
|
||||
for fmt in formats_to_try:
|
||||
try:
|
||||
return datetime.strptime(date_input, fmt).date()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise DateValidationError(
|
||||
f"Could not parse date '{date_input}'. Expected format: {date_format} "
|
||||
f"(e.g., {datetime.now().strftime(date_format)})"
|
||||
)
|
||||
|
||||
|
||||
def validate_date(
|
||||
date_input: Union[str, date, datetime, None],
|
||||
date_format: str = "%Y-%m-%d",
|
||||
allow_none: bool = False,
|
||||
min_date: Optional[date] = None,
|
||||
max_date: Optional[date] = None,
|
||||
allow_future: bool = True,
|
||||
allow_weekend: bool = True,
|
||||
) -> Optional[date]:
|
||||
if date_input is None:
|
||||
if allow_none:
|
||||
return None
|
||||
raise DateValidationError("Date cannot be None")
|
||||
|
||||
parsed = parse_date(date_input, date_format)
|
||||
|
||||
if parsed is None:
|
||||
if allow_none:
|
||||
return None
|
||||
raise DateValidationError("Date cannot be empty")
|
||||
|
||||
effective_min = min_date or MIN_VALID_DATE
|
||||
if parsed < effective_min:
|
||||
raise DateValidationError(
|
||||
f"Date {parsed} is before minimum allowed date {effective_min}"
|
||||
)
|
||||
|
||||
today = date.today()
|
||||
|
||||
if not allow_future and parsed > today:
|
||||
raise DateValidationError(
|
||||
f"Date {parsed} is in the future. Future dates are not allowed."
|
||||
)
|
||||
|
||||
effective_max = max_date or (today + timedelta(days=MAX_FUTURE_DAYS))
|
||||
if parsed > effective_max:
|
||||
raise DateValidationError(
|
||||
f"Date {parsed} is after maximum allowed date {effective_max}"
|
||||
)
|
||||
|
||||
if not allow_weekend and parsed.weekday() >= 5:
|
||||
day_name = "Saturday" if parsed.weekday() == 5 else "Sunday"
|
||||
raise DateValidationError(
|
||||
f"Date {parsed} falls on a {day_name}. Weekend dates are not allowed."
|
||||
)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def validate_date_range(
|
||||
start_date: Union[str, date, datetime],
|
||||
end_date: Union[str, date, datetime],
|
||||
date_format: str = "%Y-%m-%d",
|
||||
min_date: Optional[date] = None,
|
||||
max_date: Optional[date] = None,
|
||||
allow_future: bool = True,
|
||||
max_range_days: Optional[int] = None,
|
||||
) -> tuple[date, date]:
|
||||
start = validate_date(
|
||||
start_date,
|
||||
date_format=date_format,
|
||||
allow_none=False,
|
||||
min_date=min_date,
|
||||
max_date=max_date,
|
||||
allow_future=allow_future,
|
||||
)
|
||||
|
||||
end = validate_date(
|
||||
end_date,
|
||||
date_format=date_format,
|
||||
allow_none=False,
|
||||
min_date=min_date,
|
||||
max_date=max_date,
|
||||
allow_future=allow_future,
|
||||
)
|
||||
|
||||
if end < start:
|
||||
raise DateValidationError(
|
||||
f"End date ({end}) must be on or after start date ({start})"
|
||||
)
|
||||
|
||||
if end == start:
|
||||
raise DateValidationError(
|
||||
f"End date ({end}) must be after start date ({start})"
|
||||
)
|
||||
|
||||
if max_range_days is not None:
|
||||
range_days = (end - start).days
|
||||
if range_days > max_range_days:
|
||||
raise DateValidationError(
|
||||
f"Date range of {range_days} days exceeds maximum of {max_range_days} days"
|
||||
)
|
||||
|
||||
return start, end
|
||||
|
||||
|
||||
def format_date(
|
||||
date_input: Union[str, date, datetime],
|
||||
output_format: str = "%Y-%m-%d",
|
||||
input_format: str = "%Y-%m-%d",
|
||||
) -> str:
|
||||
parsed = parse_date(date_input, input_format)
|
||||
if parsed is None:
|
||||
raise DateValidationError("Cannot format None date")
|
||||
return parsed.strftime(output_format)
|
||||
|
||||
|
||||
def is_valid_ticker(ticker: str) -> bool:
|
||||
try:
|
||||
validate_ticker(ticker)
|
||||
return True
|
||||
except TickerValidationError:
|
||||
return False
|
||||
|
||||
|
||||
def is_valid_date(
|
||||
date_input: Union[str, date, datetime],
|
||||
date_format: str = "%Y-%m-%d",
|
||||
) -> bool:
|
||||
try:
|
||||
validate_date(date_input, date_format=date_format)
|
||||
return True
|
||||
except DateValidationError:
|
||||
return False
|
||||
|
||||
|
||||
def is_trading_day(check_date: Union[str, date, datetime]) -> bool:
|
||||
parsed = parse_date(check_date)
|
||||
if parsed is None:
|
||||
return False
|
||||
return parsed.weekday() < 5
|
||||
|
||||
|
||||
def get_previous_trading_day(from_date: Union[str, date, datetime, None] = None) -> date:
|
||||
if from_date is None:
|
||||
check = date.today()
|
||||
else:
|
||||
check = parse_date(from_date)
|
||||
if check is None:
|
||||
check = date.today()
|
||||
|
||||
while check.weekday() >= 5:
|
||||
check = check - timedelta(days=1)
|
||||
|
||||
return check
|
||||
|
||||
|
||||
def get_next_trading_day(from_date: Union[str, date, datetime, None] = None) -> date:
|
||||
if from_date is None:
|
||||
check = date.today()
|
||||
else:
|
||||
check = parse_date(from_date)
|
||||
if check is None:
|
||||
check = date.today()
|
||||
|
||||
while check.weekday() >= 5:
|
||||
check = check + timedelta(days=1)
|
||||
|
||||
return check
|
||||
Loading…
Reference in New Issue