This commit is contained in:
Bcardo 2026-04-20 16:37:33 -04:00 committed by GitHub
commit 1cc71ebedd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 572 additions and 12 deletions

3
.gitignore vendored
View File

@ -1,3 +1,6 @@
# Claude Code local files
CLAUDE.md
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]

View File

@ -0,0 +1,215 @@
"""Mock-based unit tests for Reddit sentiment and Fear & Greed dataflows."""
import pytest
import requests
from unittest.mock import MagicMock, patch
from tradingagents.dataflows.reddit_sentiment import get_reddit_sentiment
from tradingagents.dataflows.fear_greed import get_fear_greed
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_post(post_id, title, score=100, num_comments=50, upvote_ratio=0.9,
flair=None, created_utc=9_999_999_999):
return {"kind": "t3", "data": {
"id": post_id,
"title": title,
"score": score,
"num_comments": num_comments,
"upvote_ratio": upvote_ratio,
"link_flair_text": flair,
"created_utc": created_utc,
}}
def _search_response(posts):
return {"data": {"children": posts}}
def _comment_response(comments):
comment_items = [
{"kind": "t1", "data": {"author": "user1", "body": c}}
for c in comments
]
return [
{"data": {"children": []}}, # post listing (unused)
{"data": {"children": comment_items}},
]
# ---------------------------------------------------------------------------
# Reddit — get_reddit_sentiment
# ---------------------------------------------------------------------------
class TestRedditSentiment:
def _patch_search(self, posts_by_subreddit):
"""Return a mock requests.get that returns given posts per subreddit."""
def fake_get(url, params=None, headers=None, timeout=None):
resp = MagicMock()
resp.ok = True
resp.status_code = 200
resp.encoding = "utf-8"
subreddit = url.split("/r/")[1].split("/")[0]
posts = posts_by_subreddit.get(subreddit, [])
resp.json.return_value = _search_response(posts)
return resp
return fake_get
def test_happy_path_returns_formatted_post(self):
posts = {"wallstreetbets": [
_make_post("abc1", "NVDA calls printing today", score=500, num_comments=80, upvote_ratio=0.92)
], "stocks": [], "options": []}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert "NVDA" in result
assert "NVDA calls printing today" in result
assert "Score: 500" in result
assert "Comments: 80" in result
assert "92%" in result
def test_no_posts_returns_informative_message(self):
empty = {"wallstreetbets": [], "stocks": [], "options": []}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(empty)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("XYZQ", days=7)
assert "No Reddit posts found" in result
assert "XYZQ" in result
def test_429_skips_subreddit_and_returns_no_posts_message(self):
"""429 from all subreddits → no posts collected → informative message returned."""
def rate_limited(*args, **kwargs):
resp = MagicMock()
resp.ok = False
resp.status_code = 429
return resp
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=rate_limited), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert "No Reddit posts found" in result
assert "NVDA" in result
def test_network_error_skips_subreddit_and_returns_no_posts_message(self):
"""Network failure on all subreddits → no posts collected → informative message returned."""
with patch("tradingagents.dataflows.reddit_sentiment.requests.get",
side_effect=requests.RequestException("connection reset")), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert "No Reddit posts found" in result
assert "NVDA" in result
def test_title_filter_removes_off_topic_posts(self):
"""Posts whose title doesn't contain ticker or company name are dropped."""
posts = {"wallstreetbets": [
_make_post("abc1", "SanDisk joins QQQ today", score=900), # off-topic
_make_post("abc2", "NVDA calls printing today", score=100), # on-topic
], "stocks": [], "options": []}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert "SanDisk" not in result
assert "NVDA calls printing today" in result
def test_company_name_keyword_matches_title(self):
"""Posts containing company name but not ticker are included."""
posts = {"wallstreetbets": [
_make_post("abc1", "Nvidia GPU demand surging", score=200),
], "stocks": [], "options": []}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value="Nvidia Corp"):
result = get_reddit_sentiment("NVDA", days=7)
assert "Nvidia GPU demand surging" in result
def test_deduplication_across_subreddits(self):
"""Same post appearing in multiple subreddit results is only shown once."""
same_post = _make_post("dup1", "NVDA bull case", score=50)
posts = {
"wallstreetbets": [same_post],
"stocks": [same_post],
"options": [],
}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert result.count("NVDA bull case") == 1
# ---------------------------------------------------------------------------
# Fear & Greed — get_fear_greed
# ---------------------------------------------------------------------------
class TestFearGreed:
def _fng_response(self, days):
import time
data = []
for i in range(days):
ts = int(time.time()) - i * 86400
data.append({
"value": str(30 + i),
"value_classification": "Fear",
"timestamp": str(ts),
})
return {"data": data}
def test_happy_path_returns_n_entries(self):
resp = MagicMock()
resp.ok = True
resp.status_code = 200
resp.encoding = "utf-8"
resp.json.return_value = self._fng_response(7)
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
result = get_fear_greed(7)
lines = [l for l in result.splitlines() if "Score:" in l]
assert len(lines) == 7
assert "Fear" in result
assert "/100" in result
def test_single_day_returns_one_entry(self):
resp = MagicMock()
resp.ok = True
resp.status_code = 200
resp.encoding = "utf-8"
resp.json.return_value = self._fng_response(1)
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
result = get_fear_greed(1)
lines = [l for l in result.splitlines() if "Score:" in l]
assert len(lines) == 1
def test_api_failure_returns_empty_string(self):
resp = MagicMock()
resp.ok = False
resp.status_code = 500
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
result = get_fear_greed(7)
assert result == ""
def test_network_error_returns_empty_string(self):
with patch("tradingagents.dataflows.fear_greed.requests.get",
side_effect=requests.RequestException("timeout")):
result = get_fear_greed(7)
assert result == ""

