diff --git a/pyproject.toml b/pyproject.toml index d222ca4d..32916ee0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "typer>=0.21.0", "setuptools>=80.9.0", "finnhub-python>=2.4.20", + "praw>=7.8.1", "python-dateutil>=2.9.0", "simfin>=1.0.3", "stockstats>=0.6.5", diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 34a53c46..b54d49b6 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,5 +1,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news +from tradingagents.agents.utils.social_media_tools import get_reddit_posts from tradingagents.dataflows.config import get_config @@ -10,10 +11,11 @@ def create_social_media_analyst(llm): tools = [ get_news, + get_reddit_posts, ] system_message = ( - "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions." + "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(ticker, start_date, end_date) tool to search for company-specific news, and use get_reddit_posts(ticker, start_date, end_date) to gather social media discussions and sentiment from Reddit communities like r/wallstreetbets, r/stocks, and r/investing. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions." + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + get_language_instruction() ) diff --git a/tradingagents/agents/utils/social_media_tools.py b/tradingagents/agents/utils/social_media_tools.py new file mode 100644 index 00000000..0295f1e5 --- /dev/null +++ b/tradingagents/agents/utils/social_media_tools.py @@ -0,0 +1,23 @@ +from langchain_core.tools import tool +from typing import Annotated +from tradingagents.dataflows.interface import route_to_vendor + + +@tool +def get_reddit_posts( + ticker: Annotated[str, "Ticker symbol (e.g. AAPL, TSLA)"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], +) -> str: + """ + Retrieve Reddit posts discussing a given stock ticker from social media communities + such as r/wallstreetbets, r/stocks, r/investing, and more. + Uses the configured social_media_data vendor. + Args: + ticker (str): Ticker symbol (e.g. AAPL, TSLA) + start_date (str): Start date in yyyy-mm-dd format + end_date (str): End date in yyyy-mm-dd format + Returns: + str: Formatted Reddit posts with titles, content snippets, upvotes, and dates + """ + return route_to_vendor("get_social_media_posts", ticker, start_date, end_date) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index e9a36c78..ba22d328 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -33,6 +33,7 @@ from .alpha_vantage import ( get_global_news as get_alpha_vantage_global_news, ) from .alpha_vantage_common import AlphaVantageRateLimitError +from .reddit import get_reddit_posts, get_reddit_posts_from_cache # Configuration and routing logic from .config import get_config @@ -67,7 +68,13 @@ TOOLS_CATEGORIES = { "get_global_news", "get_insider_transactions", ] - } + }, + "social_media_data": { + "description": "Social media posts and sentiment", + "tools": [ + "get_social_media_posts", + ] + }, } VENDOR_LIST = [ @@ -75,6 +82,7 @@ VENDOR_LIST = [ "alpha_vantage", "finnhub", "simfin", + "reddit", ] # Mapping of methods to their vendor-specific implementations @@ -125,6 +133,11 @@ VENDOR_METHODS = { "yfinance": get_yfinance_insider_transactions, "finnhub": get_finnhub_insider_transactions, }, + # social_media_data + "get_social_media_posts": { + "reddit": get_reddit_posts, + "reddit_cache": get_reddit_posts_from_cache, + }, } def get_category_for_method(method: str) -> str: @@ -134,7 +147,7 @@ def get_category_for_method(method: str) -> str: return category raise ValueError(f"Method '{method}' not found in any category") -def get_vendor(category: str, method: str = None) -> str: +def get_vendor(category: str, method: str | None = None) -> str: """Get the configured vendor for a data category or specific tool method. Tool-level configuration takes precedence over category-level. """ diff --git a/tradingagents/dataflows/reddit.py b/tradingagents/dataflows/reddit.py new file mode 100644 index 00000000..7b515f2e --- /dev/null +++ b/tradingagents/dataflows/reddit.py @@ -0,0 +1,192 @@ +"""Reddit social media data provider. + +Supports two modes: +- Live API (default): uses PRAW with REDDIT_CLIENT_ID / REDDIT_CLIENT_SECRET env vars +- Cache mode: reads pre-downloaded JSONL files from a local reddit_data/ directory +""" + +import json +import os +import re +from datetime import datetime, timezone + +TICKER_TO_COMPANY = { + "AAPL": "Apple", + "MSFT": "Microsoft", + "GOOGL": "Google", + "AMZN": "Amazon", + "TSLA": "Tesla", + "NVDA": "Nvidia", + "TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC", + "JPM": "JPMorgan Chase OR JP Morgan", + "JNJ": "Johnson & Johnson OR JNJ", + "V": "Visa", + "WMT": "Walmart", + "META": "Meta OR Facebook", + "AMD": "AMD", + "INTC": "Intel", + "QCOM": "Qualcomm", + "BABA": "Alibaba", + "ADBE": "Adobe", + "NFLX": "Netflix", + "CRM": "Salesforce", + "PYPL": "PayPal", + "PLTR": "Palantir", + "MU": "Micron", + "SQ": "Block OR Square", + "ZM": "Zoom", + "CSCO": "Cisco", + "SHOP": "Shopify", + "ORCL": "Oracle", + "X": "Twitter OR X", + "SPOT": "Spotify", + "AVGO": "Broadcom", + "ASML": "ASML", + "TWLO": "Twilio", + "SNAP": "Snap Inc.", + "TEAM": "Atlassian", + "SQSP": "Squarespace", + "UBER": "Uber", + "ROKU": "Roku", + "PINS": "Pinterest", +} + +SUBREDDITS = [ + "wallstreetbets", + "stocks", + "investing", + "SecurityAnalysis", + "options", + "StockMarket", +] + + +def _format_posts(ticker: str, start_date: str, end_date: str, posts: list) -> str: + if not posts: + return f"No Reddit posts found for {ticker} between {start_date} and {end_date}." + + posts.sort(key=lambda x: x["upvotes"], reverse=True) + lines = [f"## Reddit Posts for {ticker} ({start_date} to {end_date})\n"] + for post in posts[:20]: + lines.append(f"**[r/{post['subreddit']}] {post['title']}** (↑{post['upvotes']})") + if post.get("content"): + snippet = post["content"][:300] + if len(post["content"]) > 300: + snippet += "..." + lines.append(snippet) + lines.append(f"Date: {post['date']} | URL: {post['url']}\n") + return "\n".join(lines) + + +def get_reddit_posts(ticker: str, start_date: str, end_date: str) -> str: + """Fetch Reddit posts about a ticker via PRAW (live Reddit API). + + Requires REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET environment variables. + Register an app at https://www.reddit.com/prefs/apps (script type). + """ + import praw + + client_id = os.environ.get("REDDIT_CLIENT_ID") + client_secret = os.environ.get("REDDIT_CLIENT_SECRET") + + if not client_id or not client_secret: + return ( + "Reddit API credentials not configured. " + "Set REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET environment variables. " + "Register a script app at https://www.reddit.com/prefs/apps to get credentials." + ) + + reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent="TradingAgents:social-media-analyst:v1.0", + read_only=True, + ) + + start_dt = datetime.strptime(start_date, "%Y-%m-%d") + end_dt = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59) + company = TICKER_TO_COMPANY.get(ticker.upper(), ticker) + query = f"{ticker} OR {company}" if company != ticker else ticker + + posts = [] + for subreddit_name in SUBREDDITS: + subreddit = reddit.subreddit(subreddit_name) + for submission in subreddit.search(query, sort="new", limit=30): + post_dt = datetime.fromtimestamp(submission.created_utc, tz=timezone.utc).replace(tzinfo=None) + if not (start_dt <= post_dt <= end_dt): + continue + posts.append({ + "subreddit": subreddit_name, + "title": submission.title, + "content": submission.selftext, + "upvotes": submission.score, + "url": f"https://reddit.com{submission.permalink}", + "date": post_dt.strftime("%Y-%m-%d"), + }) + + return _format_posts(ticker, start_date, end_date, posts) + + +def get_reddit_posts_from_cache( + ticker: str, + start_date: str, + end_date: str, + data_path: str = "reddit_data", +) -> str: + """Fetch Reddit posts from pre-downloaded local JSONL files. + + Expects files at: {data_path}/{category}/{subreddit}.jsonl + Each JSONL line must have: created_utc, title, selftext, url, ups. + """ + if not os.path.isdir(data_path): + return ( + f"Reddit cache directory '{data_path}' not found. " + "Download Reddit data first or use the live API mode." + ) + + start_dt = datetime.strptime(start_date, "%Y-%m-%d") + end_dt = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59) + company_name = TICKER_TO_COMPANY.get(ticker.upper(), ticker) + + search_terms = [ticker] + if " OR " in company_name: + search_terms.extend(company_name.split(" OR ")) + elif company_name != ticker: + search_terms.append(company_name) + + posts = [] + for category in os.listdir(data_path): + category_path = os.path.join(data_path, category) + if not os.path.isdir(category_path): + continue + for data_file in os.listdir(category_path): + if not data_file.endswith(".jsonl"): + continue + subreddit_name = data_file.replace(".jsonl", "") + with open(os.path.join(category_path, data_file), "rb") as f: + for line in f: + if not line.strip(): + continue + parsed = json.loads(line) + post_dt = datetime.fromtimestamp(parsed["created_utc"], tz=timezone.utc).replace(tzinfo=None) + if not (start_dt <= post_dt <= end_dt): + continue + # Filter by company/ticker if it's a company category + if "company" in category: + found = any( + re.search(term, parsed["title"], re.IGNORECASE) + or re.search(term, parsed.get("selftext", ""), re.IGNORECASE) + for term in search_terms + ) + if not found: + continue + posts.append({ + "subreddit": subreddit_name, + "title": parsed["title"], + "content": parsed.get("selftext", ""), + "upvotes": parsed["ups"], + "url": parsed.get("url", ""), + "date": post_dt.strftime("%Y-%m-%d"), + }) + + return _format_posts(ticker, start_date, end_date, posts) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 902befc5..6f9e1b03 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -30,6 +30,7 @@ DEFAULT_CONFIG = { "technical_indicators": "yfinance", # Options: alpha_vantage, yfinance "fundamental_data": "yfinance", # Options: alpha_vantage, yfinance, simfin "news_data": "yfinance", # Options: alpha_vantage, yfinance, finnhub + "social_media_data": "reddit", # Options: reddit, reddit_cache }, # Tool-level configuration (takes precedence over category-level) "tool_vendors": {