193 lines
6.6 KiB
Python
193 lines
6.6 KiB
Python
"""Reddit social media data provider.
|
|
|
|
Supports two modes:
|
|
- Live API (default): uses PRAW with REDDIT_CLIENT_ID / REDDIT_CLIENT_SECRET env vars
|
|
- Cache mode: reads pre-downloaded JSONL files from a local reddit_data/ directory
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
from datetime import datetime, timezone
|
|
|
|
TICKER_TO_COMPANY = {
|
|
"AAPL": "Apple",
|
|
"MSFT": "Microsoft",
|
|
"GOOGL": "Google",
|
|
"AMZN": "Amazon",
|
|
"TSLA": "Tesla",
|
|
"NVDA": "Nvidia",
|
|
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
|
|
"JPM": "JPMorgan Chase OR JP Morgan",
|
|
"JNJ": "Johnson & Johnson OR JNJ",
|
|
"V": "Visa",
|
|
"WMT": "Walmart",
|
|
"META": "Meta OR Facebook",
|
|
"AMD": "AMD",
|
|
"INTC": "Intel",
|
|
"QCOM": "Qualcomm",
|
|
"BABA": "Alibaba",
|
|
"ADBE": "Adobe",
|
|
"NFLX": "Netflix",
|
|
"CRM": "Salesforce",
|
|
"PYPL": "PayPal",
|
|
"PLTR": "Palantir",
|
|
"MU": "Micron",
|
|
"SQ": "Block OR Square",
|
|
"ZM": "Zoom",
|
|
"CSCO": "Cisco",
|
|
"SHOP": "Shopify",
|
|
"ORCL": "Oracle",
|
|
"X": "Twitter OR X",
|
|
"SPOT": "Spotify",
|
|
"AVGO": "Broadcom",
|
|
"ASML": "ASML",
|
|
"TWLO": "Twilio",
|
|
"SNAP": "Snap Inc.",
|
|
"TEAM": "Atlassian",
|
|
"SQSP": "Squarespace",
|
|
"UBER": "Uber",
|
|
"ROKU": "Roku",
|
|
"PINS": "Pinterest",
|
|
}
|
|
|
|
SUBREDDITS = [
|
|
"wallstreetbets",
|
|
"stocks",
|
|
"investing",
|
|
"SecurityAnalysis",
|
|
"options",
|
|
"StockMarket",
|
|
]
|
|
|
|
|
|
def _format_posts(ticker: str, start_date: str, end_date: str, posts: list) -> str:
|
|
if not posts:
|
|
return f"No Reddit posts found for {ticker} between {start_date} and {end_date}."
|
|
|
|
posts.sort(key=lambda x: x["upvotes"], reverse=True)
|
|
lines = [f"## Reddit Posts for {ticker} ({start_date} to {end_date})\n"]
|
|
for post in posts[:20]:
|
|
lines.append(f"**[r/{post['subreddit']}] {post['title']}** (↑{post['upvotes']})")
|
|
if post.get("content"):
|
|
snippet = post["content"][:300]
|
|
if len(post["content"]) > 300:
|
|
snippet += "..."
|
|
lines.append(snippet)
|
|
lines.append(f"Date: {post['date']} | URL: {post['url']}\n")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def get_reddit_posts(ticker: str, start_date: str, end_date: str) -> str:
|
|
"""Fetch Reddit posts about a ticker via PRAW (live Reddit API).
|
|
|
|
Requires REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET environment variables.
|
|
Register an app at https://www.reddit.com/prefs/apps (script type).
|
|
"""
|
|
import praw
|
|
|
|
client_id = os.environ.get("REDDIT_CLIENT_ID")
|
|
client_secret = os.environ.get("REDDIT_CLIENT_SECRET")
|
|
|
|
if not client_id or not client_secret:
|
|
return (
|
|
"Reddit API credentials not configured. "
|
|
"Set REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET environment variables. "
|
|
"Register a script app at https://www.reddit.com/prefs/apps to get credentials."
|
|
)
|
|
|
|
reddit = praw.Reddit(
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
user_agent="TradingAgents:social-media-analyst:v1.0",
|
|
read_only=True,
|
|
)
|
|
|
|
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59)
|
|
company = TICKER_TO_COMPANY.get(ticker.upper(), ticker)
|
|
query = f"{ticker} OR {company}" if company != ticker else ticker
|
|
|
|
posts = []
|
|
for subreddit_name in SUBREDDITS:
|
|
subreddit = reddit.subreddit(subreddit_name)
|
|
for submission in subreddit.search(query, sort="new", limit=30):
|
|
post_dt = datetime.fromtimestamp(submission.created_utc, tz=timezone.utc).replace(tzinfo=None)
|
|
if not (start_dt <= post_dt <= end_dt):
|
|
continue
|
|
posts.append({
|
|
"subreddit": subreddit_name,
|
|
"title": submission.title,
|
|
"content": submission.selftext,
|
|
"upvotes": submission.score,
|
|
"url": f"https://reddit.com{submission.permalink}",
|
|
"date": post_dt.strftime("%Y-%m-%d"),
|
|
})
|
|
|
|
return _format_posts(ticker, start_date, end_date, posts)
|
|
|
|
|
|
def get_reddit_posts_from_cache(
|
|
ticker: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
data_path: str = "reddit_data",
|
|
) -> str:
|
|
"""Fetch Reddit posts from pre-downloaded local JSONL files.
|
|
|
|
Expects files at: {data_path}/{category}/{subreddit}.jsonl
|
|
Each JSONL line must have: created_utc, title, selftext, url, ups.
|
|
"""
|
|
if not os.path.isdir(data_path):
|
|
return (
|
|
f"Reddit cache directory '{data_path}' not found. "
|
|
"Download Reddit data first or use the live API mode."
|
|
)
|
|
|
|
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59)
|
|
company_name = TICKER_TO_COMPANY.get(ticker.upper(), ticker)
|
|
|
|
search_terms = [ticker]
|
|
if " OR " in company_name:
|
|
search_terms.extend(company_name.split(" OR "))
|
|
elif company_name != ticker:
|
|
search_terms.append(company_name)
|
|
|
|
posts = []
|
|
for category in os.listdir(data_path):
|
|
category_path = os.path.join(data_path, category)
|
|
if not os.path.isdir(category_path):
|
|
continue
|
|
for data_file in os.listdir(category_path):
|
|
if not data_file.endswith(".jsonl"):
|
|
continue
|
|
subreddit_name = data_file.replace(".jsonl", "")
|
|
with open(os.path.join(category_path, data_file), "rb") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
parsed = json.loads(line)
|
|
post_dt = datetime.fromtimestamp(parsed["created_utc"], tz=timezone.utc).replace(tzinfo=None)
|
|
if not (start_dt <= post_dt <= end_dt):
|
|
continue
|
|
# Filter by company/ticker if it's a company category
|
|
if "company" in category:
|
|
found = any(
|
|
re.search(term, parsed["title"], re.IGNORECASE)
|
|
or re.search(term, parsed.get("selftext", ""), re.IGNORECASE)
|
|
for term in search_terms
|
|
)
|
|
if not found:
|
|
continue
|
|
posts.append({
|
|
"subreddit": subreddit_name,
|
|
"title": parsed["title"],
|
|
"content": parsed.get("selftext", ""),
|
|
"upvotes": parsed["ups"],
|
|
"url": parsed.get("url", ""),
|
|
"date": post_dt.strftime("%Y-%m-%d"),
|
|
})
|
|
|
|
return _format_posts(ticker, start_date, end_date, posts)
|