Filter Alpha Vantage reports by curr_date
This commit is contained in:
parent
32be17c606
commit
9f9ef2568c
|
|
@ -0,0 +1,45 @@
|
|||
import json
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.dataflows import alpha_vantage_fundamentals as fundamentals
|
||||
|
||||
|
||||
class AlphaVantageFundamentalsTests(unittest.TestCase):
|
||||
def test_curr_date_filters_future_reports_from_json_response(self):
|
||||
payload = json.dumps(
|
||||
{
|
||||
"annualReports": [
|
||||
{"fiscalDateEnding": "2023-12-31", "totalAssets": "100"},
|
||||
{"fiscalDateEnding": "2024-12-31", "totalAssets": "200"},
|
||||
],
|
||||
"quarterlyReports": [
|
||||
{"fiscalDateEnding": "2024-03-31", "totalAssets": "110"},
|
||||
{"fiscalDateEnding": "2024-06-30", "totalAssets": "120"},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
funcs = [
|
||||
fundamentals.get_balance_sheet,
|
||||
fundamentals.get_cashflow,
|
||||
fundamentals.get_income_statement,
|
||||
]
|
||||
|
||||
for func in funcs:
|
||||
with self.subTest(func=func.__name__):
|
||||
with patch.object(fundamentals, "_make_api_request", return_value=payload):
|
||||
result = json.loads(func("AAPL", curr_date="2024-03-31"))
|
||||
|
||||
self.assertEqual(
|
||||
result["annualReports"],
|
||||
[{"fiscalDateEnding": "2023-12-31", "totalAssets": "100"}],
|
||||
)
|
||||
self.assertEqual(
|
||||
result["quarterlyReports"],
|
||||
[{"fiscalDateEnding": "2024-03-31", "totalAssets": "110"}],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,6 +1,38 @@
|
|||
import json
|
||||
|
||||
from .alpha_vantage_common import _make_api_request
|
||||
|
||||
|
||||
def _filter_reports_by_date(result, curr_date: str):
|
||||
"""Filter annualReports/quarterlyReports to exclude entries after curr_date."""
|
||||
if not curr_date or not result:
|
||||
return result
|
||||
|
||||
serialized = False
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
result = json.loads(result)
|
||||
serialized = True
|
||||
except json.JSONDecodeError:
|
||||
return result
|
||||
|
||||
if not isinstance(result, dict):
|
||||
return result
|
||||
|
||||
for key in ("annualReports", "quarterlyReports"):
|
||||
if key in result:
|
||||
result[key] = [
|
||||
report
|
||||
for report in result[key]
|
||||
if report.get("fiscalDateEnding", "") <= curr_date
|
||||
]
|
||||
|
||||
if serialized:
|
||||
return json.dumps(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
||||
"""
|
||||
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
|
||||
|
|
@ -35,7 +67,7 @@ def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = Non
|
|||
"symbol": ticker,
|
||||
}
|
||||
|
||||
return _make_api_request("BALANCE_SHEET", params)
|
||||
return _filter_reports_by_date(_make_api_request("BALANCE_SHEET", params), curr_date)
|
||||
|
||||
|
||||
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
|
|
@ -54,7 +86,7 @@ def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) ->
|
|||
"symbol": ticker,
|
||||
}
|
||||
|
||||
return _make_api_request("CASH_FLOW", params)
|
||||
return _filter_reports_by_date(_make_api_request("CASH_FLOW", params), curr_date)
|
||||
|
||||
|
||||
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
|
|
@ -73,5 +105,4 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str =
|
|||
"symbol": ticker,
|
||||
}
|
||||
|
||||
return _make_api_request("INCOME_STATEMENT", params)
|
||||
|
||||
return _filter_reports_by_date(_make_api_request("INCOME_STATEMENT", params), curr_date)
|
||||
|
|
|
|||
Loading…
Reference in New Issue