feat: fetch posts from Reddit

This commit is contained in:
Tomortec 2025-07-02 17:24:07 +08:00
parent 669af68337
commit 86d7a9ace9
7 changed files with 80 additions and 265 deletions

View File

@ -25,4 +25,5 @@ rich
questionary questionary
langchain_anthropic langchain_anthropic
langchain-google-genai langchain-google-genai
binance-futures-connector binance-futures-connector
praw

View File

@ -13,7 +13,8 @@ def create_social_media_analyst(llm, toolkit):
tools = [ tools = [
toolkit.get_binance_ohlcv, toolkit.get_binance_ohlcv,
toolkit.get_fear_and_greed_index, toolkit.get_fear_and_greed_index,
toolkit.get_coinstats_btc_dominance toolkit.get_coinstats_btc_dominance,
toolkit.get_reddit_posts
# toolkit.get_stock_news_openai, # toolkit.get_stock_news_openai,
# toolkit.get_reddit_stock_info # toolkit.get_reddit_stock_info
] ]

View File

@ -207,42 +207,24 @@ class Toolkit:
@staticmethod @staticmethod
@tool @tool
def get_reddit_news( def get_reddit_posts(
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"], symbol: Annotated[str, "Ticker symbol of the asset, e.g. 'BTC'"],
subreddit: Annotated[str, "Subreddit to search in, e.g. 'CryptoCurrency', 'CryptoMarkets', 'all'"] = "CryptoCurrency",
sort: Annotated[str, "Sorting method for posts ('hot', 'new', 'top', etc.)"] = "hot",
limit: Annotated[int, "Maximum number of posts to fetch"] = 25,
) -> str: ) -> str:
""" """
Retrieve global news from Reddit within a specified time frame. Fetch top posts from a specified subreddit related to a given ticker symbol.
Args: Args:
curr_date (str): Date you want to get news for in yyyy-mm-dd format symbol (str): Ticker symbol of the asset, e.g. 'BTC'
subreddit (str): Subreddit to search in, e.g. 'CryptoCurrency', 'CryptoMarkets', 'all'
sort (str): Sorting method for posts ('hot', 'new', 'top', etc.)
limit (int): Maximum number of posts to fetch
Returns: Returns:
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame. str: A formatted string containing the top posts from the specified subreddit related to the ticker symbol.
""" """
reddit_posts_result = interface.get_reddit_posts(symbol, subreddit, sort, limit)
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5) return reddit_posts_result
return global_news_result
@staticmethod
@tool
def get_reddit_stock_info(
ticker: Annotated[
str,
"Ticker of a asset. e.g. AAPL, TSM",
],
curr_date: Annotated[str, "Current date you want to get news for"],
) -> str:
"""
Retrieve the latest news about a given stock from Reddit, given the current date.
Args:
ticker (str): Ticker of a asset. e.g. AAPL, TSM
curr_date (str): current date in yyyy-mm-dd format to get news for
Returns:
str: A formatted dataframe containing the latest news about the asset on the given date
"""
stock_news_results = interface.get_reddit_asset_news(ticker, curr_date, 7, 5)
return stock_news_results
@staticmethod @staticmethod
@tool @tool

View File

