refactor: address code review feedback on look-ahead bias fix

- Extract repeated fiscal-date filtering into _filter_reports_by_date()
  helper in alpha_vantage_fundamentals.py (was duplicated 3×)
- Fix return type annotations: get_balance_sheet/get_cashflow/
  get_income_statement now declare Union[dict, str] instead of str,
  matching the actual return value from _make_api_request
- Extract pandas column filtering into _filter_financials_by_date()
  helper in y_finance.py (was duplicated 3×); uses pd.to_datetime +
  vectorised boolean mask instead of a Python list comprehension

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude Lab 2026-03-29 07:51:10 +02:00
parent abd13c0153
commit e87ba8f29d
2 changed files with 51 additions and 37 deletions

View File

@ -1,6 +1,24 @@
from typing import Union
from .alpha_vantage_common import _make_api_request
def _filter_reports_by_date(result: dict, curr_date: str) -> dict:
"""Remove annualReports/quarterlyReports whose fiscalDateEnding exceeds curr_date.
Mutates *result* in-place and returns it so callers can chain the call.
"""
if not (curr_date and isinstance(result, dict)):
return result
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [
r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date
]
return result
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
"""
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
@ -19,7 +37,9 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str:
return _make_api_request("OVERVIEW", params)
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
def get_balance_sheet(
ticker: str, freq: str = "quarterly", curr_date: str = None
) -> Union[dict, str]:
"""
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage.
@ -29,23 +49,19 @@ def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = Non
curr_date (str): Current date you are trading at, yyyy-mm-dd
Returns:
str: Balance sheet data with normalized fields
dict | str: Balance sheet data dict, or an error string.
"""
params = {
"symbol": ticker,
}
result = _make_api_request("BALANCE_SHEET", params)
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
if curr_date and isinstance(result, dict):
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date]
return result
return _filter_reports_by_date(result, curr_date)
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
def get_cashflow(
ticker: str, freq: str = "quarterly", curr_date: str = None
) -> Union[dict, str]:
"""
Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage.
@ -55,23 +71,19 @@ def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) ->
curr_date (str): Current date you are trading at, yyyy-mm-dd
Returns:
str: Cash flow statement data with normalized fields
dict | str: Cash flow statement data dict, or an error string.
"""
params = {
"symbol": ticker,
}
result = _make_api_request("CASH_FLOW", params)
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
if curr_date and isinstance(result, dict):
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date]
return result
return _filter_reports_by_date(result, curr_date)
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
def get_income_statement(
ticker: str, freq: str = "quarterly", curr_date: str = None
) -> Union[dict, str]:
"""
Retrieve income statement data for a given ticker symbol using Alpha Vantage.
@ -81,18 +93,11 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str =
curr_date (str): Current date you are trading at, yyyy-mm-dd
Returns:
str: Income statement data with normalized fields
dict | str: Income statement data dict, or an error string.
"""
params = {
"symbol": ticker,
}
result = _make_api_request("INCOME_STATEMENT", params)
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
if curr_date and isinstance(result, dict):
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date]
return result
return _filter_reports_by_date(result, curr_date)

View File

@ -1,10 +1,25 @@
from typing import Annotated
from datetime import datetime
from dateutil.relativedelta import relativedelta
import pandas as pd
import yfinance as yf
import os
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
def _filter_financials_by_date(data: "pd.DataFrame", curr_date: str) -> "pd.DataFrame":
"""Drop DataFrame columns (fiscal period timestamps) that exceed curr_date.
yfinance financial statements are indexed by metric name (rows) and fiscal
period end date (columns). Columns that post-date the simulation's current
date represent future data and must be removed to prevent look-ahead bias.
"""
if not curr_date or data.empty:
return data
cutoff = pd.Timestamp(curr_date)
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
return data.loc[:, mask]
def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
@ -365,9 +380,7 @@ def get_balance_sheet(
data = yf_retry(lambda: ticker_obj.balance_sheet)
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
if curr_date and not data.empty:
cutoff = pd.Timestamp(curr_date)
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
data = _filter_financials_by_date(data, curr_date)
if data.empty:
return f"No balance sheet data found for symbol '{ticker}'" + (
@ -402,9 +415,7 @@ def get_cashflow(
data = yf_retry(lambda: ticker_obj.cashflow)
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
if curr_date and not data.empty:
cutoff = pd.Timestamp(curr_date)
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
data = _filter_financials_by_date(data, curr_date)
if data.empty:
return f"No cash flow data found for symbol '{ticker}'" + (
@ -439,9 +450,7 @@ def get_income_statement(
data = yf_retry(lambda: ticker_obj.income_stmt)
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
if curr_date and not data.empty:
cutoff = pd.Timestamp(curr_date)
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
data = _filter_financials_by_date(data, curr_date)
if data.empty:
return f"No income statement data found for symbol '{ticker}'" + (