View File

@ -1,5 +1,6 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
from tradingagents.agents.utils.sentiment_tools import get_reddit_sentiment, get_market_fear_greed
from tradingagents.dataflows.config import get_config
@ -9,12 +10,26 @@ def create_social_media_analyst(llm):
instrument_context = build_instrument_context(state["company_of_interest"])
tools = [
get_news,
get_reddit_sentiment,
get_market_fear_greed,
]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
"You are a Sentiment Analyst. Your job is to gauge retail investor sentiment "
"and macro market mood for a specific stock.\n\n"
"Use get_reddit_sentiment(ticker, days) to fetch recent Reddit posts from "
"r/wallstreetbets, r/stocks, and r/options. Analyse titles, scores, comment "
"counts, upvote ratios, and the actual comments to assess retail mood.\n\n"
"Use get_market_fear_greed(days) to fetch the CNN Fear & Greed Index — a "
"market-wide macro signal. Use it to contextualise retail sentiment: bullish "
"Reddit posts carry more weight in a Greed market; bearish posts in an Extreme "
"Fear market may signal capitulation.\n\n"
"If Reddit returns no posts (obscure or small-cap ticker), state that clearly — "
"absence of retail coverage is itself a signal.\n\n"
"Write a comprehensive sentiment report covering: overall retail mood (bullish / "
"bearish / mixed), engagement level, notable narratives, current Fear & Greed "
"reading, and implications for short-term trader sentiment. "
"Append a Markdown table at the end summarising key data points."
+ get_language_instruction()
)

View File

@ -0,0 +1,35 @@
from langchain_core.tools import tool
from typing import Annotated
from tradingagents.dataflows.interface import route_to_vendor
@tool
def get_market_fear_greed(
days: Annotated[int, "Number of past days to fetch (default 7)"] = 7,
) -> str:
"""
Fetch the CNN Fear & Greed Index time series from alternative.me.
Returns one entry per day with a numeric score (0100) and classification
label (Extreme Fear / Fear / Neutral / Greed / Extreme Greed).
This is a market-wide macro signal not ticker-specific. Use it to
contextualise retail sentiment against the broader market mood.
"""
return route_to_vendor("get_market_fear_greed", days)
@tool
def get_reddit_sentiment(
ticker: Annotated[str, "Stock ticker symbol, e.g. NVDA or AAPL"],
days: Annotated[int, "Number of past days to search for posts (default 3)"] = 3,
) -> str:
"""
Fetch recent Reddit posts mentioning a stock ticker from investing subreddits
(r/wallstreetbets, r/stocks, r/options). Returns post titles, upvote scores,
comment counts, and upvote ratios as a formatted string. This is a
ticker-specific retail sentiment signal, not a news source.
Automatically searches by both ticker symbol and company name (e.g. NVDA OR
Nvidia) so posts using the company name are not missed.
Returns an empty string on API failure; returns a 'no posts found' message
for obscure tickers with no Reddit coverage.
"""
return route_to_vendor("get_reddit_sentiment", ticker, days)

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
"""CNN Fear & Greed Index fetching via alternative.me public API (no auth required)."""
import logging
import requests
from datetime import datetime, timezone
logger = logging.getLogger(__name__)
_URL = "https://api.alternative.me/fng/"
_TIMEOUT = 10
def get_fear_greed(days: int = 7) -> str:
"""
Fetch the CNN Fear & Greed Index time series from alternative.me.
Returns one entry per day with a numeric score (0100) and classification
label (Extreme Fear / Fear / Neutral / Greed / Extreme Greed). This is a
market-wide macro signal not ticker-specific.
Args:
days: Number of past days to fetch (default 7)
Returns:
Formatted string of daily entries, or empty string on API failure.
"""
try:
r = requests.get(_URL, params={"limit": days}, timeout=_TIMEOUT)
except requests.RequestException as e:
logger.warning("Fear & Greed API request failed: %s", e)
return ""
if not r.ok:
logger.warning("Fear & Greed API returned HTTP %s", r.status_code)
return ""
r.encoding = "utf-8"
try:
data = r.json().get("data", [])
except ValueError:
logger.warning("Fear & Greed API returned invalid JSON")
return ""
if not data:
return ""
lines = [f"CNN Fear & Greed Index (last {days} days, most recent first):\n"]
for entry in data:
ts = int(entry.get("timestamp", 0))
date = datetime.fromtimestamp(ts, tz=timezone.utc).strftime("%Y-%m-%d")
score = entry.get("value", "?")
label = entry.get("value_classification", "?")
lines.append(f"{date} | Score: {score}/100 | {label}")
return "\n".join(lines)

View File

@ -23,6 +23,8 @@ from .alpha_vantage import (
get_global_news as get_alpha_vantage_global_news,
)
from .alpha_vantage_common import AlphaVantageRateLimitError
from .reddit_sentiment import get_reddit_sentiment as get_reddit_sentiment_impl
from .fear_greed import get_fear_greed as get_fear_greed_impl
# Configuration and routing logic
from .config import get_config
@ -57,6 +59,13 @@ TOOLS_CATEGORIES = {
"get_global_news",
"get_insider_transactions",
]
},
"sentiment_data": {
"description": "Retail sentiment and market mood data",
"tools": [
"get_reddit_sentiment",
"get_market_fear_greed",
]
}
}
@ -107,6 +116,13 @@ VENDOR_METHODS = {
"alpha_vantage": get_alpha_vantage_insider_transactions,
"yfinance": get_yfinance_insider_transactions,
},
# sentiment_data
"get_reddit_sentiment": {
"default": get_reddit_sentiment_impl,
},
"get_market_fear_greed": {
"default": get_fear_greed_impl,
},
}
def get_category_for_method(method: str) -> str:

View File

