feat: fetch posts from Reddit
This commit is contained in:
parent
669af68337
commit
86d7a9ace9
|
|
@ -25,4 +25,5 @@ rich
|
|||
questionary
|
||||
langchain_anthropic
|
||||
langchain-google-genai
|
||||
binance-futures-connector
|
||||
binance-futures-connector
|
||||
praw
|
||||
|
|
@ -13,7 +13,8 @@ def create_social_media_analyst(llm, toolkit):
|
|||
tools = [
|
||||
toolkit.get_binance_ohlcv,
|
||||
toolkit.get_fear_and_greed_index,
|
||||
toolkit.get_coinstats_btc_dominance
|
||||
toolkit.get_coinstats_btc_dominance,
|
||||
toolkit.get_reddit_posts
|
||||
# toolkit.get_stock_news_openai,
|
||||
# toolkit.get_reddit_stock_info
|
||||
]
|
||||
|
|
|
|||
|
|
@ -207,42 +207,24 @@ class Toolkit:
|
|||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_reddit_news(
|
||||
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
|
||||
def get_reddit_posts(
|
||||
symbol: Annotated[str, "Ticker symbol of the asset, e.g. 'BTC'"],
|
||||
subreddit: Annotated[str, "Subreddit to search in, e.g. 'CryptoCurrency', 'CryptoMarkets', 'all'"] = "CryptoCurrency",
|
||||
sort: Annotated[str, "Sorting method for posts ('hot', 'new', 'top', etc.)"] = "hot",
|
||||
limit: Annotated[int, "Maximum number of posts to fetch"] = 25,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global news from Reddit within a specified time frame.
|
||||
Fetch top posts from a specified subreddit related to a given ticker symbol.
|
||||
Args:
|
||||
curr_date (str): Date you want to get news for in yyyy-mm-dd format
|
||||
symbol (str): Ticker symbol of the asset, e.g. 'BTC'
|
||||
subreddit (str): Subreddit to search in, e.g. 'CryptoCurrency', 'CryptoMarkets', 'all'
|
||||
sort (str): Sorting method for posts ('hot', 'new', 'top', etc.)
|
||||
limit (int): Maximum number of posts to fetch
|
||||
Returns:
|
||||
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
|
||||
str: A formatted string containing the top posts from the specified subreddit related to the ticker symbol.
|
||||
"""
|
||||
|
||||
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
|
||||
|
||||
return global_news_result
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_reddit_stock_info(
|
||||
ticker: Annotated[
|
||||
str,
|
||||
"Ticker of a asset. e.g. AAPL, TSM",
|
||||
],
|
||||
curr_date: Annotated[str, "Current date you want to get news for"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest news about a given stock from Reddit, given the current date.
|
||||
Args:
|
||||
ticker (str): Ticker of a asset. e.g. AAPL, TSM
|
||||
curr_date (str): current date in yyyy-mm-dd format to get news for
|
||||
Returns:
|
||||
str: A formatted dataframe containing the latest news about the asset on the given date
|
||||
"""
|
||||
|
||||
stock_news_results = interface.get_reddit_asset_news(ticker, curr_date, 7, 5)
|
||||
|
||||
return stock_news_results
|
||||
reddit_posts_result = interface.get_reddit_posts(symbol, subreddit, sort, limit)
|
||||
return reddit_posts_result
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from .interface import (
|
|||
get_coinstats_news,
|
||||
get_google_news,
|
||||
get_fear_and_greed_index,
|
||||
get_reddit_global_news,
|
||||
get_reddit_asset_news,
|
||||
get_reddit_posts,
|
||||
# Financial statements functions
|
||||
# TODO
|
||||
# Technical analysis functions
|
||||
|
|
@ -28,8 +27,7 @@ __all__ = [
|
|||
"get_coinstats_news",
|
||||
"get_google_news",
|
||||
"get_fear_and_greed_index",
|
||||
"get_reddit_global_news",
|
||||
"get_reddit_asset_news",
|
||||
"get_reddit_posts",
|
||||
# Financial statements functions
|
||||
# TODO
|
||||
# Technical analysis functions
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Annotated, Dict
|
|||
from .blockbeats_utils import fetch_news_from_blockbeats
|
||||
from .coindesk_utils import fetch_news_from_coindesk
|
||||
from .coinstats_utils import *
|
||||
from .reddit_utils import fetch_top_from_category
|
||||
from .reddit_utils import fetch_posts_from_reddit
|
||||
from .googlenews_utils import *
|
||||
from .binance_utils import *
|
||||
from .alternativeme_utils import fetch_fear_and_greed_from_alternativeme
|
||||
|
|
@ -167,114 +167,33 @@ def get_google_news(
|
|||
|
||||
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
|
||||
|
||||
def get_reddit_global_news(
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
|
||||
def get_reddit_posts(
|
||||
symbol: Annotated[str, "ticker symbol of the asset"],
|
||||
subreddit_name: Annotated[str, "name of the subreddit to fetch posts from, e.g., 'CryptoCurrency', 'CryptoMarkets', 'all'"],
|
||||
sort: Annotated[str, "sorting method for posts ('hot', 'new', 'top', etc.)", "default is 'hot'"] = "hot",
|
||||
limit: Annotated[int, "maximum number of posts to fetch, default is 25"] = 25,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest top reddit news
|
||||
Fetch top posts from a specified subreddit.
|
||||
|
||||
Args:
|
||||
start_date: Start date in yyyy-mm-dd format
|
||||
end_date: End date in yyyy-mm-dd format
|
||||
symbol (str): The ticker symbol of the asset to filter posts.
|
||||
subreddit_name (str): The name of the subreddit to fetch posts from.
|
||||
sort (str): The sorting method for posts ('hot', 'new', 'top', etc.).
|
||||
limit (int): The maximum number of posts to fetch.
|
||||
|
||||
Returns:
|
||||
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
||||
str: A formatted string containing the top posts from the subreddit.
|
||||
"""
|
||||
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
before = start_date - relativedelta(days=look_back_days)
|
||||
before = before.strftime("%Y-%m-%d")
|
||||
|
||||
posts = []
|
||||
# iterate from start_date to end_date
|
||||
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
||||
|
||||
total_iterations = (start_date - curr_date).days + 1
|
||||
pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations)
|
||||
|
||||
while curr_date <= start_date:
|
||||
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
||||
fetch_result = fetch_top_from_category(
|
||||
"global_news",
|
||||
curr_date_str,
|
||||
max_limit_per_day,
|
||||
data_path=os.path.join(DATA_DIR, "reddit_data"),
|
||||
)
|
||||
posts.extend(fetch_result)
|
||||
curr_date += relativedelta(days=1)
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
posts = fetch_posts_from_reddit(symbol, subreddit_name, sort, limit)
|
||||
if len(posts) == 0:
|
||||
return ""
|
||||
|
||||
news_str = ""
|
||||
posts_str = ""
|
||||
for post in posts:
|
||||
if post["content"] == "":
|
||||
news_str += f"### {post['title']}\n\n"
|
||||
else:
|
||||
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
|
||||
posts_str += f"### {post['title']} (score: {post['score']}, created at: {datetime.utcfromtimestamp(post['created_utc']).strftime('%Y-%m-%d %H:%M:%S')})\n{post['content']}\n\n"
|
||||
|
||||
return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}"
|
||||
|
||||
def get_reddit_asset_news(
|
||||
ticker: Annotated[str, "ticker symbol of the asset"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest top reddit news
|
||||
Args:
|
||||
ticker: ticker symbol of the asset
|
||||
start_date: Start date in yyyy-mm-dd format
|
||||
end_date: End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
||||
"""
|
||||
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
before = start_date - relativedelta(days=look_back_days)
|
||||
before = before.strftime("%Y-%m-%d")
|
||||
|
||||
posts = []
|
||||
# iterate from start_date to end_date
|
||||
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
||||
|
||||
total_iterations = (start_date - curr_date).days + 1
|
||||
pbar = tqdm(
|
||||
desc=f"Getting Asset News for {ticker} on {start_date}",
|
||||
total=total_iterations,
|
||||
)
|
||||
|
||||
while curr_date <= start_date:
|
||||
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
||||
fetch_result = fetch_top_from_category(
|
||||
"asset_news",
|
||||
curr_date_str,
|
||||
max_limit_per_day,
|
||||
ticker,
|
||||
data_path=os.path.join(DATA_DIR, "reddit_data"),
|
||||
)
|
||||
posts.extend(fetch_result)
|
||||
curr_date += relativedelta(days=1)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
if len(posts) == 0:
|
||||
return ""
|
||||
|
||||
news_str = ""
|
||||
for post in posts:
|
||||
if post["content"] == "":
|
||||
news_str += f"### {post['title']}\n\n"
|
||||
else:
|
||||
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
|
||||
|
||||
return f"##{ticker} News Reddit, from {before} to {curr_date}:\n\n{news_str}"
|
||||
return f"## Reddit Posts in r/{subreddit_name} for {symbol}:\n{posts_str}"
|
||||
|
||||
def get_binance_ohlcv(
|
||||
symbol: Annotated[str, "ticker symbol of the asset"],
|
||||
|
|
|
|||
|
|
@ -1,135 +1,50 @@
|
|||
import requests
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import contextmanager
|
||||
from typing import Annotated
|
||||
|
||||
import os
|
||||
import re
|
||||
import praw
|
||||
|
||||
ticker_to_asset = {
|
||||
"AAPL": "Apple",
|
||||
"MSFT": "Microsoft",
|
||||
"GOOGL": "Google",
|
||||
"AMZN": "Amazon",
|
||||
"TSLA": "Tesla",
|
||||
"NVDA": "Nvidia",
|
||||
"TSM": "Taiwan Semiconductor Manufacturing Asset OR TSMC",
|
||||
"JPM": "JPMorgan Chase OR JP Morgan",
|
||||
"JNJ": "Johnson & Johnson OR JNJ",
|
||||
"V": "Visa",
|
||||
"WMT": "Walmart",
|
||||
"META": "Meta OR Facebook",
|
||||
"AMD": "AMD",
|
||||
"INTC": "Intel",
|
||||
"QCOM": "Qualcomm",
|
||||
"BABA": "Alibaba",
|
||||
"ADBE": "Adobe",
|
||||
"NFLX": "Netflix",
|
||||
"CRM": "Salesforce",
|
||||
"PYPL": "PayPal",
|
||||
"PLTR": "Palantir",
|
||||
"MU": "Micron",
|
||||
"SQ": "Block OR Square",
|
||||
"ZM": "Zoom",
|
||||
"CSCO": "Cisco",
|
||||
"SHOP": "Shopify",
|
||||
"ORCL": "Oracle",
|
||||
"X": "Twitter OR X",
|
||||
"SPOT": "Spotify",
|
||||
"AVGO": "Broadcom",
|
||||
"ASML": "ASML ",
|
||||
"TWLO": "Twilio",
|
||||
"SNAP": "Snap Inc.",
|
||||
"TEAM": "Atlassian",
|
||||
"SQSP": "Squarespace",
|
||||
"UBER": "Uber",
|
||||
"ROKU": "Roku",
|
||||
"PINS": "Pinterest",
|
||||
"BTC": "Bitcoin",
|
||||
"ETH": "Ethereum",
|
||||
"XRP": "XRP",
|
||||
"LTC": "Litecoin",
|
||||
"DOGE": "Dogecoin",
|
||||
"SOL": "Solana",
|
||||
"ADA": "Cardano",
|
||||
"DOT": "Polkadot",
|
||||
"AVAX": "Avalanche",
|
||||
}
|
||||
|
||||
|
||||
def fetch_top_from_category(
|
||||
category: Annotated[
|
||||
str, "Category to fetch top post from. Collection of subreddits."
|
||||
],
|
||||
date: Annotated[str, "Date to fetch top posts from."],
|
||||
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
||||
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
|
||||
data_path: Annotated[
|
||||
str,
|
||||
"Path to the data folder. Default is 'reddit_data'.",
|
||||
] = "reddit_data",
|
||||
def fetch_posts_from_reddit(
|
||||
symbol: str, subreddit_name: str,
|
||||
sort: str = "hot", limit: int = 25
|
||||
):
|
||||
base_path = data_path
|
||||
"""
|
||||
Fetch top posts from a specified subreddit.
|
||||
|
||||
all_content = []
|
||||
|
||||
if max_limit < len(os.listdir(os.path.join(base_path, category))):
|
||||
raise ValueError(
|
||||
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
|
||||
)
|
||||
|
||||
limit_per_subreddit = max_limit // len(
|
||||
os.listdir(os.path.join(base_path, category))
|
||||
Args:
|
||||
symbol (str): The ticker symbol of the asset to filter posts.
|
||||
subreddit (str): The name of the subreddit to fetch posts from.
|
||||
sort (str): The sorting method for posts ('hot', 'new', 'top', etc.).
|
||||
limit (int): The maximum number of posts to fetch.
|
||||
Returns:
|
||||
list: A list of dictionaries containing post data.
|
||||
"""
|
||||
reddit = praw.Reddit(
|
||||
client_id=os.getenv("REDDIT_CLIENT_ID"),
|
||||
client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
|
||||
username=os.getenv("REDDIT_USERNAME"),
|
||||
password=os.getenv("REDDIT_PASSWORD"),
|
||||
user_agent=os.getenv("REDDIT_USER_AGENT"),
|
||||
)
|
||||
|
||||
for data_file in os.listdir(os.path.join(base_path, category)):
|
||||
# check if data_file is a .jsonl file
|
||||
if not data_file.endswith(".jsonl"):
|
||||
continue
|
||||
|
||||
all_content_curr_subreddit = []
|
||||
|
||||
with open(os.path.join(base_path, category, data_file), "rb") as f:
|
||||
for i, line in enumerate(f):
|
||||
# skip empty lines
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
parsed_line = json.loads(line)
|
||||
|
||||
# select only lines that are from the date
|
||||
post_date = datetime.utcfromtimestamp(
|
||||
parsed_line["created_utc"]
|
||||
).strftime("%Y-%m-%d")
|
||||
if post_date != date:
|
||||
continue
|
||||
|
||||
# if is asset_news, check that the title or the content has the asset's name (query) mentioned
|
||||
if "asset" in category and query:
|
||||
search_terms = []
|
||||
if "OR" in ticker_to_asset[query]:
|
||||
search_terms = ticker_to_asset[query].split(" OR ")
|
||||
else:
|
||||
search_terms = [ticker_to_asset[query]]
|
||||
|
||||
search_terms.append(query)
|
||||
|
||||
found = False
|
||||
for term in search_terms:
|
||||
if re.search(
|
||||
term, parsed_line["title"], re.IGNORECASE
|
||||
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
continue
|
||||
|
||||
post = {
|
||||
"title": parsed_line["title"],
|
||||
"content": parsed_line["selftext"],
|
||||
"url": parsed_line["url"],
|
||||
"upvotes": parsed_line["ups"],
|
||||
"posted_date": post_date,
|
||||
}
|
||||
|
||||
all_content_curr_subreddit.append(post)
|
||||
|
||||
# sort all_content_curr_subreddit by upvote_ratio in descending order
|
||||
all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True)
|
||||
|
||||
all_content.extend(all_content_curr_subreddit[:limit_per_subreddit])
|
||||
|
||||
return all_content
|
||||
subreddit = reddit.subreddit(subreddit_name)
|
||||
query = symbol + " OR " + ticker_to_asset.get(symbol, symbol)
|
||||
submissions = subreddit.search(query, sort=sort, time_filter="day", limit=limit)
|
||||
return [
|
||||
{
|
||||
"title": submission.title,
|
||||
"content": submission.selftext,
|
||||
"score": submission.score,
|
||||
"created_utc": submission.created_utc,
|
||||
}
|
||||
for submission in submissions
|
||||
]
|
||||
|
|
@ -126,8 +126,8 @@ class TradingAgentsGraph:
|
|||
self.toolkit.get_binance_ohlcv,
|
||||
self.toolkit.get_fear_and_greed_index,
|
||||
self.toolkit.get_coinstats_btc_dominance,
|
||||
self.toolkit.get_reddit_posts,
|
||||
# self.toolkit.get_stock_news_openai,
|
||||
# self.toolkit.get_reddit_stock_info,
|
||||
]
|
||||
),
|
||||
"news": ToolNode(
|
||||
|
|
@ -138,7 +138,6 @@ class TradingAgentsGraph:
|
|||
# self.toolkit.get_google_news,
|
||||
self.toolkit.get_blockbeats_news,
|
||||
self.toolkit.get_coindesk_news,
|
||||
# self.toolkit.get_reddit_news,
|
||||
]
|
||||
),
|
||||
"fundamentals": ToolNode(
|
||||
|
|
|
|||
Loading…
Reference in New Issue