refactor: address code review feedback on look-ahead bias fix
- Extract repeated fiscal-date filtering into _filter_reports_by_date() helper in alpha_vantage_fundamentals.py (was duplicated 3×) - Fix return type annotations: get_balance_sheet/get_cashflow/ get_income_statement now declare Union[dict, str] instead of str, matching the actual return value from _make_api_request - Extract pandas column filtering into _filter_financials_by_date() helper in y_finance.py (was duplicated 3×); uses pd.to_datetime + vectorised boolean mask instead of a Python list comprehension Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
abd13c0153
commit
e87ba8f29d
|
|
@ -1,6 +1,24 @@
|
|||
from typing import Union
|
||||
|
||||
from .alpha_vantage_common import _make_api_request
|
||||
|
||||
|
||||
def _filter_reports_by_date(result: dict, curr_date: str) -> dict:
|
||||
"""Remove annualReports/quarterlyReports whose fiscalDateEnding exceeds curr_date.
|
||||
|
||||
Mutates *result* in-place and returns it so callers can chain the call.
|
||||
"""
|
||||
if not (curr_date and isinstance(result, dict)):
|
||||
return result
|
||||
for key in ("annualReports", "quarterlyReports"):
|
||||
if key in result:
|
||||
result[key] = [
|
||||
r for r in result[key]
|
||||
if r.get("fiscalDateEnding", "") <= curr_date
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
||||
"""
|
||||
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
|
||||
|
|
@ -19,7 +37,9 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
|||
return _make_api_request("OVERVIEW", params)
|
||||
|
||||
|
||||
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
def get_balance_sheet(
|
||||
ticker: str, freq: str = "quarterly", curr_date: str = None
|
||||
) -> Union[dict, str]:
|
||||
"""
|
||||
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage.
|
||||
|
||||
|
|
@ -29,23 +49,19 @@ def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = Non
|
|||
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||
|
||||
Returns:
|
||||
str: Balance sheet data with normalized fields
|
||||
dict | str: Balance sheet data dict, or an error string.
|
||||
"""
|
||||
params = {
|
||||
"symbol": ticker,
|
||||
}
|
||||
|
||||
result = _make_api_request("BALANCE_SHEET", params)
|
||||
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
|
||||
if curr_date and isinstance(result, dict):
|
||||
for key in ("annualReports", "quarterlyReports"):
|
||||
if key in result:
|
||||
result[key] = [r for r in result[key]
|
||||
if r.get("fiscalDateEnding", "") <= curr_date]
|
||||
return result
|
||||
return _filter_reports_by_date(result, curr_date)
|
||||
|
||||
|
||||
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
def get_cashflow(
|
||||
ticker: str, freq: str = "quarterly", curr_date: str = None
|
||||
) -> Union[dict, str]:
|
||||
"""
|
||||
Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage.
|
||||
|
||||
|
|
@ -55,23 +71,19 @@ def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) ->
|
|||
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||
|
||||
Returns:
|
||||
str: Cash flow statement data with normalized fields
|
||||
dict | str: Cash flow statement data dict, or an error string.
|
||||
"""
|
||||
params = {
|
||||
"symbol": ticker,
|
||||
}
|
||||
|
||||
result = _make_api_request("CASH_FLOW", params)
|
||||
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
|
||||
if curr_date and isinstance(result, dict):
|
||||
for key in ("annualReports", "quarterlyReports"):
|
||||
if key in result:
|
||||
result[key] = [r for r in result[key]
|
||||
if r.get("fiscalDateEnding", "") <= curr_date]
|
||||
return result
|
||||
return _filter_reports_by_date(result, curr_date)
|
||||
|
||||
|
||||
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
def get_income_statement(
|
||||
ticker: str, freq: str = "quarterly", curr_date: str = None
|
||||
) -> Union[dict, str]:
|
||||
"""
|
||||
Retrieve income statement data for a given ticker symbol using Alpha Vantage.
|
||||
|
||||
|
|
@ -81,18 +93,11 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str =
|
|||
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||
|
||||
Returns:
|
||||
str: Income statement data with normalized fields
|
||||
dict | str: Income statement data dict, or an error string.
|
||||
"""
|
||||
params = {
|
||||
"symbol": ticker,
|
||||
}
|
||||
|
||||
result = _make_api_request("INCOME_STATEMENT", params)
|
||||
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
|
||||
if curr_date and isinstance(result, dict):
|
||||
for key in ("annualReports", "quarterlyReports"):
|
||||
if key in result:
|
||||
result[key] = [r for r in result[key]
|
||||
if r.get("fiscalDateEnding", "") <= curr_date]
|
||||
return result
|
||||
|
||||
return _filter_reports_by_date(result, curr_date)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,25 @@
|
|||
from typing import Annotated
|
||||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
import os
|
||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
|
||||
|
||||
|
||||
def _filter_financials_by_date(data: "pd.DataFrame", curr_date: str) -> "pd.DataFrame":
|
||||
"""Drop DataFrame columns (fiscal period timestamps) that exceed curr_date.
|
||||
|
||||
yfinance financial statements are indexed by metric name (rows) and fiscal
|
||||
period end date (columns). Columns that post-date the simulation's current
|
||||
date represent future data and must be removed to prevent look-ahead bias.
|
||||
"""
|
||||
if not curr_date or data.empty:
|
||||
return data
|
||||
cutoff = pd.Timestamp(curr_date)
|
||||
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
|
||||
return data.loc[:, mask]
|
||||
|
||||
def get_YFin_data_online(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
|
|
@ -365,9 +380,7 @@ def get_balance_sheet(
|
|||
data = yf_retry(lambda: ticker_obj.balance_sheet)
|
||||
|
||||
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
||||
if curr_date and not data.empty:
|
||||
cutoff = pd.Timestamp(curr_date)
|
||||
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
|
||||
data = _filter_financials_by_date(data, curr_date)
|
||||
|
||||
if data.empty:
|
||||
return f"No balance sheet data found for symbol '{ticker}'" + (
|
||||
|
|
@ -402,9 +415,7 @@ def get_cashflow(
|
|||
data = yf_retry(lambda: ticker_obj.cashflow)
|
||||
|
||||
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
||||
if curr_date and not data.empty:
|
||||
cutoff = pd.Timestamp(curr_date)
|
||||
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
|
||||
data = _filter_financials_by_date(data, curr_date)
|
||||
|
||||
if data.empty:
|
||||
return f"No cash flow data found for symbol '{ticker}'" + (
|
||||
|
|
@ -439,9 +450,7 @@ def get_income_statement(
|
|||
data = yf_retry(lambda: ticker_obj.income_stmt)
|
||||
|
||||
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
||||
if curr_date and not data.empty:
|
||||
cutoff = pd.Timestamp(curr_date)
|
||||
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
|
||||
data = _filter_financials_by_date(data, curr_date)
|
||||
|
||||
if data.empty:
|
||||
return f"No income statement data found for symbol '{ticker}'" + (
|
||||
|
|
|
|||
Loading…
Reference in New Issue