268 lines
9.8 KiB
Python
268 lines
9.8 KiB
Python
"""
|
|
Updated Toolkit class using Service/Client/Repository architecture with JSON context.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import TYPE_CHECKING, Annotated, Any
|
|
|
|
from langchain_core.messages import HumanMessage, RemoveMessage
|
|
from langchain_core.tools import tool
|
|
|
|
from tradingagents.config import TradingAgentsConfig
|
|
from tradingagents.services.builders import build_toolkit_services
|
|
|
|
if TYPE_CHECKING:
|
|
from tradingagents.services.market_data_service import MarketDataService
|
|
from tradingagents.services.news_service import NewsService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_CONFIG = TradingAgentsConfig()
|
|
|
|
|
|
def create_msg_delete():
|
|
"""Create message deletion function for agents."""
|
|
|
|
def delete_messages(state):
|
|
"""Clear messages and add placeholder for Anthropic compatibility"""
|
|
messages = state["messages"]
|
|
|
|
# Remove all messages
|
|
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
|
|
|
# Add a minimal placeholder message
|
|
placeholder = HumanMessage(content="Continue")
|
|
|
|
return {"messages": removal_operations + [placeholder]}
|
|
|
|
return delete_messages
|
|
|
|
|
|
class Toolkit:
|
|
"""
|
|
Toolkit class that uses services to provide JSON context to agents.
|
|
|
|
This replaces the old interface.py approach with structured Pydantic models
|
|
that agents can process more dynamically.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: TradingAgentsConfig | None = None,
|
|
services: dict[str, Any] | None = None,
|
|
):
|
|
"""
|
|
Initialize Toolkit with services.
|
|
|
|
Args:
|
|
config: TradingAgents configuration
|
|
services: Pre-built services dict, or None to build from config
|
|
"""
|
|
self.config = config or DEFAULT_CONFIG
|
|
|
|
if services:
|
|
self.services = services
|
|
else:
|
|
logger.info("Building services from config")
|
|
self.services = build_toolkit_services(self.config)
|
|
|
|
# Set up individual services
|
|
self.market_service: MarketDataService | None = self.services.get("market_data")
|
|
self.news_service: NewsService | None = self.services.get("news")
|
|
|
|
logger.info(f"Toolkit initialized with {len(self.services)} services")
|
|
|
|
# Create tool methods as static methods with service access via closure
|
|
def _create_market_data_tool(self):
|
|
"""Create market data tool with service access."""
|
|
market_service = self.market_service
|
|
|
|
@tool
|
|
def get_market_data(
|
|
symbol: Annotated[str, "ticker symbol of the company"],
|
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
|
) -> str:
|
|
"""
|
|
Retrieve market data context for a given ticker symbol.
|
|
Args:
|
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
start_date (str): Start date in yyyy-mm-dd format
|
|
end_date (str): End date in yyyy-mm-dd format
|
|
Returns:
|
|
str: JSON context containing market data with price data and metadata
|
|
"""
|
|
if not market_service:
|
|
return _create_error_context("MarketDataService not available")
|
|
|
|
try:
|
|
context = market_service.get_price_context(symbol, start_date, end_date)
|
|
return context.model_dump_json(indent=2)
|
|
except Exception as e:
|
|
logger.error(f"Error getting market data for {symbol}: {e}")
|
|
return _create_error_context(f"Error fetching market data: {str(e)}")
|
|
|
|
return get_market_data
|
|
|
|
def _create_market_indicators_tool(self):
|
|
"""Create market data with indicators tool."""
|
|
market_service = self.market_service
|
|
|
|
@tool
|
|
def get_market_data_with_indicators(
|
|
symbol: Annotated[str, "ticker symbol of the company"],
|
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
|
indicators: Annotated[
|
|
str, "Comma-separated list of indicators (e.g. 'rsi,macd,close_50_sma')"
|
|
] = "rsi,macd",
|
|
) -> str:
|
|
"""
|
|
Retrieve market data context with technical indicators.
|
|
Args:
|
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
start_date (str): Start date in yyyy-mm-dd format
|
|
end_date (str): End date in yyyy-mm-dd format
|
|
indicators (str): Comma-separated indicators
|
|
Returns:
|
|
str: JSON context containing market data with technical indicators
|
|
"""
|
|
if not market_service:
|
|
return _create_error_context("MarketDataService not available")
|
|
|
|
try:
|
|
indicator_list = [i.strip() for i in indicators.split(",") if i.strip()]
|
|
context = market_service.get_context(
|
|
symbol, start_date, end_date, indicators=indicator_list
|
|
)
|
|
return context.model_dump_json(indent=2)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting market data with indicators for {symbol}: {e}"
|
|
)
|
|
return _create_error_context(
|
|
f"Error fetching market data with indicators: {str(e)}"
|
|
)
|
|
|
|
return get_market_data_with_indicators
|
|
|
|
def _create_company_news_tool(self):
|
|
"""Create company news tool."""
|
|
news_service = self.news_service
|
|
|
|
@tool
|
|
def get_company_news(
|
|
symbol: Annotated[str, "ticker symbol of the company"],
|
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
|
) -> str:
|
|
"""
|
|
Retrieve news context for a specific company.
|
|
Args:
|
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
start_date (str): Start date in yyyy-mm-dd format
|
|
end_date (str): End date in yyyy-mm-dd format
|
|
Returns:
|
|
str: JSON context containing news articles, sentiment analysis, and metadata
|
|
"""
|
|
if not news_service:
|
|
return _create_error_context("NewsService not available")
|
|
|
|
try:
|
|
context = news_service.get_company_news_context(
|
|
symbol, start_date, end_date
|
|
)
|
|
return context.model_dump_json(indent=2)
|
|
except Exception as e:
|
|
logger.error(f"Error getting company news for {symbol}: {e}")
|
|
return _create_error_context(f"Error fetching company news: {str(e)}")
|
|
|
|
return get_company_news
|
|
|
|
def _create_global_news_tool(self):
|
|
"""Create global news tool."""
|
|
news_service = self.news_service
|
|
|
|
@tool
|
|
def get_global_news(
|
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
|
categories: Annotated[
|
|
str, "Comma-separated news categories (e.g. 'economy,markets,finance')"
|
|
] = "economy,markets",
|
|
) -> str:
|
|
"""
|
|
Retrieve global/macro news context.
|
|
Args:
|
|
start_date (str): Start date in yyyy-mm-dd format
|
|
end_date (str): End date in yyyy-mm-dd format
|
|
categories (str): Comma-separated news categories
|
|
Returns:
|
|
str: JSON context containing global news articles and sentiment analysis
|
|
"""
|
|
if not news_service:
|
|
return _create_error_context("NewsService not available")
|
|
|
|
try:
|
|
category_list = [c.strip() for c in categories.split(",") if c.strip()]
|
|
context = news_service.get_global_news_context(
|
|
start_date, end_date, categories=category_list
|
|
)
|
|
return context.model_dump_json(indent=2)
|
|
except Exception as e:
|
|
logger.error(f"Error getting global news: {e}")
|
|
return _create_error_context(f"Error fetching global news: {str(e)}")
|
|
|
|
return get_global_news
|
|
|
|
def get_tools(self):
|
|
"""Get all available tools as LangChain tools."""
|
|
tools = []
|
|
|
|
if self.market_service:
|
|
tools.append(self._create_market_data_tool())
|
|
tools.append(self._create_market_indicators_tool())
|
|
|
|
if self.news_service:
|
|
tools.append(self._create_company_news_tool())
|
|
tools.append(self._create_global_news_tool())
|
|
|
|
return tools
|
|
|
|
def get_available_tools(self) -> list:
|
|
"""Get list of available tool names based on configured services."""
|
|
tools = []
|
|
|
|
if self.market_service:
|
|
tools.extend(["get_market_data", "get_market_data_with_indicators"])
|
|
|
|
if self.news_service:
|
|
tools.extend(["get_company_news", "get_global_news"])
|
|
|
|
return tools
|
|
|
|
def get_toolkit_info(self) -> dict[str, Any]:
|
|
"""Get information about the toolkit configuration."""
|
|
return {
|
|
"toolkit_type": "service_based",
|
|
"config": {
|
|
"online_mode": self.config.online_tools,
|
|
"data_dir": self.config.data_dir,
|
|
},
|
|
"services": list(self.services.keys()),
|
|
"available_tools": self.get_available_tools(),
|
|
}
|
|
|
|
|
|
def _create_error_context(error_message: str) -> str:
|
|
"""Create a JSON error context."""
|
|
import json
|
|
|
|
error_context = {
|
|
"error": True,
|
|
"message": error_message,
|
|
"metadata": {"created_at": datetime.utcnow().isoformat(), "source": "toolkit"},
|
|
}
|
|
return json.dumps(error_context, indent=2)
|