379 lines
13 KiB
Python
379 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test SocialMediaService with mock RedditClient and real SocialRepository.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
|
|
# Add the project root to the path
|
|
sys.path.insert(0, os.path.abspath("."))
|
|
|
|
from tradingagents.clients.base import BaseClient
|
|
from tradingagents.models.context import (
|
|
DataQuality,
|
|
PostData,
|
|
SentimentScore,
|
|
SocialContext,
|
|
)
|
|
from tradingagents.repositories.social_repository import SocialRepository
|
|
from tradingagents.services.social_media_service import SocialMediaService
|
|
|
|
|
|
class MockRedditClient(BaseClient):
|
|
"""Mock Reddit client that returns sample social media data."""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.connection_works = True
|
|
|
|
def test_connection(self) -> bool:
|
|
return self.connection_works
|
|
|
|
def get_data(self, *args, **kwargs) -> dict[str, Any]:
|
|
"""Not used directly by SocialMediaService."""
|
|
return {}
|
|
|
|
def search_posts(
|
|
self,
|
|
query: str,
|
|
subreddit_names: list[str],
|
|
limit: int = 25,
|
|
time_filter: str = "week",
|
|
) -> list[dict[str, Any]]:
|
|
"""Return mock Reddit search results."""
|
|
posts = []
|
|
# Use fixed dates that will work with our date filter
|
|
base_date = datetime(2024, 1, 2) # Within our test range
|
|
|
|
for i, subreddit in enumerate(
|
|
subreddit_names[:2]
|
|
): # Limit to 2 subreddits for testing
|
|
posts.extend(
|
|
[
|
|
{
|
|
"title": f"{query} to the moon! 🚀",
|
|
"content": f"DD on {query}: Strong fundamentals, great earnings beat. Buy and hold!",
|
|
"url": f"https://reddit.com/r/{subreddit}/post1",
|
|
"upvotes": 1500 - (i * 100),
|
|
"score": 1450 - (i * 100),
|
|
"num_comments": 234,
|
|
"created_utc": (base_date + timedelta(hours=i)).timestamp(),
|
|
"subreddit": subreddit,
|
|
"author": f"WSBtrader{i}",
|
|
"posted_date": (base_date + timedelta(hours=i)).strftime(
|
|
"%Y-%m-%d"
|
|
),
|
|
},
|
|
{
|
|
"title": f"Why I'm bearish on {query}",
|
|
"content": f"Overvalued, competition increasing, margins declining. Time to sell {query}.",
|
|
"url": f"https://reddit.com/r/{subreddit}/post2",
|
|
"upvotes": 800 - (i * 50),
|
|
"score": 750 - (i * 50),
|
|
"num_comments": 156,
|
|
"created_utc": (base_date + timedelta(hours=i + 1)).timestamp(),
|
|
"subreddit": subreddit,
|
|
"author": f"BearishTrader{i}",
|
|
"posted_date": (base_date + timedelta(hours=i + 1)).strftime(
|
|
"%Y-%m-%d"
|
|
),
|
|
},
|
|
]
|
|
)
|
|
return posts
|
|
|
|
def get_top_posts(
|
|
self, subreddit_names: list[str], limit: int = 25, time_filter: str = "week"
|
|
) -> list[dict[str, Any]]:
|
|
"""Return mock top posts from subreddits."""
|
|
posts = []
|
|
# Use fixed dates that will work with our date filter
|
|
base_date = datetime(2024, 1, 2) # Within our test range
|
|
|
|
for subreddit in subreddit_names[:2]:
|
|
posts.append(
|
|
{
|
|
"title": "Market Update: Tech stocks rally continues",
|
|
"content": "FAANG stocks leading the charge. SPY hit new ATH. Bull market confirmed.",
|
|
"url": f"https://reddit.com/r/{subreddit}/top1",
|
|
"upvotes": 2500,
|
|
"score": 2400,
|
|
"num_comments": 456,
|
|
"created_utc": base_date.timestamp(),
|
|
"subreddit": subreddit,
|
|
"author": "MarketWatcher",
|
|
"posted_date": base_date.strftime("%Y-%m-%d"),
|
|
}
|
|
)
|
|
return posts
|
|
|
|
def filter_posts_by_date(
|
|
self, posts: list[dict[str, Any]], start_date: str, end_date: str
|
|
) -> list[dict[str, Any]]:
|
|
"""Filter posts by date range."""
|
|
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
|
|
|
|
filtered = []
|
|
for post in posts:
|
|
if "posted_date" in post:
|
|
post_dt = datetime.strptime(post["posted_date"], "%Y-%m-%d")
|
|
if start_dt <= post_dt <= end_dt:
|
|
filtered.append(post)
|
|
return filtered
|
|
|
|
|
|
def test_online_mode_with_mock_reddit():
|
|
"""Test SocialMediaService in online mode with mock Reddit client."""
|
|
# Create mock client and real repository
|
|
mock_reddit = MockRedditClient()
|
|
real_repo = SocialRepository("test_data")
|
|
|
|
# Create service in online mode
|
|
service = SocialMediaService(
|
|
reddit_client=mock_reddit,
|
|
repository=real_repo,
|
|
online_mode=True,
|
|
data_dir="test_data",
|
|
)
|
|
|
|
# Test company-specific social context
|
|
context = service.get_company_social_context(
|
|
symbol="TSLA",
|
|
start_date="2024-01-01",
|
|
end_date="2024-01-05",
|
|
subreddits=["wallstreetbets", "stocks"],
|
|
force_refresh=True,
|
|
)
|
|
|
|
# Validate context structure
|
|
assert isinstance(context, SocialContext)
|
|
assert context.symbol == "TSLA"
|
|
assert context.period["start"] == "2024-01-01"
|
|
assert context.period["end"] == "2024-01-05"
|
|
assert len(context.posts) > 0
|
|
assert isinstance(context.sentiment_summary, SentimentScore)
|
|
assert context.post_count == len(context.posts)
|
|
assert "data_quality" in context.metadata
|
|
|
|
# Test JSON serialization
|
|
json_output = context.model_dump_json(indent=2)
|
|
assert len(json_output) > 0
|
|
|
|
# Validate individual posts
|
|
for post in context.posts:
|
|
assert isinstance(post, PostData)
|
|
assert post.title
|
|
assert post.author
|
|
assert post.date
|
|
assert post.score >= 0
|
|
|
|
|
|
def test_global_social_trends():
|
|
"""Test global social media trends functionality."""
|
|
mock_reddit = MockRedditClient()
|
|
real_repo = SocialRepository("test_data")
|
|
|
|
service = SocialMediaService(
|
|
reddit_client=mock_reddit, repository=real_repo, online_mode=True
|
|
)
|
|
|
|
# Test global trends
|
|
context = service.get_global_trends(
|
|
start_date="2024-01-01",
|
|
end_date="2024-01-03",
|
|
subreddits=["investing", "stocks", "wallstreetbets"],
|
|
force_refresh=True,
|
|
)
|
|
|
|
# Validate global context
|
|
assert context.symbol is None # Global trends have no specific symbol
|
|
assert len(context.posts) > 0
|
|
assert "reddit" in context.platforms
|
|
assert "subreddits" in context.metadata
|
|
|
|
|
|
def test_sentiment_analysis():
|
|
"""Test sentiment analysis on social posts."""
|
|
|
|
# Create service with posts that have clear sentiment
|
|
class SentimentTestClient(MockRedditClient):
|
|
def search_posts(self, query, subreddit_names, limit=25, time_filter="week"):
|
|
return [
|
|
{
|
|
"title": f"{query} is the best investment ever! 🚀🚀🚀",
|
|
"content": "Amazing earnings, incredible growth, bullish AF!",
|
|
"upvotes": 5000,
|
|
"score": 4900,
|
|
"num_comments": 500,
|
|
"subreddit": "wallstreetbets",
|
|
"author": "BullishTrader",
|
|
"posted_date": "2024-01-01",
|
|
},
|
|
{
|
|
"title": f"WARNING: {query} is about to crash hard",
|
|
"content": "Terrible fundamentals, overvalued, sell now before it's too late!",
|
|
"upvotes": 100,
|
|
"score": 50,
|
|
"num_comments": 30,
|
|
"subreddit": "stocks",
|
|
"author": "BearishAnalyst",
|
|
"posted_date": "2024-01-01",
|
|
},
|
|
]
|
|
|
|
sentiment_client = SentimentTestClient()
|
|
service = SocialMediaService(
|
|
reddit_client=sentiment_client, repository=None, online_mode=True
|
|
)
|
|
|
|
context = service.get_context("GME", "2024-01-01", "2024-01-02")
|
|
|
|
# Check sentiment analysis
|
|
assert context.sentiment_summary.score != 0 # Should have some sentiment
|
|
assert context.sentiment_summary.confidence > 0
|
|
assert context.sentiment_summary.label in ["positive", "negative", "neutral"]
|
|
|
|
# Check individual post sentiments
|
|
for post in context.posts:
|
|
if post.sentiment:
|
|
assert -1.0 <= post.sentiment.score <= 1.0
|
|
|
|
|
|
def test_offline_mode():
|
|
"""Test SocialMediaService in offline mode."""
|
|
real_repo = SocialRepository("test_data")
|
|
|
|
service = SocialMediaService(
|
|
reddit_client=None, repository=real_repo, online_mode=False
|
|
)
|
|
|
|
# Should handle offline gracefully
|
|
context = service.get_context("AAPL", "2024-01-01", "2024-01-05", symbol="AAPL")
|
|
|
|
assert context.symbol == "AAPL"
|
|
assert isinstance(context.posts, list)
|
|
assert context.metadata.get("data_quality") == DataQuality.LOW
|
|
|
|
|
|
def test_engagement_metrics():
|
|
"""Test calculation of engagement metrics."""
|
|
mock_reddit = MockRedditClient()
|
|
service = SocialMediaService(
|
|
reddit_client=mock_reddit, repository=None, online_mode=True
|
|
)
|
|
|
|
context = service.get_company_social_context(
|
|
symbol="NVDA",
|
|
start_date="2024-01-01",
|
|
end_date="2024-01-02",
|
|
subreddits=["nvidia", "stocks"],
|
|
)
|
|
|
|
# Check engagement metrics in the context
|
|
assert len(context.engagement_metrics) > 0
|
|
assert (
|
|
"total_engagement" in context.engagement_metrics
|
|
or "total_engagement" in context.metadata
|
|
)
|
|
|
|
# Verify post scores
|
|
for post in context.posts:
|
|
# Posts should have score and comments
|
|
assert post.score >= 0
|
|
assert post.comments >= 0
|
|
|
|
|
|
def test_subreddit_filtering():
|
|
"""Test filtering by specific subreddits."""
|
|
mock_reddit = MockRedditClient()
|
|
service = SocialMediaService(
|
|
reddit_client=mock_reddit, repository=None, online_mode=True
|
|
)
|
|
|
|
# Test with specific subreddits
|
|
context = service.get_company_social_context(
|
|
symbol="AMD",
|
|
start_date="2024-01-01",
|
|
end_date="2024-01-02",
|
|
subreddits=["AMD_Stock", "wallstreetbets"],
|
|
)
|
|
|
|
# Check that posts are from requested subreddits
|
|
subreddit_set = set()
|
|
for post in context.posts:
|
|
if post.subreddit:
|
|
subreddit_set.add(post.subreddit)
|
|
|
|
assert len(subreddit_set) > 0
|
|
assert all(sub in ["AMD_Stock", "wallstreetbets"] for sub in subreddit_set)
|
|
|
|
|
|
def test_error_handling():
|
|
"""Test error handling with broken client."""
|
|
|
|
class BrokenRedditClient(BaseClient):
|
|
def test_connection(self):
|
|
return False
|
|
|
|
def get_data(self, *args, **kwargs):
|
|
raise Exception("Reddit API error")
|
|
|
|
def search_posts(self, *args, **kwargs):
|
|
raise Exception("Reddit API error")
|
|
|
|
def get_top_posts(self, *args, **kwargs):
|
|
raise Exception("Reddit API error")
|
|
|
|
broken_client = BrokenRedditClient()
|
|
service = SocialMediaService(
|
|
reddit_client=broken_client, repository=None, online_mode=True
|
|
)
|
|
|
|
# Should handle errors gracefully
|
|
context = service.get_context("TSLA", "2024-01-01", "2024-01-02", symbol="TSLA")
|
|
|
|
assert context.symbol == "TSLA"
|
|
assert len(context.posts) == 0
|
|
assert context.metadata.get("data_quality") == DataQuality.LOW
|
|
|
|
|
|
def test_json_structure():
|
|
"""Test JSON structure of social context."""
|
|
mock_reddit = MockRedditClient()
|
|
service = SocialMediaService(
|
|
reddit_client=mock_reddit, repository=None, online_mode=True
|
|
)
|
|
|
|
context = service.get_context("PLTR", "2024-01-01", "2024-01-02")
|
|
json_data = context.model_dump()
|
|
|
|
# Validate required fields
|
|
required_fields = [
|
|
"symbol",
|
|
"period",
|
|
"posts",
|
|
"sentiment_summary",
|
|
"post_count",
|
|
"platforms",
|
|
"metadata",
|
|
]
|
|
for field in required_fields:
|
|
assert field in json_data
|
|
|
|
# Validate posts structure
|
|
if json_data["posts"]:
|
|
first_post = json_data["posts"][0]
|
|
post_fields = ["title", "author", "date", "score"]
|
|
for field in post_fields:
|
|
assert field in first_post
|
|
|
|
# Validate sentiment structure
|
|
sentiment = json_data["sentiment_summary"]
|
|
assert "score" in sentiment
|
|
assert "confidence" in sentiment
|
|
assert "label" in sentiment
|