diff --git a/tests/test_y_finance_bulk_indicator.py b/tests/test_y_finance_bulk_indicator.py index 6ea7f258..a870929b 100644 --- a/tests/test_y_finance_bulk_indicator.py +++ b/tests/test_y_finance_bulk_indicator.py @@ -13,10 +13,10 @@ class YFinanceBulkIndicatorTests(unittest.TestCase): fake_stockstats.wrap = lambda df: df with patch.dict(sys.modules, {"stockstats": fake_stockstats}): - sys.modules.pop("tradingagents.dataflows.stockstats_utils", None) - sys.modules.pop("tradingagents.dataflows.y_finance", None) + with patch.dict(sys.modules): + sys.modules.pop("tradingagents.dataflows.stockstats_utils", None) + sys.modules.pop("tradingagents.dataflows.y_finance", None) - try: y_finance = importlib.import_module("tradingagents.dataflows.y_finance") sample_df = pd.DataFrame( { @@ -27,12 +27,9 @@ class YFinanceBulkIndicatorTests(unittest.TestCase): with patch.object(y_finance, "load_ohlcv", return_value=sample_df): result = y_finance._get_stock_stats_bulk("AAPL", "rsi", "2024-01-03") - finally: - sys.modules.pop("tradingagents.dataflows.stockstats_utils", None) - sys.modules.pop("tradingagents.dataflows.y_finance", None) - self.assertEqual(result["2024-01-02"], "N/A") - self.assertEqual(result["2024-01-03"], "55.5") + self.assertEqual(result["2024-01-02"], "N/A") + self.assertEqual(result["2024-01-03"], "55.5") if __name__ == "__main__":