@ -9,8 +9,7 @@ from .interface import (
get_coinstats_news, get_coinstats_news,
get_google_news, get_google_news,
get_fear_and_greed_index, get_fear_and_greed_index,
get_reddit_global_news, get_reddit_posts,
get_reddit_asset_news,
# Financial statements functions # Financial statements functions
# TODO # TODO
# Technical analysis functions # Technical analysis functions
@ -28,8 +27,7 @@ __all__ = [
"get_coinstats_news", "get_coinstats_news",
"get_google_news", "get_google_news",
"get_fear_and_greed_index", "get_fear_and_greed_index",
"get_reddit_global_news", "get_reddit_posts",
"get_reddit_asset_news",
# Financial statements functions # Financial statements functions
# TODO # TODO
# Technical analysis functions # Technical analysis functions

View File

@ -2,7 +2,7 @@ from typing import Annotated, Dict
from .blockbeats_utils import fetch_news_from_blockbeats from .blockbeats_utils import fetch_news_from_blockbeats
from .coindesk_utils import fetch_news_from_coindesk from .coindesk_utils import fetch_news_from_coindesk
from .coinstats_utils import * from .coinstats_utils import *
from .reddit_utils import fetch_top_from_category from .reddit_utils import fetch_posts_from_reddit
from .googlenews_utils import * from .googlenews_utils import *
from .binance_utils import * from .binance_utils import *
from .alternativeme_utils import fetch_fear_and_greed_from_alternativeme from .alternativeme_utils import fetch_fear_and_greed_from_alternativeme
@ -167,114 +167,33 @@ def get_google_news(
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}" return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
def get_reddit_global_news( def get_reddit_posts(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], symbol: Annotated[str, "ticker symbol of the asset"],
look_back_days: Annotated[int, "how many days to look back"], subreddit_name: Annotated[str, "name of the subreddit to fetch posts from, e.g., 'CryptoCurrency', 'CryptoMarkets', 'all'"],
max_limit_per_day: Annotated[int, "Maximum number of news per day"], sort: Annotated[str, "sorting method for posts ('hot', 'new', 'top', etc.)", "default is 'hot'"] = "hot",
limit: Annotated[int, "maximum number of posts to fetch, default is 25"] = 25,
) -> str: ) -> str:
""" """
Retrieve the latest top reddit news Fetch top posts from a specified subreddit.
Args: Args:
start_date: Start date in yyyy-mm-dd format symbol (str): The ticker symbol of the asset to filter posts.
end_date: End date in yyyy-mm-dd format subreddit_name (str): The name of the subreddit to fetch posts from.
sort (str): The sorting method for posts ('hot', 'new', 'top', etc.).
limit (int): The maximum number of posts to fetch.
Returns: Returns:
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url" str: A formatted string containing the top posts from the subreddit.
""" """
posts = fetch_posts_from_reddit(symbol, subreddit_name, sort, limit)
start_date = datetime.strptime(start_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (start_date - curr_date).days + 1
pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations)
while curr_date <= start_date:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"global_news",
curr_date_str,
max_limit_per_day,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0: if len(posts) == 0:
return "" return ""
news_str = "" posts_str = ""
for post in posts: for post in posts:
if post["content"] == "": posts_str += f"### {post['title']} (score: {post['score']}, created at: {datetime.utcfromtimestamp(post['created_utc']).strftime('%Y-%m-%d %H:%M:%S')})\n{post['content']}\n\n"
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}" return f"## Reddit Posts in r/{subreddit_name} for {symbol}:\n{posts_str}"
def get_reddit_asset_news(
ticker: Annotated[str, "ticker symbol of the asset"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
) -> str:
"""
Retrieve the latest top reddit news
Args:
ticker: ticker symbol of the asset
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
"""
start_date = datetime.strptime(start_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (start_date - curr_date).days + 1
pbar = tqdm(
desc=f"Getting Asset News for {ticker} on {start_date}",
total=total_iterations,
)
while curr_date <= start_date:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"asset_news",
curr_date_str,
max_limit_per_day,
ticker,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0:
return ""
news_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"##{ticker} News Reddit, from {before} to {curr_date}:\n\n{news_str}"
def get_binance_ohlcv( def get_binance_ohlcv(
symbol: Annotated[str, "ticker symbol of the asset"], symbol: Annotated[str, "ticker symbol of the asset"],

View File

@ -1,135 +1,50 @@
import requests
import time
import json
from datetime import datetime, timedelta
from contextlib import contextmanager
from typing import Annotated
import os import os
import re import praw
ticker_to_asset = { ticker_to_asset = {
"AAPL": "Apple", "BTC": "Bitcoin",
"MSFT": "Microsoft", "ETH": "Ethereum",
"GOOGL": "Google", "XRP": "XRP",
"AMZN": "Amazon", "LTC": "Litecoin",
"TSLA": "Tesla", "DOGE": "Dogecoin",
"NVDA": "Nvidia", "SOL": "Solana",
"TSM": "Taiwan Semiconductor Manufacturing Asset OR TSMC", "ADA": "Cardano",
"JPM": "JPMorgan Chase OR JP Morgan", "DOT": "Polkadot",
"JNJ": "Johnson & Johnson OR JNJ", "AVAX": "Avalanche",
"V": "Visa",
"WMT": "Walmart",
"META": "Meta OR Facebook",
"AMD": "AMD",
"INTC": "Intel",
"QCOM": "Qualcomm",
"BABA": "Alibaba",
"ADBE": "Adobe",
"NFLX": "Netflix",
"CRM": "Salesforce",
"PYPL": "PayPal",
"PLTR": "Palantir",
"MU": "Micron",
"SQ": "Block OR Square",
"ZM": "Zoom",
"CSCO": "Cisco",
"SHOP": "Shopify",
"ORCL": "Oracle",
"X": "Twitter OR X",
"SPOT": "Spotify",
"AVGO": "Broadcom",
"ASML": "ASML ",
"TWLO": "Twilio",
"SNAP": "Snap Inc.",
"TEAM": "Atlassian",
"SQSP": "Squarespace",
"UBER": "Uber",
"ROKU": "Roku",
"PINS": "Pinterest",
} }
def fetch_posts_from_reddit(
def fetch_top_from_category( symbol: str, subreddit_name: str,
category: Annotated[ sort: str = "hot", limit: int = 25
str, "Category to fetch top post from. Collection of subreddits."
],
date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
data_path: Annotated[
str,
"Path to the data folder. Default is 'reddit_data'.",
] = "reddit_data",
): ):
base_path = data_path """
Fetch top posts from a specified subreddit.
all_content = [] Args:
symbol (str): The ticker symbol of the asset to filter posts.
if max_limit < len(os.listdir(os.path.join(base_path, category))): subreddit (str): The name of the subreddit to fetch posts from.
raise ValueError( sort (str): The sorting method for posts ('hot', 'new', 'top', etc.).
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts" limit (int): The maximum number of posts to fetch.
) Returns:
list: A list of dictionaries containing post data.
limit_per_subreddit = max_limit // len( """
os.listdir(os.path.join(base_path, category)) reddit = praw.Reddit(
client_id=os.getenv("REDDIT_CLIENT_ID"),
client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
username=os.getenv("REDDIT_USERNAME"),
password=os.getenv("REDDIT_PASSWORD"),
user_agent=os.getenv("REDDIT_USER_AGENT"),
) )
subreddit = reddit.subreddit(subreddit_name)
for data_file in os.listdir(os.path.join(base_path, category)): query = symbol + " OR " + ticker_to_asset.get(symbol, symbol)
# check if data_file is a .jsonl file submissions = subreddit.search(query, sort=sort, time_filter="day", limit=limit)
if not data_file.endswith(".jsonl"): return [
continue {
"title": submission.title,
all_content_curr_subreddit = [] "content": submission.selftext,
"score": submission.score,
with open(os.path.join(base_path, category, data_file), "rb") as f: "created_utc": submission.created_utc,
for i, line in enumerate(f): }
# skip empty lines for submission in submissions
if not line.strip(): ]
continue
parsed_line = json.loads(line)
# select only lines that are from the date
post_date = datetime.utcfromtimestamp(
parsed_line["created_utc"]
).strftime("%Y-%m-%d")
if post_date != date:
continue
# if is asset_news, check that the title or the content has the asset's name (query) mentioned
if "asset" in category and query:
search_terms = []
if "OR" in ticker_to_asset[query]:
search_terms = ticker_to_asset[query].split(" OR ")
else:
search_terms = [ticker_to_asset[query]]
search_terms.append(query)
found = False
for term in search_terms:
if re.search(
term, parsed_line["title"], re.IGNORECASE
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
found = True
break
if not found:
continue
post = {
"title": parsed_line["title"],
"content": parsed_line["selftext"],
"url": parsed_line["url"],
"upvotes": parsed_line["ups"],
"posted_date": post_date,
}
all_content_curr_subreddit.append(post)
# sort all_content_curr_subreddit by upvote_ratio in descending order
all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True)
all_content.extend(all_content_curr_subreddit[:limit_per_subreddit])
return all_content

View File

@ -126,8 +126,8 @@ class TradingAgentsGraph:
self.toolkit.get_binance_ohlcv, self.toolkit.get_binance_ohlcv,
self.toolkit.get_fear_and_greed_index, self.toolkit.get_fear_and_greed_index,
self.toolkit.get_coinstats_btc_dominance, self.toolkit.get_coinstats_btc_dominance,
self.toolkit.get_reddit_posts,
# self.toolkit.get_stock_news_openai, # self.toolkit.get_stock_news_openai,
# self.toolkit.get_reddit_stock_info,
] ]
), ),
"news": ToolNode( "news": ToolNode(
@ -138,7 +138,6 @@ class TradingAgentsGraph:
# self.toolkit.get_google_news, # self.toolkit.get_google_news,
self.toolkit.get_blockbeats_news, self.toolkit.get_blockbeats_news,
self.toolkit.get_coindesk_news, self.toolkit.get_coindesk_news,
# self.toolkit.get_reddit_news,
] ]
), ),
"fundamentals": ToolNode( "fundamentals": ToolNode(