refactor: address code review feedback on look-ahead bias fix

- Extract repeated fiscal-date filtering into _filter_reports_by_date()
  helper in alpha_vantage_fundamentals.py (was duplicated 3×)
- Fix return type annotations: get_balance_sheet/get_cashflow/
  get_income_statement now declare Union[dict, str] instead of str,
  matching the actual return value from _make_api_request
- Extract pandas column filtering into _filter_financials_by_date()
  helper in y_finance.py (was duplicated 3×); uses pd.to_datetime +
  vectorised boolean mask instead of a Python list comprehension

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude Lab 2026-03-29 07:51:10 +02:00
parent abd13c0153
commit e87ba8f29d
2 changed files with 51 additions and 37 deletions

View File

@ -1,6 +1,24 @@
from typing import Union
from .alpha_vantage_common import _make_api_request from .alpha_vantage_common import _make_api_request
def _filter_reports_by_date(result: dict, curr_date: str) -> dict:
"""Remove annualReports/quarterlyReports whose fiscalDateEnding exceeds curr_date.
Mutates *result* in-place and returns it so callers can chain the call.
"""
if not (curr_date and isinstance(result, dict)):
return result
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [
r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date
]
return result
def get_fundamentals(ticker: str, curr_date: str = None) -> str: def get_fundamentals(ticker: str, curr_date: str = None) -> str:
""" """
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage. Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
@ -19,7 +37,9 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str:
return _make_api_request("OVERVIEW", params) return _make_api_request("OVERVIEW", params)
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: def get_balance_sheet(
ticker: str, freq: str = "quarterly", curr_date: str = None
) -> Union[dict, str]:
""" """
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage. Retrieve balance sheet data for a given ticker symbol using Alpha Vantage.
@ -29,23 +49,19 @@ def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = Non
curr_date (str): Current date you are trading at, yyyy-mm-dd curr_date (str): Current date you are trading at, yyyy-mm-dd
Returns: Returns:
str: Balance sheet data with normalized fields dict | str: Balance sheet data dict, or an error string.
""" """
params = { params = {
"symbol": ticker, "symbol": ticker,
} }
result = _make_api_request("BALANCE_SHEET", params) result = _make_api_request("BALANCE_SHEET", params)
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias. return _filter_reports_by_date(result, curr_date)
if curr_date and isinstance(result, dict):
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date]
return result
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: def get_cashflow(
ticker: str, freq: str = "quarterly", curr_date: str = None
) -> Union[dict, str]:
""" """
Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage. Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage.
@ -55,23 +71,19 @@ def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) ->
curr_date (str): Current date you are trading at, yyyy-mm-dd curr_date (str): Current date you are trading at, yyyy-mm-dd
Returns: Returns:
str: Cash flow statement data with normalized fields dict | str: Cash flow statement data dict, or an error string.
""" """
params = { params = {
"symbol": ticker, "symbol": ticker,
} }
result = _make_api_request("CASH_FLOW", params) result = _make_api_request("CASH_FLOW", params)
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias. return _filter_reports_by_date(result, curr_date)
if curr_date and isinstance(result, dict):
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date]
return result
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: def get_income_statement(
ticker: str, freq: str = "quarterly", curr_date: str = None
) -> Union[dict, str]:
""" """
Retrieve income statement data for a given ticker symbol using Alpha Vantage. Retrieve income statement data for a given ticker symbol using Alpha Vantage.
@ -81,18 +93,11 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str =
curr_date (str): Current date you are trading at, yyyy-mm-dd curr_date (str): Current date you are trading at, yyyy-mm-dd
Returns: Returns:
str: Income statement data with normalized fields dict | str: Income statement data dict, or an error string.
""" """
params = { params = {
"symbol": ticker, "symbol": ticker,
} }
result = _make_api_request("INCOME_STATEMENT", params) result = _make_api_request("INCOME_STATEMENT", params)
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias. return _filter_reports_by_date(result, curr_date)
if curr_date and isinstance(result, dict):
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date]
return result

View File

@ -1,10 +1,25 @@
from typing import Annotated from typing import Annotated
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
import pandas as pd
import yfinance as yf import yfinance as yf
import os import os
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
def _filter_financials_by_date(data: "pd.DataFrame", curr_date: str) -> "pd.DataFrame":
"""Drop DataFrame columns (fiscal period timestamps) that exceed curr_date.
yfinance financial statements are indexed by metric name (rows) and fiscal
period end date (columns). Columns that post-date the simulation's current
date represent future data and must be removed to prevent look-ahead bias.
"""
if not curr_date or data.empty:
return data
cutoff = pd.Timestamp(curr_date)
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
return data.loc[:, mask]
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
@ -365,9 +380,7 @@ def get_balance_sheet(
data = yf_retry(lambda: ticker_obj.balance_sheet) data = yf_retry(lambda: ticker_obj.balance_sheet)
# Filter out fiscal periods after curr_date to prevent look-ahead bias. # Filter out fiscal periods after curr_date to prevent look-ahead bias.
if curr_date and not data.empty: data = _filter_financials_by_date(data, curr_date)
cutoff = pd.Timestamp(curr_date)
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
if data.empty: if data.empty:
return f"No balance sheet data found for symbol '{ticker}'" + ( return f"No balance sheet data found for symbol '{ticker}'" + (
@ -402,9 +415,7 @@ def get_cashflow(
data = yf_retry(lambda: ticker_obj.cashflow) data = yf_retry(lambda: ticker_obj.cashflow)
# Filter out fiscal periods after curr_date to prevent look-ahead bias. # Filter out fiscal periods after curr_date to prevent look-ahead bias.
if curr_date and not data.empty: data = _filter_financials_by_date(data, curr_date)
cutoff = pd.Timestamp(curr_date)
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
if data.empty: if data.empty:
return f"No cash flow data found for symbol '{ticker}'" + ( return f"No cash flow data found for symbol '{ticker}'" + (
@ -439,9 +450,7 @@ def get_income_statement(
data = yf_retry(lambda: ticker_obj.income_stmt) data = yf_retry(lambda: ticker_obj.income_stmt)
# Filter out fiscal periods after curr_date to prevent look-ahead bias. # Filter out fiscal periods after curr_date to prevent look-ahead bias.
if curr_date and not data.empty: data = _filter_financials_by_date(data, curr_date)
cutoff = pd.Timestamp(curr_date)
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
if data.empty: if data.empty:
return f"No income statement data found for symbol '{ticker}'" + ( return f"No income statement data found for symbol '{ticker}'" + (