Update
This commit is contained in:
parent
5cf57e5d97
commit
ccc78c694b
|
|
@ -39,7 +39,7 @@ Volatility Indicators:
|
|||
Volume-Based Indicators:
|
||||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||
|
||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then call get_indicators SEPARATELY for EACH indicator you want to analyze (e.g., call get_indicators once with indicator="rsi", then call it again with indicator="macd", etc.). Do NOT pass multiple indicators in a single call. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Union, Dict, Optional
|
||||
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||
|
||||
def get_news(ticker: str = None, start_date: str = None, end_date: str = None, query: str = None) -> dict[str, str] | str:
|
||||
def get_news(ticker: str = None, start_date: str = None, end_date: str = None, query: str = None) -> Union[Dict[str, str], str]:
|
||||
"""Returns live and historical market news & sentiment data.
|
||||
|
||||
Args:
|
||||
|
|
@ -28,7 +29,7 @@ def get_news(ticker: str = None, start_date: str = None, end_date: str = None, q
|
|||
return _make_api_request("NEWS_SENTIMENT", params)
|
||||
|
||||
|
||||
def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> dict[str, str] | str:
|
||||
def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> Union[Dict[str, str], str]:
|
||||
"""Returns global market news & sentiment data.
|
||||
|
||||
Args:
|
||||
|
|
@ -48,7 +49,7 @@ def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> dict[
|
|||
|
||||
return _make_api_request("NEWS_SENTIMENT", params)
|
||||
|
||||
def get_insider_transactions(symbol: str = None, ticker: str = None, curr_date: str = None) -> dict[str, str] | str:
|
||||
def get_insider_transactions(symbol: str = None, ticker: str = None, curr_date: str = None) -> Union[Dict[str, str], str]:
|
||||
"""Returns latest and historical insider transactions.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -63,3 +63,157 @@ def get_recommendation_trends(
|
|||
|
||||
except Exception as e:
|
||||
return f"Error fetching recommendation trends for {ticker}: {str(e)}"
|
||||
|
||||
|
||||
def get_earnings_calendar(
|
||||
from_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
to_date: Annotated[str, "End date in yyyy-mm-dd format"]
|
||||
) -> str:
|
||||
"""
|
||||
Get earnings calendar for stocks with upcoming earnings announcements.
|
||||
|
||||
Args:
|
||||
from_date: Start date in yyyy-mm-dd format
|
||||
to_date: End date in yyyy-mm-dd format
|
||||
|
||||
Returns:
|
||||
str: Formatted report of upcoming earnings
|
||||
"""
|
||||
try:
|
||||
client = get_finnhub_client()
|
||||
data = client.earnings_calendar(
|
||||
_from=from_date,
|
||||
to=to_date,
|
||||
symbol="", # Empty string returns all stocks
|
||||
international=False
|
||||
)
|
||||
|
||||
if not data or 'earningsCalendar' not in data:
|
||||
return f"No earnings data found for period {from_date} to {to_date}"
|
||||
|
||||
earnings = data['earningsCalendar']
|
||||
|
||||
if not earnings:
|
||||
return f"No earnings scheduled between {from_date} and {to_date}"
|
||||
|
||||
# Format the response
|
||||
result = f"## Earnings Calendar ({from_date} to {to_date})\n\n"
|
||||
result += f"**Total Companies**: {len(earnings)}\n\n"
|
||||
|
||||
# Group by date
|
||||
by_date = {}
|
||||
for entry in earnings:
|
||||
date = entry.get('date', 'Unknown')
|
||||
if date not in by_date:
|
||||
by_date[date] = []
|
||||
by_date[date].append(entry)
|
||||
|
||||
# Format by date
|
||||
for date in sorted(by_date.keys()):
|
||||
result += f"### {date}\n\n"
|
||||
|
||||
for entry in by_date[date]:
|
||||
symbol = entry.get('symbol', 'N/A')
|
||||
eps_estimate = entry.get('epsEstimate', 'N/A')
|
||||
eps_actual = entry.get('epsActual', 'N/A')
|
||||
revenue_estimate = entry.get('revenueEstimate', 'N/A')
|
||||
revenue_actual = entry.get('revenueActual', 'N/A')
|
||||
hour = entry.get('hour', 'N/A')
|
||||
|
||||
result += f"**{symbol}**"
|
||||
if hour != 'N/A':
|
||||
result += f" ({hour})"
|
||||
result += "\n"
|
||||
|
||||
if eps_estimate != 'N/A':
|
||||
result += f" - EPS Estimate: ${eps_estimate:.2f}" if isinstance(eps_estimate, (int, float)) else f" - EPS Estimate: {eps_estimate}"
|
||||
if eps_actual != 'N/A':
|
||||
result += f" | Actual: ${eps_actual:.2f}" if isinstance(eps_actual, (int, float)) else f" | Actual: {eps_actual}"
|
||||
result += "\n"
|
||||
|
||||
if revenue_estimate != 'N/A':
|
||||
result += f" - Revenue Estimate: ${revenue_estimate:,.0f}M" if isinstance(revenue_estimate, (int, float)) else f" - Revenue Estimate: {revenue_estimate}"
|
||||
if revenue_actual != 'N/A':
|
||||
result += f" | Actual: ${revenue_actual:,.0f}M" if isinstance(revenue_actual, (int, float)) else f" | Actual: {revenue_actual}"
|
||||
result += "\n"
|
||||
|
||||
result += "\n"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error fetching earnings calendar: {str(e)}"
|
||||
|
||||
|
||||
def get_ipo_calendar(
|
||||
from_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
to_date: Annotated[str, "End date in yyyy-mm-dd format"]
|
||||
) -> str:
|
||||
"""
|
||||
Get IPO calendar for upcoming and recent initial public offerings.
|
||||
|
||||
Args:
|
||||
from_date: Start date in yyyy-mm-dd format
|
||||
to_date: End date in yyyy-mm-dd format
|
||||
|
||||
Returns:
|
||||
str: Formatted report of IPOs
|
||||
"""
|
||||
try:
|
||||
client = get_finnhub_client()
|
||||
data = client.ipo_calendar(
|
||||
_from=from_date,
|
||||
to=to_date
|
||||
)
|
||||
|
||||
if not data or 'ipoCalendar' not in data:
|
||||
return f"No IPO data found for period {from_date} to {to_date}"
|
||||
|
||||
ipos = data['ipoCalendar']
|
||||
|
||||
if not ipos:
|
||||
return f"No IPOs scheduled between {from_date} and {to_date}"
|
||||
|
||||
# Format the response
|
||||
result = f"## IPO Calendar ({from_date} to {to_date})\n\n"
|
||||
result += f"**Total IPOs**: {len(ipos)}\n\n"
|
||||
|
||||
# Group by date
|
||||
by_date = {}
|
||||
for entry in ipos:
|
||||
date = entry.get('date', 'Unknown')
|
||||
if date not in by_date:
|
||||
by_date[date] = []
|
||||
by_date[date].append(entry)
|
||||
|
||||
# Format by date
|
||||
for date in sorted(by_date.keys()):
|
||||
result += f"### {date}\n\n"
|
||||
|
||||
for entry in by_date[date]:
|
||||
symbol = entry.get('symbol', 'N/A')
|
||||
name = entry.get('name', 'N/A')
|
||||
exchange = entry.get('exchange', 'N/A')
|
||||
price = entry.get('price', 'N/A')
|
||||
shares = entry.get('numberOfShares', 'N/A')
|
||||
total_shares = entry.get('totalSharesValue', 'N/A')
|
||||
status = entry.get('status', 'N/A')
|
||||
|
||||
result += f"**{symbol}** - {name}\n"
|
||||
result += f" - Exchange: {exchange}\n"
|
||||
|
||||
if price != 'N/A':
|
||||
result += f" - Price: ${price}\n"
|
||||
|
||||
if shares != 'N/A':
|
||||
result += f" - Shares Offered: {shares:,}\n" if isinstance(shares, (int, float)) else f" - Shares Offered: {shares}\n"
|
||||
|
||||
if total_shares != 'N/A':
|
||||
result += f" - Total Value: ${total_shares:,.0f}M\n" if isinstance(total_shares, (int, float)) else f" - Total Value: {total_shares}\n"
|
||||
|
||||
result += f" - Status: {status}\n\n"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error fetching IPO calendar: {str(e)}"
|
||||
|
|
|
|||
|
|
@ -5,13 +5,23 @@ from .googlenews_utils import getNewsData
|
|||
|
||||
|
||||
def get_google_news(
|
||||
query: Annotated[str, "Query to search with"],
|
||||
query: Annotated[str, "Query to search with"] = None,
|
||||
ticker: Annotated[str, "Ticker symbol (alias for query)"] = None,
|
||||
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"] = None,
|
||||
look_back_days: Annotated[int, "how many days to look back"] = None,
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"] = None,
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"] = None,
|
||||
) -> str:
|
||||
query = query.replace(" ", "+")
|
||||
# Handle parameter aliasing (query or ticker)
|
||||
if query:
|
||||
search_query = query
|
||||
elif ticker:
|
||||
# Format ticker as a natural language query for better results
|
||||
search_query = f"latest news on {ticker} stock"
|
||||
else:
|
||||
raise ValueError("Must provide either 'query' or 'ticker' parameter")
|
||||
|
||||
search_query = search_query.replace(" ", "+")
|
||||
|
||||
# Determine date range
|
||||
if start_date and end_date:
|
||||
|
|
@ -24,7 +34,7 @@ def get_google_news(
|
|||
else:
|
||||
raise ValueError("Must provide either (start_date, end_date) or (curr_date, look_back_days)")
|
||||
|
||||
news_results = getNewsData(query, before, target_date)
|
||||
news_results = getNewsData(search_query, before, target_date)
|
||||
|
||||
news_str = ""
|
||||
|
||||
|
|
@ -36,7 +46,7 @@ def get_google_news(
|
|||
if len(news_results) == 0:
|
||||
return ""
|
||||
|
||||
return f"## {query} Google News, from {before} to {target_date}:\n\n{news_str}"
|
||||
return f"## {search_query} Google News, from {before} to {target_date}:\n\n{news_str}"
|
||||
|
||||
|
||||
def get_global_news_google(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,24 @@ from openai import OpenAI
|
|||
from .config import get_config
|
||||
|
||||
|
||||
def get_stock_news_openai(query, start_date, end_date):
|
||||
def get_stock_news_openai(query=None, ticker=None, start_date=None, end_date=None):
|
||||
"""Get stock news from OpenAI web search.
|
||||
|
||||
Args:
|
||||
query: Search query or ticker symbol
|
||||
ticker: Ticker symbol (alias for query)
|
||||
start_date: Start date yyyy-mm-dd
|
||||
end_date: End date yyyy-mm-dd
|
||||
"""
|
||||
# Handle parameter aliasing
|
||||
if query:
|
||||
search_query = query
|
||||
elif ticker:
|
||||
# Format ticker as a natural language query for better results
|
||||
search_query = f"latest news and market analysis on {ticker} stock"
|
||||
else:
|
||||
raise ValueError("Must provide either 'query' or 'ticker' parameter")
|
||||
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
|
||||
|
|
@ -10,7 +27,7 @@ def get_stock_news_openai(query, start_date, end_date):
|
|||
response = client.responses.create(
|
||||
model="gpt-4o-mini",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
input=f"Search Social Media for {query} from {start_date} to {end_date}. Make sure you only get the data posted during that period."
|
||||
input=f"Search Social Media and news sources for {search_query} from {start_date} to {end_date}. Make sure you only get the data posted during that period."
|
||||
)
|
||||
return response.output_text
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_indicators
|
||||
)
|
||||
from tradingagents.tools.executor import execute_tool
|
||||
from tradingagents.schemas import TickerList, MarketMovers
|
||||
from tradingagents.schemas import TickerList, MarketMovers, ThemeList
|
||||
|
||||
class DiscoveryGraph:
|
||||
def __init__(self, config=None):
|
||||
|
|
@ -76,6 +76,66 @@ class DiscoveryGraph:
|
|||
|
||||
candidates = []
|
||||
|
||||
# 0. Macro Theme Discovery (Top-Down)
|
||||
try:
|
||||
from datetime import datetime
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Get Global News
|
||||
global_news = execute_tool("get_global_news", date=today, limit=5)
|
||||
|
||||
# Extract Themes
|
||||
prompt = f"""Based on this global news, identify 3 trending market themes or sectors (e.g., 'Artificial Intelligence', 'Oil', 'Biotech').
|
||||
Return a JSON object with a 'themes' array of strings.
|
||||
|
||||
News:
|
||||
{global_news}
|
||||
"""
|
||||
|
||||
structured_llm = self.quick_thinking_llm.with_structured_output(
|
||||
schema=ThemeList.model_json_schema(),
|
||||
method="json_schema"
|
||||
)
|
||||
response = structured_llm.invoke([HumanMessage(content=prompt)])
|
||||
themes = response.get("themes", [])
|
||||
|
||||
print(f" Identified Macro Themes: {themes}")
|
||||
|
||||
# Find tickers for each theme
|
||||
for theme in themes:
|
||||
try:
|
||||
tweets_report = execute_tool("get_tweets", query=f"{theme} stocks", count=15)
|
||||
|
||||
prompt = f"""Extract ONLY valid stock ticker symbols related to the theme '{theme}' from this report.
|
||||
Return a comma-separated list of tickers (1-5 uppercase letters).
|
||||
|
||||
Report:
|
||||
{tweets_report}
|
||||
|
||||
Return a JSON object with a 'tickers' array."""
|
||||
|
||||
structured_llm = self.quick_thinking_llm.with_structured_output(
|
||||
schema=TickerList.model_json_schema(),
|
||||
method="json_schema"
|
||||
)
|
||||
response = structured_llm.invoke([HumanMessage(content=prompt)])
|
||||
theme_tickers = response.get("tickers", [])
|
||||
|
||||
for t in theme_tickers:
|
||||
t = t.upper().strip()
|
||||
if re.match(r'^[A-Z]{1,5}$', t):
|
||||
# Use validate_ticker tool logic (via execute_tool)
|
||||
try:
|
||||
if execute_tool("validate_ticker", symbol=t):
|
||||
candidates.append({"ticker": t, "source": f"macro_theme_{theme}", "sentiment": "unknown"})
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f" Error fetching tickers for theme {theme}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error in Macro Theme Discovery: {e}")
|
||||
|
||||
# 1. Get Reddit Trending (Social Sentiment)
|
||||
try:
|
||||
reddit_report = execute_tool("get_trending_tickers", limit=self.reddit_trending_limit)
|
||||
|
|
@ -188,7 +248,73 @@ Data:
|
|||
|
||||
except Exception as e:
|
||||
print(f" Error fetching Market Movers: {e}")
|
||||
|
||||
|
||||
# 3. Get Earnings Calendar (Event-based Discovery)
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
today = datetime.now()
|
||||
from_date = today.strftime("%Y-%m-%d")
|
||||
to_date = (today + timedelta(days=7)).strftime("%Y-%m-%d") # Next 7 days
|
||||
|
||||
earnings_report = execute_tool("get_earnings_calendar", from_date=from_date, to_date=to_date)
|
||||
|
||||
# Extract tickers from earnings calendar
|
||||
prompt = """Extract ONLY valid stock ticker symbols from this earnings calendar.
|
||||
Return a comma-separated list of tickers (1-5 uppercase letters).
|
||||
Only include actual stock tickers, not indexes or other symbols.
|
||||
|
||||
Earnings Calendar:
|
||||
{report}
|
||||
|
||||
Return a JSON object with a 'tickers' array containing only valid stock ticker symbols.""".format(report=earnings_report)
|
||||
|
||||
structured_llm = self.quick_thinking_llm.with_structured_output(
|
||||
schema=TickerList.model_json_schema(),
|
||||
method="json_schema"
|
||||
)
|
||||
response = structured_llm.invoke([HumanMessage(content=prompt)])
|
||||
|
||||
earnings_tickers = response.get("tickers", [])
|
||||
for t in earnings_tickers:
|
||||
t = t.upper().strip()
|
||||
if re.match(r'^[A-Z]{1,5}$', t):
|
||||
candidates.append({"ticker": t, "source": "earnings_catalyst", "sentiment": "unknown"})
|
||||
except Exception as e:
|
||||
print(f" Error fetching Earnings Calendar: {e}")
|
||||
|
||||
# 4. Get IPO Calendar (New Listings Discovery)
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
today = datetime.now()
|
||||
from_date = (today - timedelta(days=7)).strftime("%Y-%m-%d") # Past 7 days
|
||||
to_date = (today + timedelta(days=14)).strftime("%Y-%m-%d") # Next 14 days
|
||||
|
||||
ipo_report = execute_tool("get_ipo_calendar", from_date=from_date, to_date=to_date)
|
||||
|
||||
# Extract tickers from IPO calendar
|
||||
prompt = """Extract ONLY valid stock ticker symbols from this IPO calendar.
|
||||
Return a comma-separated list of tickers (1-5 uppercase letters).
|
||||
Only include actual stock tickers that are listed or about to be listed.
|
||||
|
||||
IPO Calendar:
|
||||
{report}
|
||||
|
||||
Return a JSON object with a 'tickers' array containing only valid stock ticker symbols.""".format(report=ipo_report)
|
||||
|
||||
structured_llm = self.quick_thinking_llm.with_structured_output(
|
||||
schema=TickerList.model_json_schema(),
|
||||
method="json_schema"
|
||||
)
|
||||
response = structured_llm.invoke([HumanMessage(content=prompt)])
|
||||
|
||||
ipo_tickers = response.get("tickers", [])
|
||||
for t in ipo_tickers:
|
||||
t = t.upper().strip()
|
||||
if re.match(r'^[A-Z]{1,5}$', t):
|
||||
candidates.append({"ticker": t, "source": "ipo_listing", "sentiment": "unknown"})
|
||||
except Exception as e:
|
||||
print(f" Error fetching IPO Calendar: {e}")
|
||||
|
||||
# Deduplicate
|
||||
unique_candidates = {}
|
||||
for c in candidates:
|
||||
|
|
@ -232,8 +358,30 @@ Data:
|
|||
strategy = "contrarian_value"
|
||||
elif source == "social_trending" or source == "twitter_sentiment":
|
||||
strategy = "social_hype"
|
||||
elif source == "earnings_catalyst":
|
||||
strategy = "earnings_play"
|
||||
elif source == "ipo_listing":
|
||||
strategy = "ipo_opportunity"
|
||||
|
||||
cand['strategy'] = strategy
|
||||
|
||||
# Technical Analysis Check (New)
|
||||
try:
|
||||
from datetime import datetime
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Get RSI
|
||||
rsi_data = execute_tool("get_indicators", symbol=ticker, indicator="rsi", curr_date=today, look_back_days=14)
|
||||
|
||||
# Simple parsing of the string report to find the latest value
|
||||
# The report format is usually "## rsi values...\n\nDATE: VALUE"
|
||||
# We'll just store the report for the LLM to analyze in deep dive if needed,
|
||||
# OR we can try to parse it here. For now, let's just add it to metadata.
|
||||
cand['technical_indicators'] = rsi_data
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error getting technicals for {ticker}: {e}")
|
||||
|
||||
filtered_candidates.append(cand)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -276,8 +424,9 @@ Data:
|
|||
# 1. Get News Sentiment
|
||||
news = execute_tool("get_news", ticker=ticker, start_date=start_date, end_date=end_date)
|
||||
|
||||
# 2. Get Insider Transactions
|
||||
# 2. Get Insider Transactions & Sentiment
|
||||
insider = execute_tool("get_insider_transactions", ticker=ticker)
|
||||
insider_sentiment = execute_tool("get_insider_sentiment", ticker=ticker)
|
||||
|
||||
# 3. Get Fundamentals (for the Contrarian check)
|
||||
fundamentals = execute_tool("get_fundamentals", ticker=ticker, curr_date=end_date)
|
||||
|
|
@ -290,6 +439,7 @@ Data:
|
|||
"strategy": strategy,
|
||||
"news": news,
|
||||
"insider_transactions": insider,
|
||||
"insider_sentiment": insider_sentiment,
|
||||
"fundamentals": fundamentals,
|
||||
"recommendations": recommendations
|
||||
})
|
||||
|
|
@ -313,6 +463,7 @@ Data:
|
|||
"strategy": opp["strategy"],
|
||||
# Truncate to ~1000 chars each (roughly 250 tokens)
|
||||
"news": opp["news"][:1000] + "..." if len(opp["news"]) > 1000 else opp["news"],
|
||||
"insider_sentiment": opp.get("insider_sentiment", "")[:500],
|
||||
"insider_transactions": opp["insider_transactions"][:1000] + "..." if len(opp["insider_transactions"]) > 1000 else opp["insider_transactions"],
|
||||
"fundamentals": opp["fundamentals"][:1000] + "..." if len(opp["fundamentals"]) > 1000 else opp["fundamentals"],
|
||||
"recommendations": opp["recommendations"][:1000] + "..." if len(opp["recommendations"]) > 1000 else opp["recommendations"],
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from .llm_outputs import (
|
||||
TradeDecision,
|
||||
TickerList,
|
||||
ThemeList,
|
||||
MarketMover,
|
||||
MarketMovers,
|
||||
InvestmentOpportunity,
|
||||
|
|
@ -14,6 +15,7 @@ from .llm_outputs import (
|
|||
__all__ = [
|
||||
"TradeDecision",
|
||||
"TickerList",
|
||||
"ThemeList",
|
||||
"MarketMovers",
|
||||
"MarketMover",
|
||||
"InvestmentOpportunity",
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class TradeDecision(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class TickerList(BaseModel):
|
||||
"""Structured output for ticker symbol lists."""
|
||||
|
||||
|
|
@ -34,6 +35,14 @@ class TickerList(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ThemeList(BaseModel):
|
||||
"""Structured output for market themes."""
|
||||
|
||||
themes: List[str] = Field(
|
||||
description="List of trending market themes or sectors"
|
||||
)
|
||||
|
||||
|
||||
class MarketMover(BaseModel):
|
||||
"""Individual market mover entry."""
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,16 @@ registry-based approach. All routing decisions are driven by the tool registry.
|
|||
|
||||
Key improvements over old system:
|
||||
- Single registry lookup instead of multiple dictionary lookups
|
||||
- Clear fallback logic per tool (optional, not mandatory)
|
||||
- Supports both fallback and aggregate execution modes
|
||||
- Parallel vendor execution for aggregate mode
|
||||
- Better error messages and debugging
|
||||
- No dual registry systems
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, Dict
|
||||
import logging
|
||||
from tradingagents.tools.registry import TOOL_REGISTRY, get_vendor_config
|
||||
import concurrent.futures
|
||||
from tradingagents.tools.registry import TOOL_REGISTRY, get_vendor_config, get_tool_metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -28,43 +30,30 @@ class VendorNotFoundError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def execute_tool(tool_name: str, *args, **kwargs) -> Any:
|
||||
"""Execute a tool using the registry-based routing system.
|
||||
def _execute_fallback(tool_name: str, vendor_config: Dict, *args, **kwargs) -> Any:
|
||||
"""Execute vendors sequentially with fallback (original behavior).
|
||||
|
||||
This is the simplified replacement for route_to_vendor().
|
||||
Tries vendors in priority order and returns the first successful result.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute (e.g., "get_stock_data")
|
||||
*args: Positional arguments to pass to the tool
|
||||
**kwargs: Keyword arguments to pass to the tool
|
||||
tool_name: Name of the tool
|
||||
vendor_config: Vendor configuration from registry
|
||||
*args: Positional arguments for vendor function
|
||||
**kwargs: Keyword arguments for vendor function
|
||||
|
||||
Returns:
|
||||
Result from the vendor function
|
||||
Result from first successful vendor
|
||||
|
||||
Raises:
|
||||
VendorNotFoundError: If tool or vendor implementation not found
|
||||
ToolExecutionError: If all vendors fail to execute the tool
|
||||
ToolExecutionError: If all vendors fail
|
||||
"""
|
||||
|
||||
# Step 1: Get vendor configuration from registry
|
||||
vendor_config = get_vendor_config(tool_name)
|
||||
|
||||
if not vendor_config["vendor_priority"]:
|
||||
raise VendorNotFoundError(
|
||||
f"Tool '{tool_name}' not found in registry or has no vendors configured"
|
||||
)
|
||||
|
||||
# Step 2: Get vendor functions and priority list
|
||||
vendor_functions = vendor_config["vendors"]
|
||||
vendors_to_try = vendor_config["vendor_priority"]
|
||||
|
||||
logger.debug(f"Executing tool '{tool_name}' with vendors: {vendors_to_try}")
|
||||
|
||||
# Step 3: Try each vendor in priority order
|
||||
errors = []
|
||||
|
||||
logger.debug(f"Executing tool '{tool_name}' in fallback mode with vendors: {vendors_to_try}")
|
||||
|
||||
for vendor_name in vendors_to_try:
|
||||
# Get the vendor function directly from registry
|
||||
vendor_func = vendor_functions.get(vendor_name)
|
||||
|
||||
if not vendor_func:
|
||||
|
|
@ -80,16 +69,131 @@ def execute_tool(tool_name: str, *args, **kwargs) -> Any:
|
|||
error_msg = f"Vendor '{vendor_name}' failed: {str(e)}"
|
||||
logger.warning(f"Tool '{tool_name}': {error_msg}")
|
||||
errors.append(error_msg)
|
||||
|
||||
# Continue to next vendor (fallback)
|
||||
continue
|
||||
|
||||
# Step 4: All vendors failed
|
||||
# All vendors failed
|
||||
error_summary = f"Tool '{tool_name}' failed with all vendors:\n" + "\n".join(f" - {err}" for err in errors)
|
||||
logger.error(error_summary)
|
||||
raise ToolExecutionError(error_summary)
|
||||
|
||||
|
||||
def _execute_aggregate(tool_name: str, vendor_config: Dict, metadata: Dict, *args, **kwargs) -> str:
|
||||
"""Execute multiple vendors in parallel and aggregate results.
|
||||
|
||||
Executes all specified vendors simultaneously using ThreadPoolExecutor,
|
||||
collects successful results, and combines them with vendor labels.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
vendor_config: Vendor configuration from registry
|
||||
metadata: Tool metadata from registry
|
||||
*args: Positional arguments for vendor functions
|
||||
**kwargs: Keyword arguments for vendor functions
|
||||
|
||||
Returns:
|
||||
Aggregated results from all successful vendors, formatted with labels
|
||||
|
||||
Raises:
|
||||
ToolExecutionError: If all vendors fail
|
||||
"""
|
||||
vendor_functions = vendor_config["vendors"]
|
||||
|
||||
# Get list of vendors to aggregate (default to all in priority list)
|
||||
vendors_to_aggregate = metadata.get("aggregate_vendors") or vendor_config["vendor_priority"]
|
||||
|
||||
logger.debug(f"Executing tool '{tool_name}' in aggregate mode with vendors: {vendors_to_aggregate}")
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
# Execute vendors in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(vendors_to_aggregate)) as executor:
|
||||
# Submit all vendor calls
|
||||
future_to_vendor = {}
|
||||
for vendor_name in vendors_to_aggregate:
|
||||
vendor_func = vendor_functions.get(vendor_name)
|
||||
if vendor_func:
|
||||
future = executor.submit(vendor_func, *args, **kwargs)
|
||||
future_to_vendor[future] = vendor_name
|
||||
else:
|
||||
logger.warning(f"Vendor '{vendor_name}' not found in vendors dict for tool '{tool_name}'")
|
||||
|
||||
# Collect results as they complete
|
||||
for future in concurrent.futures.as_completed(future_to_vendor):
|
||||
vendor_name = future_to_vendor[future]
|
||||
try:
|
||||
result = future.result()
|
||||
results.append({
|
||||
"vendor": vendor_name,
|
||||
"data": result
|
||||
})
|
||||
logger.debug(f"Tool '{tool_name}': vendor '{vendor_name}' succeeded")
|
||||
except Exception as e:
|
||||
error_msg = f"Vendor '{vendor_name}' failed: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.warning(f"Tool '{tool_name}': {error_msg}")
|
||||
|
||||
# Check if we got any results
|
||||
if not results:
|
||||
error_summary = f"Tool '{tool_name}' aggregate mode: all vendors failed:\n" + "\n".join(f" - {err}" for err in errors)
|
||||
logger.error(error_summary)
|
||||
raise ToolExecutionError(error_summary)
|
||||
|
||||
# Format aggregated results with clear vendor labels
|
||||
formatted_results = []
|
||||
for item in results:
|
||||
vendor_label = f"=== {item['vendor'].upper()} ==="
|
||||
formatted_results.append(f"{vendor_label}\n{item['data']}")
|
||||
|
||||
# Log partial success if some vendors failed
|
||||
if errors:
|
||||
logger.info(f"Tool '{tool_name}': {len(results)} vendors succeeded, {len(errors)} failed")
|
||||
|
||||
return "\n\n".join(formatted_results)
|
||||
|
||||
|
||||
def execute_tool(tool_name: str, *args, **kwargs) -> Any:
|
||||
"""Execute a tool using fallback or aggregate mode based on configuration.
|
||||
|
||||
This is the main entry point for tool execution. It dispatches to either
|
||||
fallback mode (sequential with early return) or aggregate mode (parallel
|
||||
with result combination) based on the tool's execution_mode setting.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute (e.g., "get_stock_data")
|
||||
*args: Positional arguments to pass to the tool
|
||||
**kwargs: Keyword arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
Result from vendor function(s). String for aggregate mode (formatted
|
||||
with vendor labels), Any for fallback mode (raw vendor result).
|
||||
|
||||
Raises:
|
||||
VendorNotFoundError: If tool or vendor implementation not found
|
||||
ToolExecutionError: If all vendors fail to execute the tool
|
||||
"""
|
||||
# Get vendor configuration and metadata from registry
|
||||
vendor_config = get_vendor_config(tool_name)
|
||||
metadata = get_tool_metadata(tool_name)
|
||||
|
||||
if not vendor_config["vendor_priority"]:
|
||||
raise VendorNotFoundError(
|
||||
f"Tool '{tool_name}' not found in registry or has no vendors configured"
|
||||
)
|
||||
|
||||
if not metadata:
|
||||
raise VendorNotFoundError(f"Tool '{tool_name}' metadata not found in registry")
|
||||
|
||||
# Check execution mode (defaults to fallback for backward compatibility)
|
||||
execution_mode = metadata.get("execution_mode", "fallback")
|
||||
|
||||
# Dispatch to appropriate execution strategy
|
||||
if execution_mode == "aggregate":
|
||||
return _execute_aggregate(tool_name, vendor_config, metadata, *args, **kwargs)
|
||||
else:
|
||||
return _execute_fallback(tool_name, vendor_config, *args, **kwargs)
|
||||
|
||||
|
||||
def get_tool_info(tool_name: str) -> Optional[dict]:
|
||||
"""Get information about a tool from the registry.
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,10 @@ def generate_langchain_tool(tool_name: str, metadata: Dict[str, Any]) -> Callabl
|
|||
# Use **kwargs to handle all parameters
|
||||
def tool_function(**kwargs):
|
||||
"""Dynamically generated tool function."""
|
||||
# Ensure defaults are applied for missing parameters
|
||||
for param_name, param_info in parameters.items():
|
||||
if param_name not in kwargs and "default" in param_info:
|
||||
kwargs[param_name] = param_info["default"]
|
||||
return execute_tool(tool_name, **kwargs)
|
||||
|
||||
# Set function metadata
|
||||
|
|
|
|||
|
|
@ -53,6 +53,8 @@ from tradingagents.dataflows.reddit_api import (
|
|||
)
|
||||
from tradingagents.dataflows.finnhub_api import (
|
||||
get_recommendation_trends as get_finnhub_recommendation_trends,
|
||||
get_earnings_calendar as get_finnhub_earnings_calendar,
|
||||
get_ipo_calendar as get_finnhub_ipo_calendar,
|
||||
)
|
||||
from tradingagents.dataflows.twitter_data import (
|
||||
get_tweets as get_twitter_tweets,
|
||||
|
|
@ -109,6 +111,8 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|||
"alpha_vantage": get_alpha_vantage_indicator,
|
||||
},
|
||||
"vendor_priority": ["yfinance", "alpha_vantage"],
|
||||
"execution_mode": "aggregate",
|
||||
"aggregate_vendors": ["yfinance", "alpha_vantage"],
|
||||
"parameters": {
|
||||
"symbol": {"type": "str", "description": "Ticker symbol"},
|
||||
"indicator": {"type": "str", "description": "Technical indicator (rsi, macd, sma, ema, etc.)"},
|
||||
|
|
@ -208,6 +212,8 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|||
"google": get_google_news,
|
||||
},
|
||||
"vendor_priority": ["alpha_vantage", "reddit", "openai", "google"],
|
||||
"execution_mode": "aggregate",
|
||||
"aggregate_vendors": ["alpha_vantage", "reddit", "google"],
|
||||
"parameters": {
|
||||
"query": {"type": "str", "description": "Search query or ticker symbol"},
|
||||
"start_date": {"type": "str", "description": "Start date, yyyy-mm-dd"},
|
||||
|
|
@ -227,6 +233,7 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|||
"alpha_vantage": get_alpha_vantage_global_news,
|
||||
},
|
||||
"vendor_priority": ["openai", "google", "reddit", "alpha_vantage"],
|
||||
"execution_mode": "aggregate",
|
||||
"parameters": {
|
||||
"date": {"type": "str", "description": "Date for news, yyyy-mm-dd"},
|
||||
"look_back_days": {"type": "int", "description": "Days to look back", "default": 7},
|
||||
|
|
@ -310,6 +317,36 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|||
"returns": "str: Tweets matching the query",
|
||||
},
|
||||
|
||||
"get_earnings_calendar": {
|
||||
"description": "Get upcoming earnings announcements (catalysts for volatility)",
|
||||
"category": "discovery",
|
||||
"agents": [],
|
||||
"vendors": {
|
||||
"finnhub": get_finnhub_earnings_calendar,
|
||||
},
|
||||
"vendor_priority": ["finnhub"],
|
||||
"parameters": {
|
||||
"from_date": {"type": "str", "description": "Start date in yyyy-mm-dd format"},
|
||||
"to_date": {"type": "str", "description": "End date in yyyy-mm-dd format"},
|
||||
},
|
||||
"returns": "str: Formatted earnings calendar with EPS and revenue estimates",
|
||||
},
|
||||
|
||||
"get_ipo_calendar": {
|
||||
"description": "Get upcoming and recent IPOs (new listing opportunities)",
|
||||
"category": "discovery",
|
||||
"agents": [],
|
||||
"vendors": {
|
||||
"finnhub": get_finnhub_ipo_calendar,
|
||||
},
|
||||
"vendor_priority": ["finnhub"],
|
||||
"parameters": {
|
||||
"from_date": {"type": "str", "description": "Start date in yyyy-mm-dd format"},
|
||||
"to_date": {"type": "str", "description": "End date in yyyy-mm-dd format"},
|
||||
},
|
||||
"returns": "str: Formatted IPO calendar with pricing and share details",
|
||||
},
|
||||
|
||||
"get_reddit_discussions": {
|
||||
"description": "Get Reddit discussions about a specific ticker",
|
||||
"category": "news_data",
|
||||
|
|
@ -465,6 +502,26 @@ def validate_registry() -> List[str]:
|
|||
if not isinstance(metadata.get("parameters"), dict):
|
||||
issues.append(f"{tool_name}: Parameters must be a dictionary")
|
||||
|
||||
# Validate execution_mode if present
|
||||
if "execution_mode" in metadata:
|
||||
execution_mode = metadata["execution_mode"]
|
||||
if execution_mode not in ["fallback", "aggregate"]:
|
||||
issues.append(f"{tool_name}: Invalid execution_mode '{execution_mode}', must be 'fallback' or 'aggregate'")
|
||||
|
||||
# Validate aggregate_vendors if present
|
||||
if "aggregate_vendors" in metadata:
|
||||
aggregate_vendors = metadata["aggregate_vendors"]
|
||||
if not isinstance(aggregate_vendors, list):
|
||||
issues.append(f"{tool_name}: aggregate_vendors must be a list")
|
||||
else:
|
||||
for vendor_name in aggregate_vendors:
|
||||
if vendor_name not in vendors:
|
||||
issues.append(f"{tool_name}: aggregate_vendor '{vendor_name}' not in vendors dict")
|
||||
|
||||
# Warn if aggregate_vendors specified but execution_mode is not aggregate
|
||||
if metadata.get("execution_mode") != "aggregate":
|
||||
issues.append(f"{tool_name}: aggregate_vendors specified but execution_mode is not 'aggregate'")
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue