394 lines
14 KiB
Python
394 lines
14 KiB
Python
import logging
|
|
import re
|
|
from datetime import datetime, timedelta
|
|
from typing import Annotated
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
from tradingagents.config import DEFAULT_CONFIG, TradingAgentsConfig
|
|
from tradingagents.domains.marketdata.fundamental_data_service import (
|
|
BalanceSheetContext,
|
|
CashFlowContext,
|
|
FundamentalDataService,
|
|
IncomeStatementContext,
|
|
)
|
|
from tradingagents.domains.marketdata.insider_data_service import (
|
|
InsiderDataService,
|
|
InsiderSentimentContext,
|
|
InsiderTransactionContext,
|
|
)
|
|
from tradingagents.domains.marketdata.market_data_service import (
|
|
MarketDataService,
|
|
PriceDataContext,
|
|
TAReportContext,
|
|
)
|
|
|
|
# Import context models
|
|
from tradingagents.domains.news.news_service import (
|
|
GlobalNewsContext,
|
|
NewsContext,
|
|
NewsService,
|
|
)
|
|
from tradingagents.domains.socialmedia.social_media_service import (
|
|
SocialMediaService,
|
|
StockSocialContext,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentToolkit:
|
|
def __init__(
|
|
self,
|
|
news_service: NewsService,
|
|
marketdata_service: MarketDataService,
|
|
fundamentaldata_service: FundamentalDataService,
|
|
socialmedia_service: SocialMediaService,
|
|
insiderdata_service: InsiderDataService,
|
|
config: TradingAgentsConfig = DEFAULT_CONFIG,
|
|
):
|
|
self._news_service = news_service
|
|
self._marketdata_service = marketdata_service
|
|
self._fundamentaldata_service = fundamentaldata_service
|
|
self._socialmedia_service = socialmedia_service
|
|
self._insiderdata_service = insiderdata_service
|
|
self._config = config
|
|
|
|
@tool
|
|
def get_global_news(
|
|
self,
|
|
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
|
|
) -> GlobalNewsContext:
|
|
"""
|
|
Retrieve global news from Reddit within a specified time frame.
|
|
Args:
|
|
curr_date (str): Date you want to get news for in yyyy-mm-dd format
|
|
Returns:
|
|
GlobalNewsContext: Structured global news context with articles and sentiment analysis.
|
|
"""
|
|
# Calculate date range (current date only)
|
|
start_date = curr_date
|
|
end_date = curr_date
|
|
|
|
# Call specialized service method
|
|
return self._news_service.get_global_news_context(
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
categories=["general", "business", "politics"],
|
|
)
|
|
|
|
@tool
|
|
def get_news(
|
|
self,
|
|
ticker: Annotated[
|
|
str,
|
|
"Search query of a company, e.g. 'AAPL, TSM, etc.",
|
|
],
|
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
|
) -> NewsContext:
|
|
"""
|
|
Retrieve the latest news about a given stock from Finnhub within a date range
|
|
Args:
|
|
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
|
start_date (str): Start date in yyyy-mm-dd format
|
|
end_date (str): End date in yyyy-mm-dd format
|
|
Returns:
|
|
NewsContext: Structured news context with articles and sentiment analysis for the company.
|
|
"""
|
|
try:
|
|
ticker = self._validate_ticker(ticker)
|
|
# Validate date formats
|
|
datetime.strptime(start_date, "%Y-%m-%d")
|
|
datetime.strptime(end_date, "%Y-%m-%d")
|
|
|
|
return self._news_service.get_context(
|
|
query=ticker, start_date=start_date, end_date=end_date, symbol=ticker
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting news for {ticker}: {e}")
|
|
raise
|
|
|
|
@tool
|
|
def get_socialmedia_stock_info(
|
|
self,
|
|
ticker: Annotated[
|
|
str,
|
|
"Ticker of a company. e.g. AAPL, TSM",
|
|
],
|
|
curr_date: Annotated[str, "Current date you want to get news for"],
|
|
) -> StockSocialContext:
|
|
"""
|
|
Retrieve the latest news about a given stock from Reddit, given the current date.
|
|
Args:
|
|
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
|
curr_date (str): current date in yyyy-mm-dd format to get news for
|
|
Returns:
|
|
StockSocialContext: Structured social media context with posts and sentiment analysis for the stock.
|
|
"""
|
|
try:
|
|
ticker = self._validate_ticker(ticker)
|
|
# Validate date format
|
|
datetime.strptime(curr_date, "%Y-%m-%d")
|
|
|
|
return self._socialmedia_service.get_stock_social_context(
|
|
symbol=ticker, start_date=curr_date, end_date=curr_date
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting social media info for {ticker}: {e}")
|
|
raise
|
|
|
|
@tool
|
|
def get_market_data(
|
|
self,
|
|
symbol: Annotated[str, "ticker symbol of the company"],
|
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
|
) -> PriceDataContext:
|
|
"""
|
|
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
|
Args:
|
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
start_date (str): Start date in yyyy-mm-dd format
|
|
end_date (str): End date in yyyy-mm-dd format
|
|
Returns:
|
|
PriceDataContext: Structured price data context with historical prices and key metrics.
|
|
"""
|
|
try:
|
|
symbol = self._validate_ticker(symbol)
|
|
# Validate date formats
|
|
datetime.strptime(start_date, "%Y-%m-%d")
|
|
datetime.strptime(end_date, "%Y-%m-%d")
|
|
|
|
return self._marketdata_service.get_market_data_context(
|
|
symbol=symbol, start_date=start_date, end_date=end_date
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting market data for {symbol}: {e}")
|
|
raise
|
|
|
|
@tool
|
|
def get_ta_report(
|
|
self,
|
|
symbol: Annotated[str, "ticker symbol of the company"],
|
|
indicator: Annotated[
|
|
str, "technical indicator to get the analysis and report of"
|
|
],
|
|
curr_date: Annotated[
|
|
str, "The current trading date you are trading on, YYYY-mm-dd"
|
|
],
|
|
look_back_days: Annotated[int, "how many days to look back"] = None,
|
|
) -> TAReportContext:
|
|
"""
|
|
Retrieve stock stats indicators for a given ticker symbol and indicator.
|
|
Args:
|
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
indicator (str): Technical indicator to get the analysis and report of
|
|
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
|
look_back_days (int): How many days to look back, uses config default if None
|
|
Returns:
|
|
TAReportContext: Structured technical analysis context with indicator data and signals.
|
|
"""
|
|
try:
|
|
symbol = self._validate_ticker(symbol)
|
|
if look_back_days is None:
|
|
look_back_days = self._config.default_ta_lookback_days
|
|
start_date, end_date = self._calculate_date_range(curr_date, look_back_days)
|
|
|
|
return self._marketdata_service.get_ta_report_context(
|
|
symbol=symbol,
|
|
indicator=indicator,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting TA report for {symbol}: {e}")
|
|
raise
|
|
|
|
@tool
|
|
def get_insider_sentiment(
|
|
self,
|
|
ticker: Annotated[str, "ticker symbol for the company"],
|
|
curr_date: Annotated[
|
|
str,
|
|
"current date of you are trading at, yyyy-mm-dd",
|
|
],
|
|
) -> InsiderSentimentContext:
|
|
"""
|
|
Retrieve insider sentiment information about a company (retrieved from public SEC information) for the configured lookback period
|
|
Args:
|
|
ticker (str): ticker symbol of the company
|
|
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
Returns:
|
|
InsiderSentimentContext: Structured insider sentiment analysis with transaction data and sentiment scores.
|
|
"""
|
|
try:
|
|
ticker = self._validate_ticker(ticker)
|
|
start_date, end_date = self._calculate_date_range(curr_date)
|
|
|
|
return self._insiderdata_service.get_insider_sentiment_context(
|
|
symbol=ticker, start_date=start_date, end_date=end_date
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting insider sentiment for {ticker}: {e}")
|
|
raise
|
|
|
|
@tool
|
|
def get_insider_transactions(
|
|
self,
|
|
ticker: Annotated[str, "ticker symbol"],
|
|
curr_date: Annotated[
|
|
str,
|
|
"current date you are trading at, yyyy-mm-dd",
|
|
],
|
|
) -> InsiderTransactionContext:
|
|
"""
|
|
Retrieve insider transaction information about a company (retrieved from public SEC information) for the configured lookback period
|
|
Args:
|
|
ticker (str): ticker symbol of the company
|
|
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
Returns:
|
|
InsiderTransactionContext: Structured insider transaction analysis with detailed transaction data.
|
|
"""
|
|
try:
|
|
ticker = self._validate_ticker(ticker)
|
|
start_date, end_date = self._calculate_date_range(curr_date)
|
|
|
|
return self._insiderdata_service.get_insider_transaction_context(
|
|
symbol=ticker, start_date=start_date, end_date=end_date
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting insider transactions for {ticker}: {e}")
|
|
raise
|
|
|
|
@tool
|
|
def get_balance_sheet(
|
|
self,
|
|
ticker: Annotated[str, "ticker symbol"],
|
|
freq: Annotated[
|
|
str,
|
|
"reporting frequency of the company's financial history: annual/quarterly",
|
|
],
|
|
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
) -> BalanceSheetContext:
|
|
"""
|
|
Retrieve the most recent balance sheet of a company
|
|
Args:
|
|
ticker (str): ticker symbol of the company
|
|
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
|
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
Returns:
|
|
BalanceSheetContext: Structured balance sheet analysis with key liquidity and debt metrics.
|
|
"""
|
|
return self._fundamentaldata_service.get_balance_sheet_context(
|
|
symbol=ticker,
|
|
start_date=curr_date,
|
|
end_date=curr_date,
|
|
frequency=freq.lower(),
|
|
)
|
|
|
|
@tool
|
|
def get_cashflow(
|
|
self,
|
|
ticker: Annotated[str, "ticker symbol"],
|
|
freq: Annotated[
|
|
str,
|
|
"reporting frequency of the company's financial history: annual/quarterly",
|
|
],
|
|
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
) -> CashFlowContext:
|
|
"""
|
|
Retrieve the most recent cash flow statement of a company
|
|
Args:
|
|
ticker (str): ticker symbol of the company
|
|
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
|
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
Returns:
|
|
CashFlowContext: Structured cash flow analysis with operating cash flow metrics.
|
|
"""
|
|
return self._fundamentaldata_service.get_cashflow_context(
|
|
symbol=ticker,
|
|
start_date=curr_date,
|
|
end_date=curr_date,
|
|
frequency=freq.lower(),
|
|
)
|
|
|
|
@tool
|
|
def get_income_stmt(
|
|
self,
|
|
ticker: Annotated[str, "ticker symbol"],
|
|
freq: Annotated[
|
|
str,
|
|
"reporting frequency of the company's financial history: annual/quarterly",
|
|
],
|
|
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
) -> IncomeStatementContext:
|
|
"""
|
|
Retrieve the most recent income statement of a company
|
|
Args:
|
|
ticker (str): ticker symbol of the company
|
|
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
|
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
Returns:
|
|
IncomeStatementContext: Structured income statement analysis with profitability metrics.
|
|
"""
|
|
return self._fundamentaldata_service.get_income_statement_context(
|
|
symbol=ticker,
|
|
start_date=curr_date,
|
|
end_date=curr_date,
|
|
frequency=freq.lower(),
|
|
)
|
|
|
|
def _calculate_date_range(
|
|
self, curr_date: str, lookback_days: int | None = None
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Calculate start and end dates based on current date and lookback period.
|
|
|
|
Args:
|
|
curr_date: Current date in YYYY-MM-DD format
|
|
lookback_days: Number of days to look back (uses config default if None)
|
|
|
|
Returns:
|
|
Tuple of (start_date, end_date) in YYYY-MM-DD format
|
|
|
|
Raises:
|
|
ValueError: If date format is invalid
|
|
"""
|
|
try:
|
|
curr_date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
|
except ValueError as e:
|
|
logger.error(f"Invalid date format '{curr_date}': {e}")
|
|
raise ValueError(f"Date must be in YYYY-MM-DD format, got: {curr_date}")
|
|
|
|
if lookback_days is None:
|
|
lookback_days = self._config.default_lookback_days
|
|
|
|
start_date_obj = curr_date_obj - timedelta(days=lookback_days)
|
|
return start_date_obj.strftime("%Y-%m-%d"), curr_date
|
|
|
|
def _validate_ticker(self, ticker: str) -> str:
|
|
"""
|
|
Validate and sanitize ticker symbol.
|
|
|
|
Args:
|
|
ticker: Ticker symbol to validate
|
|
|
|
Returns:
|
|
Sanitized ticker symbol
|
|
|
|
Raises:
|
|
ValueError: If ticker is invalid
|
|
"""
|
|
if not ticker or not isinstance(ticker, str):
|
|
raise ValueError("Ticker must be a non-empty string")
|
|
|
|
# Remove whitespace and convert to uppercase
|
|
ticker = ticker.strip().upper()
|
|
|
|
# Basic validation: only letters, numbers, and common symbols
|
|
if not re.match(r"^[A-Z0-9.-]{1,10}$", ticker):
|
|
raise ValueError(f"Invalid ticker format: {ticker}")
|
|
|
|
return ticker
|