diff --git a/tests/test_technical_indicators_tools.py b/tests/test_technical_indicators_tools.py new file mode 100644 index 00000000..ee4ddeab --- /dev/null +++ b/tests/test_technical_indicators_tools.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import call, patch + +from tradingagents.agents.utils import technical_indicators_tools as tools + + +class TechnicalIndicatorsToolTests(unittest.TestCase): + def test_get_indicators_splits_and_normalizes_indicator_names(self): + with patch.object( + tools, + "route_to_vendor", + side_effect=["rsi output", "macd output"], + ) as mock_route: + result = tools.get_indicators.func( + "AAPL", + " RSI, MACD ", + "2026-03-31", + 30, + ) + + self.assertEqual(result, "rsi output\n\nmacd output") + self.assertEqual( + mock_route.call_args_list, + [ + call("get_indicators", "AAPL", "rsi", "2026-03-31", 30), + call("get_indicators", "AAPL", "macd", "2026-03-31", 30), + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/agents/utils/technical_indicators_tools.py b/tradingagents/agents/utils/technical_indicators_tools.py index c6c08bca..218567f5 100644 --- a/tradingagents/agents/utils/technical_indicators_tools.py +++ b/tradingagents/agents/utils/technical_indicators_tools.py @@ -14,10 +14,26 @@ def get_indicators( Uses the configured technical_indicators vendor. Args: symbol (str): Ticker symbol of the company, e.g. AAPL, TSM - indicator (str): Technical indicator to get the analysis and report of + indicator (str): Technical indicator to get the analysis and report of. + Comma-separated indicator names are supported. curr_date (str): The current trading date you are trading on, YYYY-mm-dd look_back_days (int): How many days to look back, default is 30 Returns: str: A formatted dataframe containing the technical indicators for the specified ticker symbol and indicator. """ - return route_to_vendor("get_indicators", symbol, indicator, curr_date, look_back_days) \ No newline at end of file + indicators = [item.strip().lower() for item in indicator.split(",") if item.strip()] + results = [] + for indicator_name in indicators: + try: + results.append( + route_to_vendor( + "get_indicators", + symbol, + indicator_name, + curr_date, + look_back_days, + ) + ) + except ValueError as exc: + results.append(str(exc)) + return "\n\n".join(results)