diff --git a/tradingagents/dataflows/alpha_vantage_fundamentals.py b/tradingagents/dataflows/alpha_vantage_fundamentals.py index 3f401337..8e0e9768 100644 --- a/tradingagents/dataflows/alpha_vantage_fundamentals.py +++ b/tradingagents/dataflows/alpha_vantage_fundamentals.py @@ -1,6 +1,24 @@ +from typing import Union + from .alpha_vantage_common import _make_api_request +def _filter_reports_by_date(result: dict, curr_date: str) -> dict: + """Remove annualReports/quarterlyReports whose fiscalDateEnding exceeds curr_date. + + Mutates *result* in-place and returns it so callers can chain the call. + """ + if not (curr_date and isinstance(result, dict)): + return result + for key in ("annualReports", "quarterlyReports"): + if key in result: + result[key] = [ + r for r in result[key] + if r.get("fiscalDateEnding", "") <= curr_date + ] + return result + + def get_fundamentals(ticker: str, curr_date: str = None) -> str: """ Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage. @@ -19,7 +37,9 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str: return _make_api_request("OVERVIEW", params) -def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: +def get_balance_sheet( + ticker: str, freq: str = "quarterly", curr_date: str = None +) -> Union[dict, str]: """ Retrieve balance sheet data for a given ticker symbol using Alpha Vantage. @@ -29,23 +49,19 @@ def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = Non curr_date (str): Current date you are trading at, yyyy-mm-dd Returns: - str: Balance sheet data with normalized fields + dict | str: Balance sheet data dict, or an error string. """ params = { "symbol": ticker, } result = _make_api_request("BALANCE_SHEET", params) - # Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias. - if curr_date and isinstance(result, dict): - for key in ("annualReports", "quarterlyReports"): - if key in result: - result[key] = [r for r in result[key] - if r.get("fiscalDateEnding", "") <= curr_date] - return result + return _filter_reports_by_date(result, curr_date) -def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: +def get_cashflow( + ticker: str, freq: str = "quarterly", curr_date: str = None +) -> Union[dict, str]: """ Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage. @@ -55,23 +71,19 @@ def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> curr_date (str): Current date you are trading at, yyyy-mm-dd Returns: - str: Cash flow statement data with normalized fields + dict | str: Cash flow statement data dict, or an error string. """ params = { "symbol": ticker, } result = _make_api_request("CASH_FLOW", params) - # Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias. - if curr_date and isinstance(result, dict): - for key in ("annualReports", "quarterlyReports"): - if key in result: - result[key] = [r for r in result[key] - if r.get("fiscalDateEnding", "") <= curr_date] - return result + return _filter_reports_by_date(result, curr_date) -def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: +def get_income_statement( + ticker: str, freq: str = "quarterly", curr_date: str = None +) -> Union[dict, str]: """ Retrieve income statement data for a given ticker symbol using Alpha Vantage. @@ -81,18 +93,11 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = curr_date (str): Current date you are trading at, yyyy-mm-dd Returns: - str: Income statement data with normalized fields + dict | str: Income statement data dict, or an error string. """ params = { "symbol": ticker, } result = _make_api_request("INCOME_STATEMENT", params) - # Filter out reports whose fiscalDateEnding is after curr_date to prevent look-ahead bias. - if curr_date and isinstance(result, dict): - for key in ("annualReports", "quarterlyReports"): - if key in result: - result[key] = [r for r in result[key] - if r.get("fiscalDateEnding", "") <= curr_date] - return result - + return _filter_reports_by_date(result, curr_date) diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 8f6d6680..cd35c165 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -1,10 +1,25 @@ from typing import Annotated from datetime import datetime from dateutil.relativedelta import relativedelta +import pandas as pd import yfinance as yf import os from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry + +def _filter_financials_by_date(data: "pd.DataFrame", curr_date: str) -> "pd.DataFrame": + """Drop DataFrame columns (fiscal period timestamps) that exceed curr_date. + + yfinance financial statements are indexed by metric name (rows) and fiscal + period end date (columns). Columns that post-date the simulation's current + date represent future data and must be removed to prevent look-ahead bias. + """ + if not curr_date or data.empty: + return data + cutoff = pd.Timestamp(curr_date) + mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff + return data.loc[:, mask] + def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], @@ -365,9 +380,7 @@ def get_balance_sheet( data = yf_retry(lambda: ticker_obj.balance_sheet) # Filter out fiscal periods after curr_date to prevent look-ahead bias. - if curr_date and not data.empty: - cutoff = pd.Timestamp(curr_date) - data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]] + data = _filter_financials_by_date(data, curr_date) if data.empty: return f"No balance sheet data found for symbol '{ticker}'" + ( @@ -402,9 +415,7 @@ def get_cashflow( data = yf_retry(lambda: ticker_obj.cashflow) # Filter out fiscal periods after curr_date to prevent look-ahead bias. - if curr_date and not data.empty: - cutoff = pd.Timestamp(curr_date) - data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]] + data = _filter_financials_by_date(data, curr_date) if data.empty: return f"No cash flow data found for symbol '{ticker}'" + ( @@ -439,9 +450,7 @@ def get_income_statement( data = yf_retry(lambda: ticker_obj.income_stmt) # Filter out fiscal periods after curr_date to prevent look-ahead bias. - if curr_date and not data.empty: - cutoff = pd.Timestamp(curr_date) - data = data.loc[:, [c for c in data.columns if pd.Timestamp(c) <= cutoff]] + data = _filter_financials_by_date(data, curr_date) if data.empty: return f"No income statement data found for symbol '{ticker}'" + (