feat: fetch posts from Reddit

This commit is contained in:
Tomortec 2025-07-02 17:24:07 +08:00
parent 669af68337
commit 86d7a9ace9
7 changed files with 80 additions and 265 deletions

View File

@ -25,4 +25,5 @@ rich
questionary
langchain_anthropic
langchain-google-genai
binance-futures-connector
binance-futures-connector
praw

View File

@ -13,7 +13,8 @@ def create_social_media_analyst(llm, toolkit):
tools = [
toolkit.get_binance_ohlcv,
toolkit.get_fear_and_greed_index,
toolkit.get_coinstats_btc_dominance
toolkit.get_coinstats_btc_dominance,
toolkit.get_reddit_posts
# toolkit.get_stock_news_openai,
# toolkit.get_reddit_stock_info
]

View File

@ -207,42 +207,24 @@ class Toolkit:
@staticmethod
@tool
def get_reddit_news(
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
def get_reddit_posts(
symbol: Annotated[str, "Ticker symbol of the asset, e.g. 'BTC'"],
subreddit: Annotated[str, "Subreddit to search in, e.g. 'CryptoCurrency', 'CryptoMarkets', 'all'"] = "CryptoCurrency",
sort: Annotated[str, "Sorting method for posts ('hot', 'new', 'top', etc.)"] = "hot",
limit: Annotated[int, "Maximum number of posts to fetch"] = 25,
) -> str:
"""
Retrieve global news from Reddit within a specified time frame.
Fetch top posts from a specified subreddit related to a given ticker symbol.
Args:
curr_date (str): Date you want to get news for in yyyy-mm-dd format
symbol (str): Ticker symbol of the asset, e.g. 'BTC'
subreddit (str): Subreddit to search in, e.g. 'CryptoCurrency', 'CryptoMarkets', 'all'
sort (str): Sorting method for posts ('hot', 'new', 'top', etc.)
limit (int): Maximum number of posts to fetch
Returns:
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
str: A formatted string containing the top posts from the specified subreddit related to the ticker symbol.
"""
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
@staticmethod
@tool
def get_reddit_stock_info(
ticker: Annotated[
str,
"Ticker of a asset. e.g. AAPL, TSM",
],
curr_date: Annotated[str, "Current date you want to get news for"],
) -> str:
"""
Retrieve the latest news about a given stock from Reddit, given the current date.
Args:
ticker (str): Ticker of a asset. e.g. AAPL, TSM
curr_date (str): current date in yyyy-mm-dd format to get news for
Returns:
str: A formatted dataframe containing the latest news about the asset on the given date
"""
stock_news_results = interface.get_reddit_asset_news(ticker, curr_date, 7, 5)
return stock_news_results
reddit_posts_result = interface.get_reddit_posts(symbol, subreddit, sort, limit)
return reddit_posts_result
@staticmethod
@tool

View File

@ -9,8 +9,7 @@ from .interface import (
get_coinstats_news,
get_google_news,
get_fear_and_greed_index,
get_reddit_global_news,
get_reddit_asset_news,
get_reddit_posts,
# Financial statements functions
# TODO
# Technical analysis functions
@ -28,8 +27,7 @@ __all__ = [
"get_coinstats_news",
"get_google_news",
"get_fear_and_greed_index",
"get_reddit_global_news",
"get_reddit_asset_news",
"get_reddit_posts",
# Financial statements functions
# TODO
# Technical analysis functions

View File

@ -2,7 +2,7 @@ from typing import Annotated, Dict
from .blockbeats_utils import fetch_news_from_blockbeats
from .coindesk_utils import fetch_news_from_coindesk
from .coinstats_utils import *
from .reddit_utils import fetch_top_from_category
from .reddit_utils import fetch_posts_from_reddit
from .googlenews_utils import *
from .binance_utils import *
from .alternativeme_utils import fetch_fear_and_greed_from_alternativeme
@ -167,114 +167,33 @@ def get_google_news(
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
def get_reddit_global_news(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
def get_reddit_posts(
symbol: Annotated[str, "ticker symbol of the asset"],
subreddit_name: Annotated[str, "name of the subreddit to fetch posts from, e.g., 'CryptoCurrency', 'CryptoMarkets', 'all'"],
sort: Annotated[str, "sorting method for posts ('hot', 'new', 'top', etc.)", "default is 'hot'"] = "hot",
limit: Annotated[int, "maximum number of posts to fetch, default is 25"] = 25,
) -> str:
"""
Retrieve the latest top reddit news
Fetch top posts from a specified subreddit.
Args:
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
symbol (str): The ticker symbol of the asset to filter posts.
subreddit_name (str): The name of the subreddit to fetch posts from.
sort (str): The sorting method for posts ('hot', 'new', 'top', etc.).
limit (int): The maximum number of posts to fetch.
Returns:
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
str: A formatted string containing the top posts from the subreddit.
"""
start_date = datetime.strptime(start_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (start_date - curr_date).days + 1
pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations)
while curr_date <= start_date:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"global_news",
curr_date_str,
max_limit_per_day,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
posts = fetch_posts_from_reddit(symbol, subreddit_name, sort, limit)
if len(posts) == 0:
return ""
news_str = ""
posts_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
posts_str += f"### {post['title']} (score: {post['score']}, created at: {datetime.utcfromtimestamp(post['created_utc']).strftime('%Y-%m-%d %H:%M:%S')})\n{post['content']}\n\n"
return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}"
def get_reddit_asset_news(
ticker: Annotated[str, "ticker symbol of the asset"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
) -> str:
"""
Retrieve the latest top reddit news
Args:
ticker: ticker symbol of the asset
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
"""
start_date = datetime.strptime(start_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (start_date - curr_date).days + 1
pbar = tqdm(
desc=f"Getting Asset News for {ticker} on {start_date}",
total=total_iterations,
)
while curr_date <= start_date:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"asset_news",
curr_date_str,
max_limit_per_day,
ticker,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0:
return ""
news_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"##{ticker} News Reddit, from {before} to {curr_date}:\n\n{news_str}"
return f"## Reddit Posts in r/{subreddit_name} for {symbol}:\n{posts_str}"
def get_binance_ohlcv(
symbol: Annotated[str, "ticker symbol of the asset"],

View File

@ -1,135 +1,50 @@
import requests
import time
import json
from datetime import datetime, timedelta
from contextlib import contextmanager
from typing import Annotated
import os
import re
import praw
ticker_to_asset = {
"AAPL": "Apple",
"MSFT": "Microsoft",
"GOOGL": "Google",
"AMZN": "Amazon",
"TSLA": "Tesla",
"NVDA": "Nvidia",
"TSM": "Taiwan Semiconductor Manufacturing Asset OR TSMC",
"JPM": "JPMorgan Chase OR JP Morgan",
"JNJ": "Johnson & Johnson OR JNJ",
"V": "Visa",
"WMT": "Walmart",
"META": "Meta OR Facebook",
"AMD": "AMD",
"INTC": "Intel",
"QCOM": "Qualcomm",
"BABA": "Alibaba",
"ADBE": "Adobe",
"NFLX": "Netflix",
"CRM": "Salesforce",
"PYPL": "PayPal",
"PLTR": "Palantir",
"MU": "Micron",
"SQ": "Block OR Square",
"ZM": "Zoom",
"CSCO": "Cisco",
"SHOP": "Shopify",
"ORCL": "Oracle",
"X": "Twitter OR X",
"SPOT": "Spotify",
"AVGO": "Broadcom",
"ASML": "ASML ",
"TWLO": "Twilio",
"SNAP": "Snap Inc.",
"TEAM": "Atlassian",
"SQSP": "Squarespace",
"UBER": "Uber",
"ROKU": "Roku",
"PINS": "Pinterest",
"BTC": "Bitcoin",
"ETH": "Ethereum",
"XRP": "XRP",
"LTC": "Litecoin",
"DOGE": "Dogecoin",
"SOL": "Solana",
"ADA": "Cardano",
"DOT": "Polkadot",
"AVAX": "Avalanche",
}
def fetch_top_from_category(
category: Annotated[
str, "Category to fetch top post from. Collection of subreddits."
],
date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
data_path: Annotated[
str,
"Path to the data folder. Default is 'reddit_data'.",
] = "reddit_data",
def fetch_posts_from_reddit(
symbol: str, subreddit_name: str,
sort: str = "hot", limit: int = 25
):
base_path = data_path
"""
Fetch top posts from a specified subreddit.
all_content = []
if max_limit < len(os.listdir(os.path.join(base_path, category))):
raise ValueError(
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
)
limit_per_subreddit = max_limit // len(
os.listdir(os.path.join(base_path, category))
Args:
symbol (str): The ticker symbol of the asset to filter posts.
subreddit (str): The name of the subreddit to fetch posts from.
sort (str): The sorting method for posts ('hot', 'new', 'top', etc.).
limit (int): The maximum number of posts to fetch.
Returns:
list: A list of dictionaries containing post data.
"""
reddit = praw.Reddit(
client_id=os.getenv("REDDIT_CLIENT_ID"),
client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
username=os.getenv("REDDIT_USERNAME"),
password=os.getenv("REDDIT_PASSWORD"),
user_agent=os.getenv("REDDIT_USER_AGENT"),
)
for data_file in os.listdir(os.path.join(base_path, category)):
# check if data_file is a .jsonl file
if not data_file.endswith(".jsonl"):
continue
all_content_curr_subreddit = []
with open(os.path.join(base_path, category, data_file), "rb") as f:
for i, line in enumerate(f):
# skip empty lines
if not line.strip():
continue
parsed_line = json.loads(line)
# select only lines that are from the date
post_date = datetime.utcfromtimestamp(
parsed_line["created_utc"]
).strftime("%Y-%m-%d")
if post_date != date:
continue
# if is asset_news, check that the title or the content has the asset's name (query) mentioned
if "asset" in category and query:
search_terms = []
if "OR" in ticker_to_asset[query]:
search_terms = ticker_to_asset[query].split(" OR ")
else:
search_terms = [ticker_to_asset[query]]
search_terms.append(query)
found = False
for term in search_terms:
if re.search(
term, parsed_line["title"], re.IGNORECASE
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
found = True
break
if not found:
continue
post = {
"title": parsed_line["title"],
"content": parsed_line["selftext"],
"url": parsed_line["url"],
"upvotes": parsed_line["ups"],
"posted_date": post_date,
}
all_content_curr_subreddit.append(post)
# sort all_content_curr_subreddit by upvote_ratio in descending order
all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True)
all_content.extend(all_content_curr_subreddit[:limit_per_subreddit])
return all_content
subreddit = reddit.subreddit(subreddit_name)
query = symbol + " OR " + ticker_to_asset.get(symbol, symbol)
submissions = subreddit.search(query, sort=sort, time_filter="day", limit=limit)
return [
{
"title": submission.title,
"content": submission.selftext,
"score": submission.score,
"created_utc": submission.created_utc,
}
for submission in submissions
]

View File

@ -126,8 +126,8 @@ class TradingAgentsGraph:
self.toolkit.get_binance_ohlcv,
self.toolkit.get_fear_and_greed_index,
self.toolkit.get_coinstats_btc_dominance,
self.toolkit.get_reddit_posts,
# self.toolkit.get_stock_news_openai,
# self.toolkit.get_reddit_stock_info,
]
),
"news": ToolNode(
@ -138,7 +138,6 @@ class TradingAgentsGraph:
# self.toolkit.get_google_news,
self.toolkit.get_blockbeats_news,
self.toolkit.get_coindesk_news,
# self.toolkit.get_reddit_news,
]
),
"fundamentals": ToolNode(