@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
"""Reddit-based retail sentiment fetching via public JSON API (no auth required)."""
import logging
import requests
import yfinance as yf
from datetime import datetime, timedelta, timezone
logger = logging.getLogger(__name__)
_HEADERS = {"User-Agent": "TradingAgentsBot/0.1"}
_SUBREDDITS = ["wallstreetbets", "stocks", "options"]
_TIMEOUT = 10
_COMMENT_PREVIEW = 200 # max chars per comment shown to LLM
_TOP_POSTS_WITH_COMMENTS = 3 # fetch comments for this many top posts only
def _get_company_name(ticker: str) -> str:
"""Look up the short company name for a ticker via yfinance."""
try:
info = yf.Ticker(ticker).info
return info.get("shortName") or info.get("longName") or ""
except Exception:
return ""
def _search_subreddit(subreddit: str, query: str) -> list:
"""Fetch up to 25 posts from a subreddit matching query. Returns raw post dicts."""
params = {
"q": query,
"restrict_sr": 1,
"sort": "relevance",
"limit": 25,
"t": "week",
}
try:
r = requests.get(
f"https://www.reddit.com/r/{subreddit}/search.json",
params=params,
headers=_HEADERS,
timeout=_TIMEOUT,
)
except requests.RequestException as e:
logger.warning("Reddit API request failed for r/%s: %s", subreddit, e)
return []
if r.status_code == 429:
logger.warning("Reddit API rate limit hit (429) for r/%s", subreddit)
return []
if not r.ok:
logger.warning("Reddit API returned HTTP %s for r/%s", r.status_code, subreddit)
return []
r.encoding = "utf-8"
try:
return r.json().get("data", {}).get("children", [])
except ValueError:
logger.warning("Reddit API returned invalid JSON for r/%s", subreddit)
return []
def _fetch_top_comments(subreddit: str, post_id: str, limit: int = 20) -> list[str]:
"""
Fetch top-level comments for a post, sorted by score.
Returns a list of comment body strings (truncated to _COMMENT_PREVIEW chars).
"""
try:
r = requests.get(
f"https://www.reddit.com/r/{subreddit}/comments/{post_id}.json",
params={"sort": "top", "limit": limit, "depth": 1},
headers=_HEADERS,
timeout=_TIMEOUT,
)
except requests.RequestException as e:
logger.warning("Reddit comment fetch failed for %s/%s: %s", subreddit, post_id, e)
return []
if r.status_code == 429:
logger.warning("Reddit API rate limit hit (429) fetching comments for %s", post_id)
return []
if not r.ok:
logger.warning("Reddit comment API returned HTTP %s for %s", r.status_code, post_id)
return []
r.encoding = "utf-8"
try:
data = r.json()
except ValueError:
return []
# Response is [post_listing, comment_listing]
if len(data) < 2:
return []
_BOT_AUTHORS = {"automoderator", "visualmod"}
comments = []
for item in data[1].get("data", {}).get("children", []):
cdata = item.get("data", {})
author = cdata.get("author", "").lower()
if author in _BOT_AUTHORS:
continue
body = cdata.get("body", "")
if not body or body == "[deleted]" or body == "[removed]":
continue
# Truncate and clean whitespace
body = " ".join(body.split())
comments.append(body[:_COMMENT_PREVIEW])
return comments
def get_reddit_sentiment(ticker: str, days: int = 3) -> str:
"""
Fetch recent Reddit posts mentioning a ticker from investing subreddits.
Searches r/wallstreetbets, r/stocks, and r/options via Reddit's public
JSON API (no authentication required). Runs separate queries for the ticker
symbol and company name so posts using either form are captured. Only keeps
posts whose title contains the ticker or company name. Fetches top comments
for the three highest-scoring posts to capture actual retail discussion.
Args:
ticker: Stock ticker symbol (e.g., "NVDA")
days: Number of days to look back (default 3)
Returns:
Formatted string of matching posts with top comments, or empty string
on API failure.
"""
cutoff = datetime.now(tz=timezone.utc) - timedelta(days=days)
company_name = _get_company_name(ticker)
# Use the first meaningful word of the company name as a separate search term
# e.g. "NVIDIA Corporation" → "NVIDIA", "Apple Inc." → "Apple"
name_keyword = ""
if company_name:
first_word = company_name.split()[0]
if len(first_word) > 3 and first_word.upper() != ticker.upper():
name_keyword = first_word
search_terms = [ticker]
if name_keyword:
search_terms.append(name_keyword)
seen_ids = set()
all_posts = []
for subreddit in _SUBREDDITS:
for term in search_terms:
for item in _search_subreddit(subreddit, term):
post = item.get("data", {})
post_id = post.get("id", "")
if post_id in seen_ids:
continue
# Only keep posts whose title mentions ticker or company name
title_lower = post.get("title", "").lower()
if ticker.lower() not in title_lower and (
not name_keyword or name_keyword.lower() not in title_lower
):
continue
created_utc = post.get("created_utc", 0)
post_time = datetime.fromtimestamp(created_utc, tz=timezone.utc)
if post_time < cutoff:
continue
seen_ids.add(post_id)
all_posts.append({
"id": post_id,
"subreddit": subreddit,
"title": post.get("title", ""),
"score": post.get("score", 0),
"num_comments": post.get("num_comments", 0),
"upvote_ratio": post.get("upvote_ratio", 0.0),
"flair": post.get("link_flair_text") or "",
})
if not all_posts:
return (
f"No Reddit posts found mentioning {ticker} in the last {days} days "
f"across r/wallstreetbets, r/stocks, r/options."
)
# Sort by score descending — highest engagement first
all_posts.sort(key=lambda p: p["score"], reverse=True)
# Fetch comments for top N posts only to keep API calls bounded
for post in all_posts[:_TOP_POSTS_WITH_COMMENTS]:
post["comments"] = _fetch_top_comments(post["subreddit"], post["id"])
label = f"{ticker}" + (f" / {name_keyword}" if name_keyword else "")
lines = [f"Reddit posts mentioning {label} (last {days} days, sorted by upvotes):\n"]
for i, p in enumerate(all_posts):
flair = f" [{p['flair']}]" if p["flair"] else ""
lines.append(
f"r/{p['subreddit']}{flair} | "
f"Score: {p['score']} | "
f"Comments: {p['num_comments']} | "
f"Upvote ratio: {p['upvote_ratio']:.0%} | "
f"{p['title']}"
)
comments = p.get("comments", [])
if comments:
for c in comments:
lines.append(f" > {c}")
return "\n".join(lines)

