TradingAgents/tests/test_sentiment_tools.py

216 lines
8.3 KiB
Python

"""Mock-based unit tests for Reddit sentiment and Fear & Greed dataflows."""
import pytest
import requests
from unittest.mock import MagicMock, patch
from tradingagents.dataflows.reddit_sentiment import get_reddit_sentiment
from tradingagents.dataflows.fear_greed import get_fear_greed
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_post(post_id, title, score=100, num_comments=50, upvote_ratio=0.9,
flair=None, created_utc=9_999_999_999):
return {"kind": "t3", "data": {
"id": post_id,
"title": title,
"score": score,
"num_comments": num_comments,
"upvote_ratio": upvote_ratio,
"link_flair_text": flair,
"created_utc": created_utc,
}}
def _search_response(posts):
return {"data": {"children": posts}}
def _comment_response(comments):
comment_items = [
{"kind": "t1", "data": {"author": "user1", "body": c}}
for c in comments
]
return [
{"data": {"children": []}}, # post listing (unused)
{"data": {"children": comment_items}},
]
# ---------------------------------------------------------------------------
# Reddit — get_reddit_sentiment
# ---------------------------------------------------------------------------
class TestRedditSentiment:
def _patch_search(self, posts_by_subreddit):
"""Return a mock requests.get that returns given posts per subreddit."""
def fake_get(url, params=None, headers=None, timeout=None):
resp = MagicMock()
resp.ok = True
resp.status_code = 200
resp.encoding = "utf-8"
subreddit = url.split("/r/")[1].split("/")[0]
posts = posts_by_subreddit.get(subreddit, [])
resp.json.return_value = _search_response(posts)
return resp
return fake_get
def test_happy_path_returns_formatted_post(self):
posts = {"wallstreetbets": [
_make_post("abc1", "NVDA calls printing today", score=500, num_comments=80, upvote_ratio=0.92)
], "stocks": [], "options": []}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert "NVDA" in result
assert "NVDA calls printing today" in result
assert "Score: 500" in result
assert "Comments: 80" in result
assert "92%" in result
def test_no_posts_returns_informative_message(self):
empty = {"wallstreetbets": [], "stocks": [], "options": []}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(empty)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("XYZQ", days=7)
assert "No Reddit posts found" in result
assert "XYZQ" in result
def test_429_skips_subreddit_and_returns_no_posts_message(self):
"""429 from all subreddits → no posts collected → informative message returned."""
def rate_limited(*args, **kwargs):
resp = MagicMock()
resp.ok = False
resp.status_code = 429
return resp
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=rate_limited), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert "No Reddit posts found" in result
assert "NVDA" in result
def test_network_error_skips_subreddit_and_returns_no_posts_message(self):
"""Network failure on all subreddits → no posts collected → informative message returned."""
with patch("tradingagents.dataflows.reddit_sentiment.requests.get",
side_effect=requests.RequestException("connection reset")), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert "No Reddit posts found" in result
assert "NVDA" in result
def test_title_filter_removes_off_topic_posts(self):
"""Posts whose title doesn't contain ticker or company name are dropped."""
posts = {"wallstreetbets": [
_make_post("abc1", "SanDisk joins QQQ today", score=900), # off-topic
_make_post("abc2", "NVDA calls printing today", score=100), # on-topic
], "stocks": [], "options": []}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert "SanDisk" not in result
assert "NVDA calls printing today" in result
def test_company_name_keyword_matches_title(self):
"""Posts containing company name but not ticker are included."""
posts = {"wallstreetbets": [
_make_post("abc1", "Nvidia GPU demand surging", score=200),
], "stocks": [], "options": []}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value="Nvidia Corp"):
result = get_reddit_sentiment("NVDA", days=7)
assert "Nvidia GPU demand surging" in result
def test_deduplication_across_subreddits(self):
"""Same post appearing in multiple subreddit results is only shown once."""
same_post = _make_post("dup1", "NVDA bull case", score=50)
posts = {
"wallstreetbets": [same_post],
"stocks": [same_post],
"options": [],
}
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
result = get_reddit_sentiment("NVDA", days=7)
assert result.count("NVDA bull case") == 1
# ---------------------------------------------------------------------------
# Fear & Greed — get_fear_greed
# ---------------------------------------------------------------------------
class TestFearGreed:
def _fng_response(self, days):
import time
data = []
for i in range(days):
ts = int(time.time()) - i * 86400
data.append({
"value": str(30 + i),
"value_classification": "Fear",
"timestamp": str(ts),
})
return {"data": data}
def test_happy_path_returns_n_entries(self):
resp = MagicMock()
resp.ok = True
resp.status_code = 200
resp.encoding = "utf-8"
resp.json.return_value = self._fng_response(7)
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
result = get_fear_greed(7)
lines = [l for l in result.splitlines() if "Score:" in l]
assert len(lines) == 7
assert "Fear" in result
assert "/100" in result
def test_single_day_returns_one_entry(self):
resp = MagicMock()
resp.ok = True
resp.status_code = 200
resp.encoding = "utf-8"
resp.json.return_value = self._fng_response(1)
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
result = get_fear_greed(1)
lines = [l for l in result.splitlines() if "Score:" in l]
assert len(lines) == 1
def test_api_failure_returns_empty_string(self):
resp = MagicMock()
resp.ok = False
resp.status_code = 500
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
result = get_fear_greed(7)
assert result == ""
def test_network_error_returns_empty_string(self):
with patch("tradingagents.dataflows.fear_greed.requests.get",
side_effect=requests.RequestException("timeout")):
result = get_fear_greed(7)
assert result == ""