diff --git a/tests/test_alpha_vantage_fundamentals.py b/tests/test_alpha_vantage_fundamentals.py new file mode 100644 index 00000000..4e090953 --- /dev/null +++ b/tests/test_alpha_vantage_fundamentals.py @@ -0,0 +1,45 @@ +import json +import unittest +from unittest.mock import patch + +from tradingagents.dataflows import alpha_vantage_fundamentals as fundamentals + + +class AlphaVantageFundamentalsTests(unittest.TestCase): + def test_curr_date_filters_future_reports_from_json_response(self): + payload = json.dumps( + { + "annualReports": [ + {"fiscalDateEnding": "2023-12-31", "totalAssets": "100"}, + {"fiscalDateEnding": "2024-12-31", "totalAssets": "200"}, + ], + "quarterlyReports": [ + {"fiscalDateEnding": "2024-03-31", "totalAssets": "110"}, + {"fiscalDateEnding": "2024-06-30", "totalAssets": "120"}, + ], + } + ) + + funcs = [ + fundamentals.get_balance_sheet, + fundamentals.get_cashflow, + fundamentals.get_income_statement, + ] + + for func in funcs: + with self.subTest(func=func.__name__): + with patch.object(fundamentals, "_make_api_request", return_value=payload): + result = json.loads(func("AAPL", curr_date="2024-03-31")) + + self.assertEqual( + result["annualReports"], + [{"fiscalDateEnding": "2023-12-31", "totalAssets": "100"}], + ) + self.assertEqual( + result["quarterlyReports"], + [{"fiscalDateEnding": "2024-03-31", "totalAssets": "110"}], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/dataflows/alpha_vantage_fundamentals.py b/tradingagents/dataflows/alpha_vantage_fundamentals.py index 8b92faa6..cb53fa03 100644 --- a/tradingagents/dataflows/alpha_vantage_fundamentals.py +++ b/tradingagents/dataflows/alpha_vantage_fundamentals.py @@ -1,6 +1,38 @@ +import json + from .alpha_vantage_common import _make_api_request +def _filter_reports_by_date(result, curr_date: str): + """Filter annualReports/quarterlyReports to exclude entries after curr_date.""" + if not curr_date or not result: + return result + + serialized = False + if isinstance(result, str): + try: + result = json.loads(result) + serialized = True + except json.JSONDecodeError: + return result + + if not isinstance(result, dict): + return result + + for key in ("annualReports", "quarterlyReports"): + if key in result: + result[key] = [ + report + for report in result[key] + if report.get("fiscalDateEnding", "") <= curr_date + ] + + if serialized: + return json.dumps(result) + + return result + + def get_fundamentals(ticker: str, curr_date: str = None) -> str: """ Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage. @@ -35,7 +67,7 @@ def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = Non "symbol": ticker, } - return _make_api_request("BALANCE_SHEET", params) + return _filter_reports_by_date(_make_api_request("BALANCE_SHEET", params), curr_date) def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: @@ -54,7 +86,7 @@ def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> "symbol": ticker, } - return _make_api_request("CASH_FLOW", params) + return _filter_reports_by_date(_make_api_request("CASH_FLOW", params), curr_date) def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: @@ -73,5 +105,4 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = "symbol": ticker, } - return _make_api_request("INCOME_STATEMENT", params) - + return _filter_reports_by_date(_make_api_request("INCOME_STATEMENT", params), curr_date)