View File

@ -29,6 +29,7 @@ DEFAULT_CONFIG = {
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
"news_data": "yfinance", # Options: alpha_vantage, yfinance
"sentiment_data": "default", # Reddit, Fear&Greed, Discord UW (no vendor alt)
},
# Tool-level configuration (takes precedence over category-level)
"tool_vendors": {

View File

@ -104,14 +104,21 @@ class GraphSetup:
self.deep_thinking_llm, self.portfolio_manager_memory
)
# Maps analyst selector keys to their display names in the graph.
# Add entries here when a key and its label should differ.
_labels = {"social": "Sentiment"}
def _label(analyst_type: str) -> str:
return _labels.get(analyst_type, analyst_type.capitalize())
# Create workflow
workflow = StateGraph(AgentState)
# Add analyst nodes to the graph
for analyst_type, node in analyst_nodes.items():
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
workflow.add_node(f"{_label(analyst_type)} Analyst", node)
workflow.add_node(
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
f"Msg Clear {_label(analyst_type)}", delete_nodes[analyst_type]
)
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
@ -128,13 +135,13 @@ class GraphSetup:
# Define edges
# Start with the first analyst
first_analyst = selected_analysts[0]
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
workflow.add_edge(START, f"{_label(first_analyst)} Analyst")
# Connect analysts in sequence
for i, analyst_type in enumerate(selected_analysts):
current_analyst = f"{analyst_type.capitalize()} Analyst"
current_analyst = f"{_label(analyst_type)} Analyst"
current_tools = f"tools_{analyst_type}"
current_clear = f"Msg Clear {analyst_type.capitalize()}"
current_clear = f"Msg Clear {_label(analyst_type)}"
# Add conditional edges for current analyst
workflow.add_conditional_edges(
@ -146,7 +153,7 @@ class GraphSetup:
# Connect to next analyst or to Bull Researcher if this is the last analyst
if i < len(selected_analysts) - 1:
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
next_analyst = f"{_label(selected_analysts[i+1])} Analyst"
workflow.add_edge(current_clear, next_analyst)
else:
workflow.add_edge(current_clear, "Bull Researcher")

View File

@ -32,6 +32,7 @@ from tradingagents.agents.utils.agent_utils import (
get_insider_transactions,
get_global_news
)
from tradingagents.agents.utils.sentiment_tools import get_reddit_sentiment, get_market_fear_greed
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
@ -166,8 +167,8 @@ class TradingAgentsGraph:
),
"social": ToolNode(
[
# News tools for social media analysis
get_news,
get_reddit_sentiment,
get_market_fear_greed,
]
),
"news": ToolNode(