diff --git a/tests/test_y_finance_bulk_indicator.py b/tests/test_y_finance_bulk_indicator.py new file mode 100644 index 00000000..6ea7f258 --- /dev/null +++ b/tests/test_y_finance_bulk_indicator.py @@ -0,0 +1,39 @@ +import importlib +import sys +import types +import unittest +from unittest.mock import patch + +import pandas as pd + + +class YFinanceBulkIndicatorTests(unittest.TestCase): + def test_bulk_indicator_maps_nan_to_na(self): + fake_stockstats = types.ModuleType("stockstats") + fake_stockstats.wrap = lambda df: df + + with patch.dict(sys.modules, {"stockstats": fake_stockstats}): + sys.modules.pop("tradingagents.dataflows.stockstats_utils", None) + sys.modules.pop("tradingagents.dataflows.y_finance", None) + + try: + y_finance = importlib.import_module("tradingagents.dataflows.y_finance") + sample_df = pd.DataFrame( + { + "Date": pd.to_datetime(["2024-01-02", "2024-01-03"]), + "rsi": [float("nan"), 55.5], + } + ) + + with patch.object(y_finance, "load_ohlcv", return_value=sample_df): + result = y_finance._get_stock_stats_bulk("AAPL", "rsi", "2024-01-03") + finally: + sys.modules.pop("tradingagents.dataflows.stockstats_utils", None) + sys.modules.pop("tradingagents.dataflows.y_finance", None) + + self.assertEqual(result["2024-01-02"], "N/A") + self.assertEqual(result["2024-01-03"], "55.5") + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 8b4b93f5..64bf98d0 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -1,6 +1,7 @@ from typing import Annotated from datetime import datetime from dateutil.relativedelta import relativedelta +import pandas as pd import yfinance as yf import os from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date @@ -418,4 +419,4 @@ def get_insider_transactions( return header + csv_string except Exception as e: - return f"Error retrieving insider transactions for {ticker}: {str(e)}" \ No newline at end of file + return f"Error retrieving insider transactions for {ticker}: {str(e)}"