TradingAgents/tradingagents/agents/utils/agent_utils.py

323 lines
13 KiB
Python

from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage
from typing import List
from typing import Annotated
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import RemoveMessage
from langchain_core.tools import tool
from datetime import date, timedelta, datetime
import functools
import pandas as pd
import os
from dateutil.relativedelta import relativedelta
from langchain_openai import ChatOpenAI
import tradingagents.dataflows.interface as interface
from tradingagents.default_config import DEFAULT_CONFIG
from langchain_core.messages import HumanMessage
def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
class Toolkit:
_config = DEFAULT_CONFIG.copy()
@classmethod
def update_config(cls, config):
"""Update the class-level configuration."""
cls._config.update(config)
@property
def config(self):
"""Access the configuration."""
return self._config
def __init__(self, config=None):
if config:
self.update_config(config)
@staticmethod
@tool
def get_blockbeats_news(
count: Annotated[int, "Number of news articles to retrieve, no more than 30"] = 10,
) -> str:
"""
Retrieve the latest news from BlockBeats, especially useful for Cryptos.
Args:
count (int): Number of news articles to retrieve, no more than 30
Returns:
str: A formatted string containing the latest news from BlockBeats.
"""
blockbeats_news_result = interface.get_blockbeats_news(count)
return blockbeats_news_result
@staticmethod
@tool
def get_coindesk_news(
tickers: Annotated[
List[str],
"List of tickers to get news for, e.g. ['BTC', 'ETH']",
] = [],
count: Annotated[int, "Number of news articles to retrieve, default is 10"] = 10,
) -> str:
"""
Retrieve the latest news from Coindesk for a given list of tickers.
Args:
tickers (List[str]): List of tickers to get news for, e.g. ['BTC', 'ETH']
count (int): Number of news articles to retrieve, default is 10
Returns:
str: A formatted string containing the latest news from Coindesk for the specified tickers.
"""
coindesk_news_result = interface.get_coindesk_news(tickers, count)
return coindesk_news_result
@staticmethod
@tool
def get_coinstats_news() -> str:
"""
Retrieve the latest news from CoinStats.
Returns:
str: A formatted string containing the latest news from CoinStats.
"""
coinstats_news_result = interface.get_coinstats_news()
return coinstats_news_result
@staticmethod
@tool
def get_binance_ohlcv(
symbol: Annotated[str, "ticker symbol of the asset"],
interval: Annotated[
str,
"Interval for the data, e.g. '1m', '5m', '1h', '1d'",
] = "15m",
) -> str:
"""
Retrieve the latest OHLCV data from Binance for a given symbol and interval.
Args:
symbol (str): Ticker symbol of the asset, e.g. 'BTCUSDT'
interval (str): Interval for the data, e.g. '1m', '5m', '1h', '1d'
Returns:
str: A formatted string containing the latest OHLCV data from Binance for the specified symbol and interval.
"""
binance_ohlcv_result = interface.get_binance_ohlcv(symbol, interval)
return binance_ohlcv_result
@staticmethod
@tool
def get_coinstats_btc_dominance() -> str:
"""
Retrieve the current Bitcoin dominance percentage from CoinStats.
Returns:
str: A formatted string containing the last daily and weekly Bitcoin dominance percentage.
"""
btc_dominance_result = interface.get_coinstats_btc_dominance()
return btc_dominance_result
@staticmethod
@tool
def get_binance_data(
symbol: Annotated[str, "ticker symbol of the asset"],
interval: Annotated[
str,
"Interval for the data, e.g. '1m', '5m', '1h', '1d'",
] = "15m",
) -> str:
"""
Retrieve the latest market data from Binance for a given symbol and interval.
Args:
symbol (str): Ticker symbol of the asset, e.g. 'BTCUSDT'
interval (str): Interval for the data, e.g. '1m', '5m', '1h', '1d'
Returns:
str: A formatted string containing the latest market data from Binance for the specified symbol and interval.
"""
binance_data_result = interface.get_binance_data(symbol, interval)
return binance_data_result
@staticmethod
@tool
def get_fear_and_greed_index() -> str:
"""
Get current crypto market Fear and Greed Index. 0 means "Extreme Fear", while 100 means "Extreme Greed"
Returns:
str: A formatted string containing the current crypto market Fear and Greed Index.
"""
return interface.get_fear_and_greed_index()
@staticmethod
@tool
def get_taapi_bulk_indicators(
symbol: Annotated[str, "Ticker symbol of the asset, e.g. 'BTC'"],
interval: Annotated[
str,
"Interval for the data, e.g. '1m', '5m', '1h', '1d'",
] = "15m",
ema_period: Annotated[int, "EMA period, default is 30"] = 30,
ichimoku_conversionPeriod: Annotated[int, "Ichimoku conversion line period, default is 9"] = 9,
ichimoku_basePeriod: Annotated[int, "Ichimoku base line period, default is 26"] = 26,
ichimoku_spanPeriod: Annotated[int, "Ichimoku span period, default is 52"] = 52,
ichimoku_displacement: Annotated[int, "Ichimoku displacement, default is 26"] = 26,
supertrend_period: Annotated[int, "Supertrend period, default is 7"] = 7,
supertrend_multiplier: Annotated[float, "Supertrend multiplier, default is 3.0"] = 3.0,
psar_start: Annotated[float, "Parabolic SAR start, default is 0.02"] = 0.02,
psar_increment: Annotated[float, "Parabolic SAR increment, default is 0.02"] = 0.02,
psar_maximum: Annotated[float, "Parabolic SAR maximum, default is 0.2"] = 0.2,
donchianchannels_period: Annotated[int, "Donchian Channels period, default is 20"] = 20,
macd_optInFastPeriod: Annotated[int, "MACD fast period, default is 12"] = 12,
macd_optInSlowPeriod: Annotated[int, "MACD slow period, default is 26"] = 26,
macd_optInSignalPeriod: Annotated[int, "MACD signal period, default is 9"] = 9,
rsi_period: Annotated[int, "RSI period, default is 14"] = 14,
stochrsi_rsiPeriod: Annotated[int, "Stochastic RSI RSI period, default is 14"] = 14,
stochrsi_kPeriod: Annotated[int, "Stochastic RSI %K period, default is 5"] = 5,
stochrsi_dPeriod: Annotated[int, "Stochastic RSI %D period, default is 3"] = 3,
stochrsi_stochasticPeriod: Annotated[int, "Stochastic RSI stochastic period, default is 14"] = 14,
trix_period: Annotated[int, "TRIX period, default is 30"] = 30,
stc_fastLength: Annotated[int, "STC fast length, default is 23"] = 23,
stc_slowLength: Annotated[int, "STC slow length, default is 50"] = 50,
stc_cycleLength: Annotated[int, "STC cycle length, default is 10"] = 10,
atr_period: Annotated[int, "ATR period, default is 14"] = 14,
bbands_period: Annotated[int, "Bollinger Bands period, default is 20"] = 20,
bbands_stddev: Annotated[float, "Bollinger Bands standard deviation, default is 2.0"] = 2.0,
keltnerchannels_period: Annotated[int, "Keltner Channels period, default is 20"] = 20,
keltnerchannels_multiplier: Annotated[float, "Keltner Channels multiplier, default is 2"] = 2,
keltnerchannels_atrLength: Annotated[int, "Keltner Channels ATR length, default is 10"] = 10,
chop_period: Annotated[int, "Chop period, default is 14"] = 14,
) -> str:
"""
Retrieve bulk technical indicators from TAAPI.io for a given symbol and interval.
Args:
symbol (str): Ticker symbol of the asset, e.g. 'BTC'
interval (str): Interval for the data, e.g. '1m', '5m', '1h', '1d'
Returns:
str: A formatted string containing the bulk technical indicators from TAAPI.io for the specified symbol and interval.
"""
taapi_bulk_indicators_result = interface.get_taapi_bulk_indicators("BTC", "15m")
return taapi_bulk_indicators_result
@staticmethod
@tool
def get_reddit_news(
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
) -> str:
"""
Retrieve global news from Reddit within a specified time frame.
Args:
curr_date (str): Date you want to get news for in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
"""
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
@staticmethod
@tool
def get_reddit_stock_info(
ticker: Annotated[
str,
"Ticker of a asset. e.g. AAPL, TSM",
],
curr_date: Annotated[str, "Current date you want to get news for"],
) -> str:
"""
Retrieve the latest news about a given stock from Reddit, given the current date.
Args:
ticker (str): Ticker of a asset. e.g. AAPL, TSM
curr_date (str): current date in yyyy-mm-dd format to get news for
Returns:
str: A formatted dataframe containing the latest news about the asset on the given date
"""
stock_news_results = interface.get_reddit_asset_news(ticker, curr_date, 7, 5)
return stock_news_results
@staticmethod
@tool
def get_google_news(
query: Annotated[str, "Query to search with"],
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
):
"""
Retrieve the latest news from Google News based on a query and date range.
Args:
query (str): Query to search with
curr_date (str): Current date in yyyy-mm-dd format
look_back_days (int): How many days to look back
Returns:
str: A formatted string containing the latest news from Google News based on the query and date range.
"""
google_news_results = interface.get_google_news(query, curr_date, 7)
return google_news_results
@staticmethod
@tool
def get_stock_news_openai(
ticker: Annotated[str, "the asset's ticker"],
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
):
"""
Retrieve the latest news about a given stock by using OpenAI's news API.
Args:
ticker (str): Ticker of a asset. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest news about the asset on the given date.
"""
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
return openai_news_results
@staticmethod
@tool
def get_global_news_openai(
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
):
"""
Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
Args:
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest macroeconomic news on the given date.
"""
openai_news_results = interface.get_global_news_openai(curr_date)
return openai_news_results
@staticmethod
@tool
def get_fundamentals_openai(
ticker: Annotated[str, "the asset's ticker"],
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
):
"""
Retrieve the latest fundamental information about a given stock on a given date by using OpenAI's news API.
Args:
ticker (str): Ticker of a asset. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest fundamental information about the asset on the given date.
"""
openai_fundamentals_results = interface.get_fundamentals_openai(
ticker, curr_date
)
return openai_fundamentals_results