refactor: address code review feedback on look-ahead bias fix
- Extract repeated fiscal-date filtering into _filter_reports_by_date() helper in alpha_vantage_fundamentals.py (was duplicated 3×) - Fix return type annotations: get_balance_sheet/get_cashflow/ get_income_statement now declare Union[dict, str] instead of str, matching the actual return value from _make_api_request - Extract pandas column filtering into _filter_financials_by_date() helper in y_finance.py (was duplicated 3×); uses pd.to_datetime + vectorised boolean mask instead of a Python list comprehension Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
abd13c0153
commit
e87ba8f29d
|
|
@ -1,6 +1,24 @@
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from .alpha_vantage_common import _make_api_request
|
from .alpha_vantage_common import _make_api_request
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_reports_by_date(result: dict, curr_date: str) -> dict:
|
||||||
|
"""Remove annualReports/quarterlyReports whose fiscalDateEnding exceeds curr_date.
|
||||||
|
|
||||||
|
Mutates *result* in-place and returns it so callers can chain the call.
|
||||||
|
"""
|
||||||
|
if not (curr_date and isinstance(result, dict)):
|
||||||
|
return result
|
||||||
|
for key in ("annualReports", "quarterlyReports"):
|
||||||
|
if key in result:
|
||||||
|
result[key] = [
|
||||||
|
r for r in result[key]
|
||||||
|
if r.get("fiscalDateEnding", "") <= curr_date
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
|
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
|
||||||
|
|
@ -19,7 +37,9 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
||||||
return _make_api_request("OVERVIEW", params)
|
return _make_api_request("OVERVIEW", params)
|
||||||
|
|
||||||
|
|
||||||
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
def get_balance_sheet(
|
||||||
|
ticker: str, freq: str = "quarterly", curr_date: str = None
|
||||||
|
) -> Union[dict, str]:
|
||||||
"""
|
"""
|
||||||
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage.
|
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage.
|
||||||
|
|
||||||
|
|
@ -29,23 +49,19 @@ def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = Non
|
||||||
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Balance sheet data with normalized fields
|
dict | str: Balance sheet data dict, or an error string.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"symbol": ticker,
|
"symbol": ticker,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = _make_api_request("BALANCE_SHEET", params)
|
result = _make_api_request("BALANCE_SHEET", params)
|
||||||
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
|
return _filter_reports_by_date(result, curr_date)
|
||||||
if curr_date and isinstance(result, dict):
|
|
||||||
for key in ("annualReports", "quarterlyReports"):
|
|
||||||
if key in result:
|
|
||||||
result[key] = [r for r in result[key]
|
|
||||||
if r.get("fiscalDateEnding", "") <= curr_date]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
def get_cashflow(
|
||||||
|
ticker: str, freq: str = "quarterly", curr_date: str = None
|
||||||
|
) -> Union[dict, str]:
|
||||||
"""
|
"""
|
||||||
Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage.
|
Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage.
|
||||||
|
|
||||||
|
|
@ -55,23 +71,19 @@ def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) ->
|
||||||
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Cash flow statement data with normalized fields
|
dict | str: Cash flow statement data dict, or an error string.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"symbol": ticker,
|
"symbol": ticker,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = _make_api_request("CASH_FLOW", params)
|
result = _make_api_request("CASH_FLOW", params)
|
||||||
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
|
return _filter_reports_by_date(result, curr_date)
|
||||||
if curr_date and isinstance(result, dict):
|
|
||||||
for key in ("annualReports", "quarterlyReports"):
|
|
||||||
if key in result:
|
|
||||||
result[key] = [r for r in result[key]
|
|
||||||
if r.get("fiscalDateEnding", "") <= curr_date]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
def get_income_statement(
|
||||||
|
ticker: str, freq: str = "quarterly", curr_date: str = None
|
||||||
|
) -> Union[dict, str]:
|
||||||
"""
|
"""
|
||||||
Retrieve income statement data for a given ticker symbol using Alpha Vantage.
|
Retrieve income statement data for a given ticker symbol using Alpha Vantage.
|
||||||
|
|
||||||
|
|
@ -81,18 +93,11 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str =
|
||||||
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Income statement data with normalized fields
|
dict | str: Income statement data dict, or an error string.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"symbol": ticker,
|
"symbol": ticker,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = _make_api_request("INCOME_STATEMENT", params)
|
result = _make_api_request("INCOME_STATEMENT", params)
|
||||||
# Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias.
|
return _filter_reports_by_date(result, curr_date)
|
||||||
if curr_date and isinstance(result, dict):
|
|
||||||
for key in ("annualReports", "quarterlyReports"):
|
|
||||||
if key in result:
|
|
||||||
result[key] = [r for r in result[key]
|
|
||||||
if r.get("fiscalDateEnding", "") <= curr_date]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,25 @@
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
import os
|
import os
|
||||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
|
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_financials_by_date(data: "pd.DataFrame", curr_date: str) -> "pd.DataFrame":
|
||||||
|
"""Drop DataFrame columns (fiscal period timestamps) that exceed curr_date.
|
||||||
|
|
||||||
|
yfinance financial statements are indexed by metric name (rows) and fiscal
|
||||||
|
period end date (columns). Columns that post-date the simulation's current
|
||||||
|
date represent future data and must be removed to prevent look-ahead bias.
|
||||||
|
"""
|
||||||
|
if not curr_date or data.empty:
|
||||||
|
return data
|
||||||
|
cutoff = pd.Timestamp(curr_date)
|
||||||
|
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
|
||||||
|
return data.loc[:, mask]
|
||||||
|
|
||||||
def get_YFin_data_online(
|
def get_YFin_data_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
|
|
@ -365,9 +380,7 @@ def get_balance_sheet(
|
||||||
data = yf_retry(lambda: ticker_obj.balance_sheet)
|
data = yf_retry(lambda: ticker_obj.balance_sheet)
|
||||||
|
|
||||||
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
||||||
if curr_date and not data.empty:
|
data = _filter_financials_by_date(data, curr_date)
|
||||||
cutoff = pd.Timestamp(curr_date)
|
|
||||||
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
|
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No balance sheet data found for symbol '{ticker}'" + (
|
return f"No balance sheet data found for symbol '{ticker}'" + (
|
||||||
|
|
@ -402,9 +415,7 @@ def get_cashflow(
|
||||||
data = yf_retry(lambda: ticker_obj.cashflow)
|
data = yf_retry(lambda: ticker_obj.cashflow)
|
||||||
|
|
||||||
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
||||||
if curr_date and not data.empty:
|
data = _filter_financials_by_date(data, curr_date)
|
||||||
cutoff = pd.Timestamp(curr_date)
|
|
||||||
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
|
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No cash flow data found for symbol '{ticker}'" + (
|
return f"No cash flow data found for symbol '{ticker}'" + (
|
||||||
|
|
@ -439,9 +450,7 @@ def get_income_statement(
|
||||||
data = yf_retry(lambda: ticker_obj.income_stmt)
|
data = yf_retry(lambda: ticker_obj.income_stmt)
|
||||||
|
|
||||||
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
# Filter out fiscal periods after curr_date to prevent look-ahead bias.
|
||||||
if curr_date and not data.empty:
|
data = _filter_financials_by_date(data, curr_date)
|
||||||
cutoff = pd.Timestamp(curr_date)
|
|
||||||
data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]]
|
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No income statement data found for symbol '{ticker}'" + (
|
return f"No income statement data found for symbol '{ticker}'" + (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue