TradingAgents/tradingagents/dataflows/reddit_utils.py

384 lines
12 KiB
Python

"""
Reddit API integration for social sentiment analysis.
"""
import logging
import re
from datetime import datetime, timedelta
from typing import Any
import praw
# from .api_clients import RateLimiter
logger = logging.getLogger(__name__)
# Extended ticker to company mapping
ticker_to_company = {
"AAPL": "Apple",
"MSFT": "Microsoft",
"GOOGL": "Google",
"AMZN": "Amazon",
"TSLA": "Tesla",
"NVDA": "Nvidia",
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
"JPM": "JPMorgan Chase OR JP Morgan",
"JNJ": "Johnson & Johnson OR JNJ",
"V": "Visa",
"WMT": "Walmart",
"META": "Meta OR Facebook",
"AMD": "AMD",
"INTC": "Intel",
"QCOM": "Qualcomm",
"BABA": "Alibaba",
"ADBE": "Adobe",
"NFLX": "Netflix",
"CRM": "Salesforce",
"PYPL": "PayPal",
"PLTR": "Palantir",
"MU": "Micron",
"SQ": "Block OR Square",
"ZM": "Zoom",
"CSCO": "Cisco",
"SHOP": "Shopify",
"ORCL": "Oracle",
"X": "Twitter OR X",
"SPOT": "Spotify",
"AVGO": "Broadcom",
"ASML": "ASML",
"TWLO": "Twilio",
"SNAP": "Snap Inc.",
"TEAM": "Atlassian",
"SQSP": "Squarespace",
"UBER": "Uber",
"ROKU": "Roku",
"PINS": "Pinterest",
# Additional popular tickers
"SPY": "SPDR S&P 500 ETF",
"QQQ": "Invesco QQQ Trust",
"GME": "GameStop",
"AMC": "AMC Entertainment",
"BB": "BlackBerry",
"NOK": "Nokia",
"COIN": "Coinbase",
"HOOD": "Robinhood",
"RBLX": "Roblox",
"DKNG": "DraftKings",
"PENN": "Penn Entertainment",
"SOFI": "SoFi Technologies",
}
class RedditClient:
"""Client for Reddit API with rate limiting and caching."""
def __init__(self, client_id: str, client_secret: str, user_agent: str):
"""
Initialize Reddit client.
Args:
client_id: Reddit application client ID
client_secret: Reddit application client secret
user_agent: User agent string for API requests
"""
self.client_id = client_id
self.client_secret = client_secret
self.user_agent = user_agent
# Reddit allows 100 requests per minute for script applications
# self.rate_limiter = RateLimiter(100, 60)
# Initialize Reddit instance
self.reddit = praw.Reddit(
client_id=client_id, client_secret=client_secret, user_agent=user_agent
)
# Default subreddits for different categories
self.subreddits = {
"investing": ["investing", "SecurityAnalysis", "ValueInvesting", "stocks"],
"trading": ["wallstreetbets", "StockMarket", "pennystocks", "options"],
"global_news": ["news", "worldnews", "Economics", "business"],
"company_news": ["investing", "stocks", "StockMarket", "SecurityAnalysis"],
}
def test_connection(self) -> bool:
"""Test if the Reddit API connection is working."""
try:
# Test by fetching user info (read-only operation)
_ = self.reddit.user.me()
return True
except Exception as e:
logger.error(f"Reddit connection test failed: {e}")
return False
def search_posts(
self,
query: str,
subreddit_names: list[str],
limit: int = 25,
time_filter: str = "week",
) -> list[dict[str, Any]]:
"""
Search for posts across multiple subreddits.
Args:
query: Search query
subreddit_names: List of subreddit names to search
limit: Maximum number of posts per subreddit
time_filter: Time filter ('day', 'week', 'month', 'year', 'all')
Returns:
List of post dictionaries
"""
posts = []
for subreddit_name in subreddit_names:
try:
# self.rate_limiter.wait_if_needed()
subreddit = self.reddit.subreddit(subreddit_name)
# Search posts in the subreddit
search_results = subreddit.search(
query=query, sort="relevance", time_filter=time_filter, limit=limit
)
for submission in search_results:
post_data = {
"title": submission.title,
"content": submission.selftext,
"url": submission.url,
"upvotes": submission.ups,
"score": submission.score,
"num_comments": submission.num_comments,
"created_utc": submission.created_utc,
"subreddit": subreddit_name,
"author": str(submission.author)
if submission.author
else "[deleted]",
}
posts.append(post_data)
except Exception as e:
logger.error(f"Error searching subreddit {subreddit_name}: {e}")
continue
# Sort by score (upvotes - downvotes) descending
posts.sort(key=lambda x: x["score"], reverse=True)
return posts
def get_top_posts(
self, subreddit_names: list[str], limit: int = 25, time_filter: str = "week"
) -> list[dict[str, Any]]:
"""
Get top posts from multiple subreddits.
Args:
subreddit_names: List of subreddit names
limit: Maximum number of posts per subreddit
time_filter: Time filter ('day', 'week', 'month', 'year', 'all')
Returns:
List of post dictionaries
"""
posts = []
for subreddit_name in subreddit_names:
try:
# self.rate_limiter.wait_if_needed()
subreddit = self.reddit.subreddit(subreddit_name)
# Get top posts from the subreddit
top_posts = subreddit.top(time_filter=time_filter, limit=limit)
for submission in top_posts:
post_data = {
"title": submission.title,
"content": submission.selftext,
"url": submission.url,
"upvotes": submission.ups,
"score": submission.score,
"num_comments": submission.num_comments,
"created_utc": submission.created_utc,
"subreddit": subreddit_name,
"author": str(submission.author)
if submission.author
else "[deleted]",
}
posts.append(post_data)
except Exception as e:
logger.error(f"Error fetching top posts from {subreddit_name}: {e}")
continue
# Sort by score descending
posts.sort(key=lambda x: x["score"], reverse=True)
return posts
def filter_posts_by_date(
self, posts: list[dict[str, Any]], start_date: str, end_date: str
) -> list[dict[str, Any]]:
"""
Filter posts by date range.
Args:
posts: List of post dictionaries
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
Returns:
Filtered list of posts
"""
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(
days=1
) # Include end date
filtered_posts = []
for post in posts:
post_dt = datetime.fromtimestamp(post["created_utc"])
if start_dt <= post_dt <= end_dt:
# Add formatted date string
post["posted_date"] = post_dt.strftime("%Y-%m-%d")
filtered_posts.append(post)
return filtered_posts
def filter_posts_by_company(
self, posts: list[dict[str, Any]], ticker: str
) -> list[dict[str, Any]]:
"""
Filter posts that mention a specific company/ticker.
Args:
posts: List of post dictionaries
ticker: Stock ticker symbol
Returns:
Filtered list of posts that mention the company
"""
if ticker not in ticker_to_company:
# If ticker not in mapping, search for ticker directly
search_terms = [ticker.upper()]
else:
# Get company names and ticker
company_names = ticker_to_company[ticker]
if "OR" in company_names:
search_terms = [name.strip() for name in company_names.split(" OR")]
else:
search_terms = [company_names]
search_terms.append(ticker.upper())
filtered_posts = []
for post in posts:
title_text = post["title"].lower()
content_text = post["content"].lower()
# Check if any search term appears in title or content
found = False
for term in search_terms:
term_lower = term.lower()
if re.search(
r"\b" + re.escape(term_lower) + r"\b", title_text
) or re.search(r"\b" + re.escape(term_lower) + r"\b", content_text):
found = True
break
if found:
filtered_posts.append(post)
return filtered_posts
def fetch_top_from_category(
category: str,
date: str,
max_limit: int,
query: str | None = None,
data_path: str = "reddit_data",
client_id: str | None = None,
client_secret: str | None = None,
user_agent: str | None = None,
) -> list[dict[str, Any]]:
"""
Legacy function to maintain backward compatibility.
Now uses live Reddit API if credentials are provided.
Args:
category: Category ('global_news', 'company_news', etc.)
date: Date in YYYY-MM-DD format
max_limit: Maximum number of posts
query: Optional search query (ticker for company_news)
data_path: Unused in API mode
client_id: Reddit client ID
client_secret: Reddit client secret
user_agent: Reddit user agent
Returns:
List of post dictionaries
"""
if not all([client_id, client_secret, user_agent]):
logger.warning("Reddit API credentials not provided. Returning empty data.")
return []
try:
# Type check ensures these are not None
assert client_id is not None
assert client_secret is not None
assert user_agent is not None
client = RedditClient(client_id, client_secret, user_agent)
# Determine subreddits based on category
if category == "global_news":
subreddit_names = client.subreddits["global_news"]
elif category == "company_news":
subreddit_names = client.subreddits["company_news"]
else:
# Default to investing subreddits
subreddit_names = client.subreddits["investing"]
# Calculate time filter based on date (Reddit doesn't support exact date filtering)
post_date = datetime.strptime(date, "%Y-%m-%d")
days_ago = (datetime.now() - post_date).days
if days_ago <= 1:
time_filter = "day"
elif days_ago <= 7:
time_filter = "week"
elif days_ago <= 30:
time_filter = "month"
else:
time_filter = "year"
# Get posts
if query and category == "company_news":
# Search for specific company
posts = client.search_posts(
query=query,
subreddit_names=subreddit_names,
limit=max_limit // len(subreddit_names),
time_filter=time_filter,
)
# Filter by company mentions
posts = client.filter_posts_by_company(posts, query)
else:
# Get top posts
posts = client.get_top_posts(
subreddit_names=subreddit_names,
limit=max_limit // len(subreddit_names),
time_filter=time_filter,
)
# Filter by date (approximate)
start_date = date
end_date = date
posts = client.filter_posts_by_date(posts, start_date, end_date)
# Limit results
return posts[:max_limit]
except Exception as e:
logger.error(f"Error fetching Reddit data: {e}")
return []