384 lines
12 KiB
Python
384 lines
12 KiB
Python
"""
|
|
Reddit API integration for social sentiment analysis.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
|
|
import praw
|
|
|
|
# from .api_clients import RateLimiter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Extended ticker to company mapping
|
|
ticker_to_company = {
|
|
"AAPL": "Apple",
|
|
"MSFT": "Microsoft",
|
|
"GOOGL": "Google",
|
|
"AMZN": "Amazon",
|
|
"TSLA": "Tesla",
|
|
"NVDA": "Nvidia",
|
|
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
|
|
"JPM": "JPMorgan Chase OR JP Morgan",
|
|
"JNJ": "Johnson & Johnson OR JNJ",
|
|
"V": "Visa",
|
|
"WMT": "Walmart",
|
|
"META": "Meta OR Facebook",
|
|
"AMD": "AMD",
|
|
"INTC": "Intel",
|
|
"QCOM": "Qualcomm",
|
|
"BABA": "Alibaba",
|
|
"ADBE": "Adobe",
|
|
"NFLX": "Netflix",
|
|
"CRM": "Salesforce",
|
|
"PYPL": "PayPal",
|
|
"PLTR": "Palantir",
|
|
"MU": "Micron",
|
|
"SQ": "Block OR Square",
|
|
"ZM": "Zoom",
|
|
"CSCO": "Cisco",
|
|
"SHOP": "Shopify",
|
|
"ORCL": "Oracle",
|
|
"X": "Twitter OR X",
|
|
"SPOT": "Spotify",
|
|
"AVGO": "Broadcom",
|
|
"ASML": "ASML",
|
|
"TWLO": "Twilio",
|
|
"SNAP": "Snap Inc.",
|
|
"TEAM": "Atlassian",
|
|
"SQSP": "Squarespace",
|
|
"UBER": "Uber",
|
|
"ROKU": "Roku",
|
|
"PINS": "Pinterest",
|
|
# Additional popular tickers
|
|
"SPY": "SPDR S&P 500 ETF",
|
|
"QQQ": "Invesco QQQ Trust",
|
|
"GME": "GameStop",
|
|
"AMC": "AMC Entertainment",
|
|
"BB": "BlackBerry",
|
|
"NOK": "Nokia",
|
|
"COIN": "Coinbase",
|
|
"HOOD": "Robinhood",
|
|
"RBLX": "Roblox",
|
|
"DKNG": "DraftKings",
|
|
"PENN": "Penn Entertainment",
|
|
"SOFI": "SoFi Technologies",
|
|
}
|
|
|
|
|
|
class RedditClient:
|
|
"""Client for Reddit API with rate limiting and caching."""
|
|
|
|
def __init__(self, client_id: str, client_secret: str, user_agent: str):
|
|
"""
|
|
Initialize Reddit client.
|
|
|
|
Args:
|
|
client_id: Reddit application client ID
|
|
client_secret: Reddit application client secret
|
|
user_agent: User agent string for API requests
|
|
"""
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.user_agent = user_agent
|
|
|
|
# Reddit allows 100 requests per minute for script applications
|
|
# self.rate_limiter = RateLimiter(100, 60)
|
|
|
|
# Initialize Reddit instance
|
|
self.reddit = praw.Reddit(
|
|
client_id=client_id, client_secret=client_secret, user_agent=user_agent
|
|
)
|
|
|
|
# Default subreddits for different categories
|
|
self.subreddits = {
|
|
"investing": ["investing", "SecurityAnalysis", "ValueInvesting", "stocks"],
|
|
"trading": ["wallstreetbets", "StockMarket", "pennystocks", "options"],
|
|
"global_news": ["news", "worldnews", "Economics", "business"],
|
|
"company_news": ["investing", "stocks", "StockMarket", "SecurityAnalysis"],
|
|
}
|
|
|
|
def test_connection(self) -> bool:
|
|
"""Test if the Reddit API connection is working."""
|
|
try:
|
|
# Test by fetching user info (read-only operation)
|
|
_ = self.reddit.user.me()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Reddit connection test failed: {e}")
|
|
return False
|
|
|
|
def search_posts(
|
|
self,
|
|
query: str,
|
|
subreddit_names: list[str],
|
|
limit: int = 25,
|
|
time_filter: str = "week",
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Search for posts across multiple subreddits.
|
|
|
|
Args:
|
|
query: Search query
|
|
subreddit_names: List of subreddit names to search
|
|
limit: Maximum number of posts per subreddit
|
|
time_filter: Time filter ('day', 'week', 'month', 'year', 'all')
|
|
|
|
Returns:
|
|
List of post dictionaries
|
|
"""
|
|
posts = []
|
|
|
|
for subreddit_name in subreddit_names:
|
|
try:
|
|
# self.rate_limiter.wait_if_needed()
|
|
|
|
subreddit = self.reddit.subreddit(subreddit_name)
|
|
|
|
# Search posts in the subreddit
|
|
search_results = subreddit.search(
|
|
query=query, sort="relevance", time_filter=time_filter, limit=limit
|
|
)
|
|
|
|
for submission in search_results:
|
|
post_data = {
|
|
"title": submission.title,
|
|
"content": submission.selftext,
|
|
"url": submission.url,
|
|
"upvotes": submission.ups,
|
|
"score": submission.score,
|
|
"num_comments": submission.num_comments,
|
|
"created_utc": submission.created_utc,
|
|
"subreddit": subreddit_name,
|
|
"author": str(submission.author)
|
|
if submission.author
|
|
else "[deleted]",
|
|
}
|
|
posts.append(post_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching subreddit {subreddit_name}: {e}")
|
|
continue
|
|
|
|
# Sort by score (upvotes - downvotes) descending
|
|
posts.sort(key=lambda x: x["score"], reverse=True)
|
|
return posts
|
|
|
|
def get_top_posts(
|
|
self, subreddit_names: list[str], limit: int = 25, time_filter: str = "week"
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Get top posts from multiple subreddits.
|
|
|
|
Args:
|
|
subreddit_names: List of subreddit names
|
|
limit: Maximum number of posts per subreddit
|
|
time_filter: Time filter ('day', 'week', 'month', 'year', 'all')
|
|
|
|
Returns:
|
|
List of post dictionaries
|
|
"""
|
|
posts = []
|
|
|
|
for subreddit_name in subreddit_names:
|
|
try:
|
|
# self.rate_limiter.wait_if_needed()
|
|
|
|
subreddit = self.reddit.subreddit(subreddit_name)
|
|
|
|
# Get top posts from the subreddit
|
|
top_posts = subreddit.top(time_filter=time_filter, limit=limit)
|
|
|
|
for submission in top_posts:
|
|
post_data = {
|
|
"title": submission.title,
|
|
"content": submission.selftext,
|
|
"url": submission.url,
|
|
"upvotes": submission.ups,
|
|
"score": submission.score,
|
|
"num_comments": submission.num_comments,
|
|
"created_utc": submission.created_utc,
|
|
"subreddit": subreddit_name,
|
|
"author": str(submission.author)
|
|
if submission.author
|
|
else "[deleted]",
|
|
}
|
|
posts.append(post_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching top posts from {subreddit_name}: {e}")
|
|
continue
|
|
|
|
# Sort by score descending
|
|
posts.sort(key=lambda x: x["score"], reverse=True)
|
|
return posts
|
|
|
|
def filter_posts_by_date(
|
|
self, posts: list[dict[str, Any]], start_date: str, end_date: str
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Filter posts by date range.
|
|
|
|
Args:
|
|
posts: List of post dictionaries
|
|
start_date: Start date in YYYY-MM-DD format
|
|
end_date: End date in YYYY-MM-DD format
|
|
|
|
Returns:
|
|
Filtered list of posts
|
|
"""
|
|
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(
|
|
days=1
|
|
) # Include end date
|
|
|
|
filtered_posts = []
|
|
for post in posts:
|
|
post_dt = datetime.fromtimestamp(post["created_utc"])
|
|
if start_dt <= post_dt <= end_dt:
|
|
# Add formatted date string
|
|
post["posted_date"] = post_dt.strftime("%Y-%m-%d")
|
|
filtered_posts.append(post)
|
|
|
|
return filtered_posts
|
|
|
|
def filter_posts_by_company(
|
|
self, posts: list[dict[str, Any]], ticker: str
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Filter posts that mention a specific company/ticker.
|
|
|
|
Args:
|
|
posts: List of post dictionaries
|
|
ticker: Stock ticker symbol
|
|
|
|
Returns:
|
|
Filtered list of posts that mention the company
|
|
"""
|
|
if ticker not in ticker_to_company:
|
|
# If ticker not in mapping, search for ticker directly
|
|
search_terms = [ticker.upper()]
|
|
else:
|
|
# Get company names and ticker
|
|
company_names = ticker_to_company[ticker]
|
|
if "OR" in company_names:
|
|
search_terms = [name.strip() for name in company_names.split(" OR")]
|
|
else:
|
|
search_terms = [company_names]
|
|
search_terms.append(ticker.upper())
|
|
|
|
filtered_posts = []
|
|
for post in posts:
|
|
title_text = post["title"].lower()
|
|
content_text = post["content"].lower()
|
|
|
|
# Check if any search term appears in title or content
|
|
found = False
|
|
for term in search_terms:
|
|
term_lower = term.lower()
|
|
if re.search(
|
|
r"\b" + re.escape(term_lower) + r"\b", title_text
|
|
) or re.search(r"\b" + re.escape(term_lower) + r"\b", content_text):
|
|
found = True
|
|
break
|
|
|
|
if found:
|
|
filtered_posts.append(post)
|
|
|
|
return filtered_posts
|
|
|
|
|
|
def fetch_top_from_category(
|
|
category: str,
|
|
date: str,
|
|
max_limit: int,
|
|
query: str | None = None,
|
|
data_path: str = "reddit_data",
|
|
client_id: str | None = None,
|
|
client_secret: str | None = None,
|
|
user_agent: str | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Legacy function to maintain backward compatibility.
|
|
Now uses live Reddit API if credentials are provided.
|
|
|
|
Args:
|
|
category: Category ('global_news', 'company_news', etc.)
|
|
date: Date in YYYY-MM-DD format
|
|
max_limit: Maximum number of posts
|
|
query: Optional search query (ticker for company_news)
|
|
data_path: Unused in API mode
|
|
client_id: Reddit client ID
|
|
client_secret: Reddit client secret
|
|
user_agent: Reddit user agent
|
|
|
|
Returns:
|
|
List of post dictionaries
|
|
"""
|
|
if not all([client_id, client_secret, user_agent]):
|
|
logger.warning("Reddit API credentials not provided. Returning empty data.")
|
|
return []
|
|
|
|
try:
|
|
# Type check ensures these are not None
|
|
assert client_id is not None
|
|
assert client_secret is not None
|
|
assert user_agent is not None
|
|
client = RedditClient(client_id, client_secret, user_agent)
|
|
|
|
# Determine subreddits based on category
|
|
if category == "global_news":
|
|
subreddit_names = client.subreddits["global_news"]
|
|
elif category == "company_news":
|
|
subreddit_names = client.subreddits["company_news"]
|
|
else:
|
|
# Default to investing subreddits
|
|
subreddit_names = client.subreddits["investing"]
|
|
|
|
# Calculate time filter based on date (Reddit doesn't support exact date filtering)
|
|
post_date = datetime.strptime(date, "%Y-%m-%d")
|
|
days_ago = (datetime.now() - post_date).days
|
|
|
|
if days_ago <= 1:
|
|
time_filter = "day"
|
|
elif days_ago <= 7:
|
|
time_filter = "week"
|
|
elif days_ago <= 30:
|
|
time_filter = "month"
|
|
else:
|
|
time_filter = "year"
|
|
|
|
# Get posts
|
|
if query and category == "company_news":
|
|
# Search for specific company
|
|
posts = client.search_posts(
|
|
query=query,
|
|
subreddit_names=subreddit_names,
|
|
limit=max_limit // len(subreddit_names),
|
|
time_filter=time_filter,
|
|
)
|
|
# Filter by company mentions
|
|
posts = client.filter_posts_by_company(posts, query)
|
|
else:
|
|
# Get top posts
|
|
posts = client.get_top_posts(
|
|
subreddit_names=subreddit_names,
|
|
limit=max_limit // len(subreddit_names),
|
|
time_filter=time_filter,
|
|
)
|
|
|
|
# Filter by date (approximate)
|
|
start_date = date
|
|
end_date = date
|
|
posts = client.filter_posts_by_date(posts, start_date, end_date)
|
|
|
|
# Limit results
|
|
return posts[:max_limit]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching Reddit data: {e}")
|
|
return []
|