Normalize technical indicator names in tool routing
This commit is contained in:
parent
32be17c606
commit
abf82e2cec
|
|
@ -0,0 +1,32 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import call, patch
|
||||||
|
|
||||||
|
from tradingagents.agents.utils import technical_indicators_tools as tools
|
||||||
|
|
||||||
|
|
||||||
|
class TechnicalIndicatorsToolTests(unittest.TestCase):
|
||||||
|
def test_get_indicators_splits_and_normalizes_indicator_names(self):
|
||||||
|
with patch.object(
|
||||||
|
tools,
|
||||||
|
"route_to_vendor",
|
||||||
|
side_effect=["rsi output", "macd output"],
|
||||||
|
) as mock_route:
|
||||||
|
result = tools.get_indicators.func(
|
||||||
|
"AAPL",
|
||||||
|
" RSI, MACD ",
|
||||||
|
"2026-03-31",
|
||||||
|
30,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result, "rsi output\n\nmacd output")
|
||||||
|
self.assertEqual(
|
||||||
|
mock_route.call_args_list,
|
||||||
|
[
|
||||||
|
call("get_indicators", "AAPL", "rsi", "2026-03-31", 30),
|
||||||
|
call("get_indicators", "AAPL", "macd", "2026-03-31", 30),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -14,10 +14,26 @@ def get_indicators(
|
||||||
Uses the configured technical_indicators vendor.
|
Uses the configured technical_indicators vendor.
|
||||||
Args:
|
Args:
|
||||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||||
indicator (str): Technical indicator to get the analysis and report of
|
indicator (str): Technical indicator to get the analysis and report of.
|
||||||
|
Comma-separated indicator names are supported.
|
||||||
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
||||||
look_back_days (int): How many days to look back, default is 30
|
look_back_days (int): How many days to look back, default is 30
|
||||||
Returns:
|
Returns:
|
||||||
str: A formatted dataframe containing the technical indicators for the specified ticker symbol and indicator.
|
str: A formatted dataframe containing the technical indicators for the specified ticker symbol and indicator.
|
||||||
"""
|
"""
|
||||||
return route_to_vendor("get_indicators", symbol, indicator, curr_date, look_back_days)
|
indicators = [item.strip().lower() for item in indicator.split(",") if item.strip()]
|
||||||
|
results = []
|
||||||
|
for indicator_name in indicators:
|
||||||
|
try:
|
||||||
|
results.append(
|
||||||
|
route_to_vendor(
|
||||||
|
"get_indicators",
|
||||||
|
symbol,
|
||||||
|
indicator_name,
|
||||||
|
curr_date,
|
||||||
|
look_back_days,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
results.append(str(exc))
|
||||||
|
return "\n\n".join(results)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue