216 lines
8.3 KiB
Python
216 lines
8.3 KiB
Python
"""Mock-based unit tests for Reddit sentiment and Fear & Greed dataflows."""
|
|
|
|
import pytest
|
|
import requests
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from tradingagents.dataflows.reddit_sentiment import get_reddit_sentiment
|
|
from tradingagents.dataflows.fear_greed import get_fear_greed
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_post(post_id, title, score=100, num_comments=50, upvote_ratio=0.9,
|
|
flair=None, created_utc=9_999_999_999):
|
|
return {"kind": "t3", "data": {
|
|
"id": post_id,
|
|
"title": title,
|
|
"score": score,
|
|
"num_comments": num_comments,
|
|
"upvote_ratio": upvote_ratio,
|
|
"link_flair_text": flair,
|
|
"created_utc": created_utc,
|
|
}}
|
|
|
|
|
|
def _search_response(posts):
|
|
return {"data": {"children": posts}}
|
|
|
|
|
|
def _comment_response(comments):
|
|
comment_items = [
|
|
{"kind": "t1", "data": {"author": "user1", "body": c}}
|
|
for c in comments
|
|
]
|
|
return [
|
|
{"data": {"children": []}}, # post listing (unused)
|
|
{"data": {"children": comment_items}},
|
|
]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reddit — get_reddit_sentiment
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRedditSentiment:
|
|
|
|
def _patch_search(self, posts_by_subreddit):
|
|
"""Return a mock requests.get that returns given posts per subreddit."""
|
|
def fake_get(url, params=None, headers=None, timeout=None):
|
|
resp = MagicMock()
|
|
resp.ok = True
|
|
resp.status_code = 200
|
|
resp.encoding = "utf-8"
|
|
subreddit = url.split("/r/")[1].split("/")[0]
|
|
posts = posts_by_subreddit.get(subreddit, [])
|
|
resp.json.return_value = _search_response(posts)
|
|
return resp
|
|
return fake_get
|
|
|
|
def test_happy_path_returns_formatted_post(self):
|
|
posts = {"wallstreetbets": [
|
|
_make_post("abc1", "NVDA calls printing today", score=500, num_comments=80, upvote_ratio=0.92)
|
|
], "stocks": [], "options": []}
|
|
|
|
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
|
|
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
|
|
result = get_reddit_sentiment("NVDA", days=7)
|
|
|
|
assert "NVDA" in result
|
|
assert "NVDA calls printing today" in result
|
|
assert "Score: 500" in result
|
|
assert "Comments: 80" in result
|
|
assert "92%" in result
|
|
|
|
def test_no_posts_returns_informative_message(self):
|
|
empty = {"wallstreetbets": [], "stocks": [], "options": []}
|
|
|
|
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(empty)), \
|
|
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
|
|
result = get_reddit_sentiment("XYZQ", days=7)
|
|
|
|
assert "No Reddit posts found" in result
|
|
assert "XYZQ" in result
|
|
|
|
def test_429_skips_subreddit_and_returns_no_posts_message(self):
|
|
"""429 from all subreddits → no posts collected → informative message returned."""
|
|
def rate_limited(*args, **kwargs):
|
|
resp = MagicMock()
|
|
resp.ok = False
|
|
resp.status_code = 429
|
|
return resp
|
|
|
|
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=rate_limited), \
|
|
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
|
|
result = get_reddit_sentiment("NVDA", days=7)
|
|
|
|
assert "No Reddit posts found" in result
|
|
assert "NVDA" in result
|
|
|
|
def test_network_error_skips_subreddit_and_returns_no_posts_message(self):
|
|
"""Network failure on all subreddits → no posts collected → informative message returned."""
|
|
with patch("tradingagents.dataflows.reddit_sentiment.requests.get",
|
|
side_effect=requests.RequestException("connection reset")), \
|
|
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
|
|
result = get_reddit_sentiment("NVDA", days=7)
|
|
|
|
assert "No Reddit posts found" in result
|
|
assert "NVDA" in result
|
|
|
|
def test_title_filter_removes_off_topic_posts(self):
|
|
"""Posts whose title doesn't contain ticker or company name are dropped."""
|
|
posts = {"wallstreetbets": [
|
|
_make_post("abc1", "SanDisk joins QQQ today", score=900), # off-topic
|
|
_make_post("abc2", "NVDA calls printing today", score=100), # on-topic
|
|
], "stocks": [], "options": []}
|
|
|
|
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
|
|
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
|
|
result = get_reddit_sentiment("NVDA", days=7)
|
|
|
|
assert "SanDisk" not in result
|
|
assert "NVDA calls printing today" in result
|
|
|
|
def test_company_name_keyword_matches_title(self):
|
|
"""Posts containing company name but not ticker are included."""
|
|
posts = {"wallstreetbets": [
|
|
_make_post("abc1", "Nvidia GPU demand surging", score=200),
|
|
], "stocks": [], "options": []}
|
|
|
|
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
|
|
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value="Nvidia Corp"):
|
|
result = get_reddit_sentiment("NVDA", days=7)
|
|
|
|
assert "Nvidia GPU demand surging" in result
|
|
|
|
def test_deduplication_across_subreddits(self):
|
|
"""Same post appearing in multiple subreddit results is only shown once."""
|
|
same_post = _make_post("dup1", "NVDA bull case", score=50)
|
|
posts = {
|
|
"wallstreetbets": [same_post],
|
|
"stocks": [same_post],
|
|
"options": [],
|
|
}
|
|
|
|
with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \
|
|
patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""):
|
|
result = get_reddit_sentiment("NVDA", days=7)
|
|
|
|
assert result.count("NVDA bull case") == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fear & Greed — get_fear_greed
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestFearGreed:
|
|
|
|
def _fng_response(self, days):
|
|
import time
|
|
data = []
|
|
for i in range(days):
|
|
ts = int(time.time()) - i * 86400
|
|
data.append({
|
|
"value": str(30 + i),
|
|
"value_classification": "Fear",
|
|
"timestamp": str(ts),
|
|
})
|
|
return {"data": data}
|
|
|
|
def test_happy_path_returns_n_entries(self):
|
|
resp = MagicMock()
|
|
resp.ok = True
|
|
resp.status_code = 200
|
|
resp.encoding = "utf-8"
|
|
resp.json.return_value = self._fng_response(7)
|
|
|
|
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
|
|
result = get_fear_greed(7)
|
|
|
|
lines = [l for l in result.splitlines() if "Score:" in l]
|
|
assert len(lines) == 7
|
|
assert "Fear" in result
|
|
assert "/100" in result
|
|
|
|
def test_single_day_returns_one_entry(self):
|
|
resp = MagicMock()
|
|
resp.ok = True
|
|
resp.status_code = 200
|
|
resp.encoding = "utf-8"
|
|
resp.json.return_value = self._fng_response(1)
|
|
|
|
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
|
|
result = get_fear_greed(1)
|
|
|
|
lines = [l for l in result.splitlines() if "Score:" in l]
|
|
assert len(lines) == 1
|
|
|
|
def test_api_failure_returns_empty_string(self):
|
|
resp = MagicMock()
|
|
resp.ok = False
|
|
resp.status_code = 500
|
|
|
|
with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp):
|
|
result = get_fear_greed(7)
|
|
|
|
assert result == ""
|
|
|
|
def test_network_error_returns_empty_string(self):
|
|
with patch("tradingagents.dataflows.fear_greed.requests.get",
|
|
side_effect=requests.RequestException("timeout")):
|
|
result = get_fear_greed(7)
|
|
|
|
assert result == ""
|