From 86d7a9ace9ac3fc869e9427e2c17ccd21e199091 Mon Sep 17 00:00:00 2001 From: Tomortec Date: Wed, 2 Jul 2025 17:24:07 +0800 Subject: [PATCH] feat: fetch posts from Reddit --- requirements.txt | 3 +- .../agents/analysts/social_media_analyst.py | 3 +- tradingagents/agents/utils/agent_utils.py | 44 ++--- tradingagents/dataflows/__init__.py | 6 +- tradingagents/dataflows/interface.py | 117 ++---------- tradingagents/dataflows/reddit_utils.py | 169 +++++------------- tradingagents/graph/trading_graph.py | 3 +- 7 files changed, 80 insertions(+), 265 deletions(-) diff --git a/requirements.txt b/requirements.txt index 005883ec..463c5b10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,5 @@ rich questionary langchain_anthropic langchain-google-genai -binance-futures-connector \ No newline at end of file +binance-futures-connector +praw \ No newline at end of file diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 2fc94c3c..ad9c6ae7 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -13,7 +13,8 @@ def create_social_media_analyst(llm, toolkit): tools = [ toolkit.get_binance_ohlcv, toolkit.get_fear_and_greed_index, - toolkit.get_coinstats_btc_dominance + toolkit.get_coinstats_btc_dominance, + toolkit.get_reddit_posts # toolkit.get_stock_news_openai, # toolkit.get_reddit_stock_info ] diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 184885f3..8d4de4ae 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -207,42 +207,24 @@ class Toolkit: @staticmethod @tool - def get_reddit_news( - curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"], + def get_reddit_posts( + symbol: Annotated[str, "Ticker symbol of the asset, e.g. 'BTC'"], + subreddit: Annotated[str, "Subreddit to search in, e.g. 'CryptoCurrency', 'CryptoMarkets', 'all'"] = "CryptoCurrency", + sort: Annotated[str, "Sorting method for posts ('hot', 'new', 'top', etc.)"] = "hot", + limit: Annotated[int, "Maximum number of posts to fetch"] = 25, ) -> str: """ - Retrieve global news from Reddit within a specified time frame. + Fetch top posts from a specified subreddit related to a given ticker symbol. Args: - curr_date (str): Date you want to get news for in yyyy-mm-dd format + symbol (str): Ticker symbol of the asset, e.g. 'BTC' + subreddit (str): Subreddit to search in, e.g. 'CryptoCurrency', 'CryptoMarkets', 'all' + sort (str): Sorting method for posts ('hot', 'new', 'top', etc.) + limit (int): Maximum number of posts to fetch Returns: - str: A formatted dataframe containing the latest global news from Reddit in the specified time frame. + str: A formatted string containing the top posts from the specified subreddit related to the ticker symbol. """ - - global_news_result = interface.get_reddit_global_news(curr_date, 7, 5) - - return global_news_result - - @staticmethod - @tool - def get_reddit_stock_info( - ticker: Annotated[ - str, - "Ticker of a asset. e.g. AAPL, TSM", - ], - curr_date: Annotated[str, "Current date you want to get news for"], - ) -> str: - """ - Retrieve the latest news about a given stock from Reddit, given the current date. - Args: - ticker (str): Ticker of a asset. e.g. AAPL, TSM - curr_date (str): current date in yyyy-mm-dd format to get news for - Returns: - str: A formatted dataframe containing the latest news about the asset on the given date - """ - - stock_news_results = interface.get_reddit_asset_news(ticker, curr_date, 7, 5) - - return stock_news_results + reddit_posts_result = interface.get_reddit_posts(symbol, subreddit, sort, limit) + return reddit_posts_result @staticmethod @tool diff --git a/tradingagents/dataflows/__init__.py b/tradingagents/dataflows/__init__.py index df50cb3d..2184fa5a 100644 --- a/tradingagents/dataflows/__init__.py +++ b/tradingagents/dataflows/__init__.py @@ -9,8 +9,7 @@ from .interface import ( get_coinstats_news, get_google_news, get_fear_and_greed_index, - get_reddit_global_news, - get_reddit_asset_news, + get_reddit_posts, # Financial statements functions # TODO # Technical analysis functions @@ -28,8 +27,7 @@ __all__ = [ "get_coinstats_news", "get_google_news", "get_fear_and_greed_index", - "get_reddit_global_news", - "get_reddit_asset_news", + "get_reddit_posts", # Financial statements functions # TODO # Technical analysis functions diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 9b6f170b..e49df346 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -2,7 +2,7 @@ from typing import Annotated, Dict from .blockbeats_utils import fetch_news_from_blockbeats from .coindesk_utils import fetch_news_from_coindesk from .coinstats_utils import * -from .reddit_utils import fetch_top_from_category +from .reddit_utils import fetch_posts_from_reddit from .googlenews_utils import * from .binance_utils import * from .alternativeme_utils import fetch_fear_and_greed_from_alternativeme @@ -167,114 +167,33 @@ def get_google_news( return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}" -def get_reddit_global_news( - start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - look_back_days: Annotated[int, "how many days to look back"], - max_limit_per_day: Annotated[int, "Maximum number of news per day"], +def get_reddit_posts( + symbol: Annotated[str, "ticker symbol of the asset"], + subreddit_name: Annotated[str, "name of the subreddit to fetch posts from, e.g., 'CryptoCurrency', 'CryptoMarkets', 'all'"], + sort: Annotated[str, "sorting method for posts ('hot', 'new', 'top', etc.)", "default is 'hot'"] = "hot", + limit: Annotated[int, "maximum number of posts to fetch, default is 25"] = 25, ) -> str: """ - Retrieve the latest top reddit news + Fetch top posts from a specified subreddit. + Args: - start_date: Start date in yyyy-mm-dd format - end_date: End date in yyyy-mm-dd format + symbol (str): The ticker symbol of the asset to filter posts. + subreddit_name (str): The name of the subreddit to fetch posts from. + sort (str): The sorting method for posts ('hot', 'new', 'top', etc.). + limit (int): The maximum number of posts to fetch. + Returns: - str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url" + str: A formatted string containing the top posts from the subreddit. """ - - start_date = datetime.strptime(start_date, "%Y-%m-%d") - before = start_date - relativedelta(days=look_back_days) - before = before.strftime("%Y-%m-%d") - - posts = [] - # iterate from start_date to end_date - curr_date = datetime.strptime(before, "%Y-%m-%d") - - total_iterations = (start_date - curr_date).days + 1 - pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations) - - while curr_date <= start_date: - curr_date_str = curr_date.strftime("%Y-%m-%d") - fetch_result = fetch_top_from_category( - "global_news", - curr_date_str, - max_limit_per_day, - data_path=os.path.join(DATA_DIR, "reddit_data"), - ) - posts.extend(fetch_result) - curr_date += relativedelta(days=1) - pbar.update(1) - - pbar.close() - + posts = fetch_posts_from_reddit(symbol, subreddit_name, sort, limit) if len(posts) == 0: return "" - news_str = "" + posts_str = "" for post in posts: - if post["content"] == "": - news_str += f"### {post['title']}\n\n" - else: - news_str += f"### {post['title']}\n\n{post['content']}\n\n" + posts_str += f"### {post['title']} (score: {post['score']}, created at: {datetime.utcfromtimestamp(post['created_utc']).strftime('%Y-%m-%d %H:%M:%S')})\n{post['content']}\n\n" - return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}" - -def get_reddit_asset_news( - ticker: Annotated[str, "ticker symbol of the asset"], - start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - look_back_days: Annotated[int, "how many days to look back"], - max_limit_per_day: Annotated[int, "Maximum number of news per day"], -) -> str: - """ - Retrieve the latest top reddit news - Args: - ticker: ticker symbol of the asset - start_date: Start date in yyyy-mm-dd format - end_date: End date in yyyy-mm-dd format - Returns: - str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url" - """ - - start_date = datetime.strptime(start_date, "%Y-%m-%d") - before = start_date - relativedelta(days=look_back_days) - before = before.strftime("%Y-%m-%d") - - posts = [] - # iterate from start_date to end_date - curr_date = datetime.strptime(before, "%Y-%m-%d") - - total_iterations = (start_date - curr_date).days + 1 - pbar = tqdm( - desc=f"Getting Asset News for {ticker} on {start_date}", - total=total_iterations, - ) - - while curr_date <= start_date: - curr_date_str = curr_date.strftime("%Y-%m-%d") - fetch_result = fetch_top_from_category( - "asset_news", - curr_date_str, - max_limit_per_day, - ticker, - data_path=os.path.join(DATA_DIR, "reddit_data"), - ) - posts.extend(fetch_result) - curr_date += relativedelta(days=1) - - pbar.update(1) - - pbar.close() - - if len(posts) == 0: - return "" - - news_str = "" - for post in posts: - if post["content"] == "": - news_str += f"### {post['title']}\n\n" - else: - news_str += f"### {post['title']}\n\n{post['content']}\n\n" - - return f"##{ticker} News Reddit, from {before} to {curr_date}:\n\n{news_str}" + return f"## Reddit Posts in r/{subreddit_name} for {symbol}:\n{posts_str}" def get_binance_ohlcv( symbol: Annotated[str, "ticker symbol of the asset"], diff --git a/tradingagents/dataflows/reddit_utils.py b/tradingagents/dataflows/reddit_utils.py index 7cb233e5..9da36da0 100644 --- a/tradingagents/dataflows/reddit_utils.py +++ b/tradingagents/dataflows/reddit_utils.py @@ -1,135 +1,50 @@ -import requests -import time -import json -from datetime import datetime, timedelta -from contextlib import contextmanager -from typing import Annotated + import os -import re +import praw ticker_to_asset = { - "AAPL": "Apple", - "MSFT": "Microsoft", - "GOOGL": "Google", - "AMZN": "Amazon", - "TSLA": "Tesla", - "NVDA": "Nvidia", - "TSM": "Taiwan Semiconductor Manufacturing Asset OR TSMC", - "JPM": "JPMorgan Chase OR JP Morgan", - "JNJ": "Johnson & Johnson OR JNJ", - "V": "Visa", - "WMT": "Walmart", - "META": "Meta OR Facebook", - "AMD": "AMD", - "INTC": "Intel", - "QCOM": "Qualcomm", - "BABA": "Alibaba", - "ADBE": "Adobe", - "NFLX": "Netflix", - "CRM": "Salesforce", - "PYPL": "PayPal", - "PLTR": "Palantir", - "MU": "Micron", - "SQ": "Block OR Square", - "ZM": "Zoom", - "CSCO": "Cisco", - "SHOP": "Shopify", - "ORCL": "Oracle", - "X": "Twitter OR X", - "SPOT": "Spotify", - "AVGO": "Broadcom", - "ASML": "ASML ", - "TWLO": "Twilio", - "SNAP": "Snap Inc.", - "TEAM": "Atlassian", - "SQSP": "Squarespace", - "UBER": "Uber", - "ROKU": "Roku", - "PINS": "Pinterest", + "BTC": "Bitcoin", + "ETH": "Ethereum", + "XRP": "XRP", + "LTC": "Litecoin", + "DOGE": "Dogecoin", + "SOL": "Solana", + "ADA": "Cardano", + "DOT": "Polkadot", + "AVAX": "Avalanche", } - -def fetch_top_from_category( - category: Annotated[ - str, "Category to fetch top post from. Collection of subreddits." - ], - date: Annotated[str, "Date to fetch top posts from."], - max_limit: Annotated[int, "Maximum number of posts to fetch."], - query: Annotated[str, "Optional query to search for in the subreddit."] = None, - data_path: Annotated[ - str, - "Path to the data folder. Default is 'reddit_data'.", - ] = "reddit_data", +def fetch_posts_from_reddit( + symbol: str, subreddit_name: str, + sort: str = "hot", limit: int = 25 ): - base_path = data_path + """ + Fetch top posts from a specified subreddit. - all_content = [] - - if max_limit < len(os.listdir(os.path.join(base_path, category))): - raise ValueError( - "REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts" - ) - - limit_per_subreddit = max_limit // len( - os.listdir(os.path.join(base_path, category)) + Args: + symbol (str): The ticker symbol of the asset to filter posts. + subreddit (str): The name of the subreddit to fetch posts from. + sort (str): The sorting method for posts ('hot', 'new', 'top', etc.). + limit (int): The maximum number of posts to fetch. + Returns: + list: A list of dictionaries containing post data. + """ + reddit = praw.Reddit( + client_id=os.getenv("REDDIT_CLIENT_ID"), + client_secret=os.getenv("REDDIT_CLIENT_SECRET"), + username=os.getenv("REDDIT_USERNAME"), + password=os.getenv("REDDIT_PASSWORD"), + user_agent=os.getenv("REDDIT_USER_AGENT"), ) - - for data_file in os.listdir(os.path.join(base_path, category)): - # check if data_file is a .jsonl file - if not data_file.endswith(".jsonl"): - continue - - all_content_curr_subreddit = [] - - with open(os.path.join(base_path, category, data_file), "rb") as f: - for i, line in enumerate(f): - # skip empty lines - if not line.strip(): - continue - - parsed_line = json.loads(line) - - # select only lines that are from the date - post_date = datetime.utcfromtimestamp( - parsed_line["created_utc"] - ).strftime("%Y-%m-%d") - if post_date != date: - continue - - # if is asset_news, check that the title or the content has the asset's name (query) mentioned - if "asset" in category and query: - search_terms = [] - if "OR" in ticker_to_asset[query]: - search_terms = ticker_to_asset[query].split(" OR ") - else: - search_terms = [ticker_to_asset[query]] - - search_terms.append(query) - - found = False - for term in search_terms: - if re.search( - term, parsed_line["title"], re.IGNORECASE - ) or re.search(term, parsed_line["selftext"], re.IGNORECASE): - found = True - break - - if not found: - continue - - post = { - "title": parsed_line["title"], - "content": parsed_line["selftext"], - "url": parsed_line["url"], - "upvotes": parsed_line["ups"], - "posted_date": post_date, - } - - all_content_curr_subreddit.append(post) - - # sort all_content_curr_subreddit by upvote_ratio in descending order - all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True) - - all_content.extend(all_content_curr_subreddit[:limit_per_subreddit]) - - return all_content + subreddit = reddit.subreddit(subreddit_name) + query = symbol + " OR " + ticker_to_asset.get(symbol, symbol) + submissions = subreddit.search(query, sort=sort, time_filter="day", limit=limit) + return [ + { + "title": submission.title, + "content": submission.selftext, + "score": submission.score, + "created_utc": submission.created_utc, + } + for submission in submissions + ] \ No newline at end of file diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index adf40304..328eea82 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -126,8 +126,8 @@ class TradingAgentsGraph: self.toolkit.get_binance_ohlcv, self.toolkit.get_fear_and_greed_index, self.toolkit.get_coinstats_btc_dominance, + self.toolkit.get_reddit_posts, # self.toolkit.get_stock_news_openai, - # self.toolkit.get_reddit_stock_info, ] ), "news": ToolNode( @@ -138,7 +138,6 @@ class TradingAgentsGraph: # self.toolkit.get_google_news, self.toolkit.get_blockbeats_news, self.toolkit.get_coindesk_news, - # self.toolkit.get_reddit_news, ] ), "fundamentals": ToolNode(