TradingAgents/tradingagents/dataflows/reddit.py

193 lines
6.6 KiB
Python

"""Reddit social media data provider.
Supports two modes:
- Live API (default): uses PRAW with REDDIT_CLIENT_ID / REDDIT_CLIENT_SECRET env vars
- Cache mode: reads pre-downloaded JSONL files from a local reddit_data/ directory
"""
import json
import os
import re
from datetime import datetime, timezone
TICKER_TO_COMPANY = {
"AAPL": "Apple",
"MSFT": "Microsoft",
"GOOGL": "Google",
"AMZN": "Amazon",
"TSLA": "Tesla",
"NVDA": "Nvidia",
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
"JPM": "JPMorgan Chase OR JP Morgan",
"JNJ": "Johnson & Johnson OR JNJ",
"V": "Visa",
"WMT": "Walmart",
"META": "Meta OR Facebook",
"AMD": "AMD",
"INTC": "Intel",
"QCOM": "Qualcomm",
"BABA": "Alibaba",
"ADBE": "Adobe",
"NFLX": "Netflix",
"CRM": "Salesforce",
"PYPL": "PayPal",
"PLTR": "Palantir",
"MU": "Micron",
"SQ": "Block OR Square",
"ZM": "Zoom",
"CSCO": "Cisco",
"SHOP": "Shopify",
"ORCL": "Oracle",
"X": "Twitter OR X",
"SPOT": "Spotify",
"AVGO": "Broadcom",
"ASML": "ASML",
"TWLO": "Twilio",
"SNAP": "Snap Inc.",
"TEAM": "Atlassian",
"SQSP": "Squarespace",
"UBER": "Uber",
"ROKU": "Roku",
"PINS": "Pinterest",
}
SUBREDDITS = [
"wallstreetbets",
"stocks",
"investing",
"SecurityAnalysis",
"options",
"StockMarket",
]
def _format_posts(ticker: str, start_date: str, end_date: str, posts: list) -> str:
if not posts:
return f"No Reddit posts found for {ticker} between {start_date} and {end_date}."
posts.sort(key=lambda x: x["upvotes"], reverse=True)
lines = [f"## Reddit Posts for {ticker} ({start_date} to {end_date})\n"]
for post in posts[:20]:
lines.append(f"**[r/{post['subreddit']}] {post['title']}** (↑{post['upvotes']})")
if post.get("content"):
snippet = post["content"][:300]
if len(post["content"]) > 300:
snippet += "..."
lines.append(snippet)
lines.append(f"Date: {post['date']} | URL: {post['url']}\n")
return "\n".join(lines)
def get_reddit_posts(ticker: str, start_date: str, end_date: str) -> str:
"""Fetch Reddit posts about a ticker via PRAW (live Reddit API).
Requires REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET environment variables.
Register an app at https://www.reddit.com/prefs/apps (script type).
"""
import praw
client_id = os.environ.get("REDDIT_CLIENT_ID")
client_secret = os.environ.get("REDDIT_CLIENT_SECRET")
if not client_id or not client_secret:
return (
"Reddit API credentials not configured. "
"Set REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET environment variables. "
"Register a script app at https://www.reddit.com/prefs/apps to get credentials."
)
reddit = praw.Reddit(
client_id=client_id,
client_secret=client_secret,
user_agent="TradingAgents:social-media-analyst:v1.0",
read_only=True,
)
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59)
company = TICKER_TO_COMPANY.get(ticker.upper(), ticker)
query = f"{ticker} OR {company}" if company != ticker else ticker
posts = []
for subreddit_name in SUBREDDITS:
subreddit = reddit.subreddit(subreddit_name)
for submission in subreddit.search(query, sort="new", limit=30):
post_dt = datetime.fromtimestamp(submission.created_utc, tz=timezone.utc).replace(tzinfo=None)
if not (start_dt <= post_dt <= end_dt):
continue
posts.append({
"subreddit": subreddit_name,
"title": submission.title,
"content": submission.selftext,
"upvotes": submission.score,
"url": f"https://reddit.com{submission.permalink}",
"date": post_dt.strftime("%Y-%m-%d"),
})
return _format_posts(ticker, start_date, end_date, posts)
def get_reddit_posts_from_cache(
ticker: str,
start_date: str,
end_date: str,
data_path: str = "reddit_data",
) -> str:
"""Fetch Reddit posts from pre-downloaded local JSONL files.
Expects files at: {data_path}/{category}/{subreddit}.jsonl
Each JSONL line must have: created_utc, title, selftext, url, ups.
"""
if not os.path.isdir(data_path):
return (
f"Reddit cache directory '{data_path}' not found. "
"Download Reddit data first or use the live API mode."
)
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59)
company_name = TICKER_TO_COMPANY.get(ticker.upper(), ticker)
search_terms = [ticker]
if " OR " in company_name:
search_terms.extend(company_name.split(" OR "))
elif company_name != ticker:
search_terms.append(company_name)
posts = []
for category in os.listdir(data_path):
category_path = os.path.join(data_path, category)
if not os.path.isdir(category_path):
continue
for data_file in os.listdir(category_path):
if not data_file.endswith(".jsonl"):
continue
subreddit_name = data_file.replace(".jsonl", "")
with open(os.path.join(category_path, data_file), "rb") as f:
for line in f:
if not line.strip():
continue
parsed = json.loads(line)
post_dt = datetime.fromtimestamp(parsed["created_utc"], tz=timezone.utc).replace(tzinfo=None)
if not (start_dt <= post_dt <= end_dt):
continue
# Filter by company/ticker if it's a company category
if "company" in category:
found = any(
re.search(term, parsed["title"], re.IGNORECASE)
or re.search(term, parsed.get("selftext", ""), re.IGNORECASE)
for term in search_terms
)
if not found:
continue
posts.append({
"subreddit": subreddit_name,
"title": parsed["title"],
"content": parsed.get("selftext", ""),
"upvotes": parsed["ups"],
"url": parsed.get("url", ""),
"date": post_dt.strftime("%Y-%m-%d"),
})
return _format_posts(ticker, start_date, end_date, posts)