TradingAgents/tradingagents/services/social_media_service.py

578 lines
20 KiB
Python

"""
Social Media Service for aggregating and analyzing social media data.
"""
import logging
from datetime import datetime
from typing import Any
from tradingagents.clients.base import BaseClient
from tradingagents.models.context import (
DataQuality,
PostData,
SentimentScore,
SocialContext,
)
from tradingagents.repositories.base import BaseRepository
from tradingagents.services.base import BaseService
logger = logging.getLogger(__name__)
class SocialMediaService(BaseService):
"""Service for social media data aggregation and analysis."""
def __init__(
self,
reddit_client: BaseClient | None = None,
repository: BaseRepository | None = None,
online_mode: bool = True,
data_dir: str = "data",
**kwargs,
):
"""Initialize Social Media Service.
Args:
reddit_client: Client for Reddit API access
repository: Repository for cached social data
online_mode: Whether to fetch live data
data_dir: Directory for data storage
"""
super().__init__(online_mode=online_mode, data_dir=data_dir, **kwargs)
self.reddit_client = reddit_client
self.repository = repository
def get_context(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None = None,
subreddits: list[str] | None = None,
force_refresh: bool = False,
**kwargs,
) -> SocialContext:
"""Get social media context for a query.
Args:
query: Search query
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
symbol: Optional stock symbol
subreddits: Optional list of subreddits to search
force_refresh: If True, skip local data and fetch fresh from APIs
Returns:
SocialContext with posts and sentiment analysis
"""
posts = []
error_info = {}
data_source = "unknown"
try:
# Local-first data strategy with force refresh option
if force_refresh:
# Skip local data, fetch fresh from APIs
posts, data_source = self._fetch_and_cache_fresh_social_data(
query, start_date, end_date, symbol, subreddits
)
else:
# Check local data first, fetch missing if needed
posts, data_source = self._get_social_data_local_first(
query, start_date, end_date, symbol, subreddits
)
except Exception as e:
logger.error(f"Error fetching social media data: {e}")
error_info = {"error": str(e)}
# Calculate sentiment and engagement metrics
sentiment_summary = self._calculate_sentiment(posts)
engagement_metrics = self._calculate_engagement_metrics(posts)
# Determine data quality based on data source
data_quality = self._determine_data_quality(
data_source=data_source,
record_count=len(posts),
has_errors=bool(error_info),
)
# Separate float metrics from metadata
float_metrics = {
k: v for k, v in engagement_metrics.items() if isinstance(v, int | float)
}
metadata_info = {
k: v
for k, v in engagement_metrics.items()
if not isinstance(v, int | float)
}
return SocialContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
posts=posts,
engagement_metrics=float_metrics,
sentiment_summary=sentiment_summary,
post_count=len(posts),
platforms=["reddit"],
metadata={
"data_quality": data_quality,
"service": "social_media",
"online_mode": self.is_online(),
"subreddits": subreddits or [],
"data_source": data_source,
"force_refresh": force_refresh,
**metadata_info,
**error_info,
},
)
def get_company_social_context(
self,
symbol: str,
start_date: str,
end_date: str,
subreddits: list[str] | None = None,
**kwargs,
) -> SocialContext:
"""Get company-specific social media context.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
subreddits: Optional list of subreddits
Returns:
SocialContext for the company
"""
return self.get_context(
query=symbol,
start_date=start_date,
end_date=end_date,
symbol=symbol,
subreddits=subreddits,
**kwargs,
)
def get_global_trends(
self,
start_date: str,
end_date: str,
subreddits: list[str] | None = None,
**kwargs,
) -> SocialContext:
"""Get global social media trends.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
subreddits: Optional list of subreddits
Returns:
SocialContext with global trends
"""
posts = []
try:
if self.is_online() and self.reddit_client:
subreddit_list = subreddits or ["news", "worldnews", "Economics"]
# Get top posts from subreddits
raw_posts = self.reddit_client.get_top_posts(
subreddit_names=subreddit_list, limit=50, time_filter="week"
)
# Filter by date
if hasattr(self.reddit_client, "filter_posts_by_date"):
raw_posts = self.reddit_client.filter_posts_by_date(
raw_posts, start_date, end_date
)
posts = self._convert_to_post_data(raw_posts)
except Exception as e:
logger.error(f"Error fetching global trends: {e}")
sentiment_summary = self._calculate_sentiment(posts)
engagement_metrics = self._calculate_engagement_metrics(posts)
# Separate float metrics from metadata
float_metrics = {
k: v for k, v in engagement_metrics.items() if isinstance(v, int | float)
}
metadata_info = {
k: v
for k, v in engagement_metrics.items()
if not isinstance(v, int | float)
}
return SocialContext(
symbol=None, # No specific symbol for global trends
period={"start": start_date, "end": end_date},
posts=posts,
engagement_metrics=float_metrics,
sentiment_summary=sentiment_summary,
post_count=len(posts),
platforms=["reddit"],
metadata={
"data_quality": self._determine_data_quality(
data_source="live_api" if self.is_online() else "offline",
record_count=len(posts),
has_errors=False,
),
"service": "social_media",
"type": "global_trends",
"subreddits": subreddits or [],
**metadata_info,
},
)
def _convert_to_post_data(self, raw_posts: list[dict[str, Any]]) -> list[PostData]:
"""Convert raw Reddit posts to PostData objects."""
posts = []
for post in raw_posts:
try:
# Calculate engagement score
engagement = post.get("upvotes", 0) + post.get("num_comments", 0)
# Get posted date
if "posted_date" in post:
date_str = post["posted_date"]
elif "created_utc" in post:
date_str = datetime.fromtimestamp(post["created_utc"]).strftime(
"%Y-%m-%d"
)
else:
date_str = datetime.now().strftime("%Y-%m-%d")
post_data = PostData(
title=post.get("title", ""),
content=post.get("content", ""),
author=post.get("author", "unknown"),
source=post.get("subreddit", "reddit"),
date=date_str,
url=post.get("url", ""),
score=post.get("score", 0),
comments=post.get("num_comments", 0),
engagement_score=engagement,
subreddit=post.get("subreddit"),
metadata={
"upvotes": post.get("upvotes", 0),
"num_comments": post.get("num_comments", 0),
"subreddit": post.get("subreddit", ""),
},
)
posts.append(post_data)
except Exception as e:
logger.warning(f"Error converting post: {e}")
continue
return posts
def _convert_cached_to_posts(self, cached_data: dict[str, Any]) -> list[PostData]:
"""Convert cached repository data to PostData objects."""
posts = []
if not cached_data or "posts" not in cached_data:
return posts
for post in cached_data.get("posts", []):
try:
posts.append(PostData(**post))
except Exception as e:
logger.warning(f"Error converting cached post: {e}")
return posts
def _get_social_data_local_first(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None,
subreddits: list[str] | None,
) -> tuple[list[PostData], str]:
"""Get social data using local-first strategy: check local data first, fetch missing if needed."""
try:
# Check if we have sufficient local data
search_key = symbol or query
if self.repository and self.repository.has_data_for_period(
search_key, start_date, end_date, symbol=symbol
):
logger.info(
f"Using local social data for {search_key} ({start_date} to {end_date})"
)
cached_data = self.repository.get_data(
query=search_key, start_date=start_date, end_date=end_date
)
posts = self._convert_cached_to_posts(cached_data)
return posts, "local_cache"
# We don't have sufficient local data - need to fetch from APIs
logger.info(
f"Local data insufficient, fetching from APIs for {search_key} ({start_date} to {end_date})"
)
posts, _ = self._fetch_fresh_social_data(
query, start_date, end_date, symbol, subreddits
)
# Cache the fresh data if we have a repository
if posts and self.repository:
try:
posts_data = [post.model_dump() for post in posts]
cache_data = {
"query": query,
"symbol": symbol,
"posts": posts_data,
"subreddits": subreddits,
"metadata": {"cached_at": datetime.utcnow().isoformat()},
}
self.repository.store_data(search_key, cache_data, symbol=symbol)
logger.debug(f"Cached fresh social data for {search_key}")
except Exception as e:
logger.warning(f"Failed to cache social data for {search_key}: {e}")
return posts, "live_api"
except Exception as e:
logger.error(f"Error fetching social data for {query}: {e}")
return [], "error"
def _fetch_and_cache_fresh_social_data(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None,
subreddits: list[str] | None,
) -> tuple[list[PostData], str]:
"""Force fetch fresh social data from APIs and cache it, bypassing local data."""
try:
search_key = symbol or query
logger.info(
f"Force refreshing social data from APIs for {search_key} ({start_date} to {end_date})"
)
# Clear existing data if we have a repository
if self.repository:
try:
self.repository.clear_data(
search_key, start_date, end_date, symbol=symbol
)
logger.debug(f"Cleared existing social data for {search_key}")
except Exception as e:
logger.warning(
f"Failed to clear existing social data for {search_key}: {e}"
)
# Fetch fresh data
posts, _ = self._fetch_fresh_social_data(
query, start_date, end_date, symbol, subreddits
)
# Cache the fresh data
if posts and self.repository:
try:
posts_data = [post.model_dump() for post in posts]
cache_data = {
"query": query,
"symbol": symbol,
"posts": posts_data,
"subreddits": subreddits,
"metadata": {"refreshed_at": datetime.utcnow().isoformat()},
}
self.repository.store_data(
search_key, cache_data, symbol=symbol, overwrite=True
)
logger.debug(f"Cached refreshed social data for {search_key}")
except Exception as e:
logger.warning(
f"Failed to cache refreshed social data for {search_key}: {e}"
)
return posts, "live_api_refresh"
except Exception as e:
logger.error(f"Error force refreshing social data for {query}: {e}")
return [], "refresh_error"
def _fetch_fresh_social_data(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None,
subreddits: list[str] | None,
) -> tuple[list[PostData], str]:
"""Fetch fresh social data from APIs."""
posts = []
if self.is_online() and self.reddit_client:
# Get live Reddit data
subreddit_list = subreddits or ["investing", "stocks", "wallstreetbets"]
# Search for posts
raw_posts = self.reddit_client.search_posts(
query=query,
subreddit_names=subreddit_list,
limit=50,
time_filter="week",
)
# Filter by date
if hasattr(self.reddit_client, "filter_posts_by_date"):
raw_posts = self.reddit_client.filter_posts_by_date(
raw_posts, start_date, end_date
)
# Convert to PostData objects
posts = self._convert_to_post_data(raw_posts)
return posts, "live_api"
def _calculate_sentiment(self, posts: list[PostData]) -> SentimentScore:
"""Calculate overall sentiment from posts."""
if not posts:
return SentimentScore(score=0.0, confidence=0.0, label="neutral")
total_score = 0.0
total_weight = 0.0
for post in posts:
# Simple sentiment analysis based on keywords and engagement
sentiment_score = self._analyze_post_sentiment(post)
# Weight by engagement
weight = 1 + (
post.engagement_score / 1000
) # Higher engagement = more weight
total_score += sentiment_score * weight
total_weight += weight
# Set individual post sentiment
post.sentiment = SentimentScore(
score=sentiment_score,
confidence=0.7, # Moderate confidence for keyword-based analysis
label="positive"
if sentiment_score > 0.2
else "negative"
if sentiment_score < -0.2
else "neutral",
)
# Calculate weighted average
avg_score = total_score / total_weight if total_weight > 0 else 0.0
# Determine label
if avg_score > 0.2:
label = "positive"
elif avg_score < -0.2:
label = "negative"
else:
label = "neutral"
# Confidence based on number of posts
confidence = min(0.9, 0.5 + (len(posts) / 100))
return SentimentScore(score=avg_score, confidence=confidence, label=label)
def _analyze_post_sentiment(self, post: PostData) -> float:
"""Analyze sentiment of a single post."""
text = f"{post.title} {post.content or ''}".lower()
# Simple keyword-based sentiment
positive_words = [
"bullish",
"moon",
"gains",
"buy",
"hold",
"amazing",
"great",
"excellent",
"positive",
"growth",
"beat",
"upgrade",
"🚀",
]
negative_words = [
"bearish",
"crash",
"sell",
"loss",
"decline",
"terrible",
"bad",
"negative",
"downgrade",
"warning",
"overvalued",
]
positive_count = sum(1 for word in positive_words if word in text)
negative_count = sum(1 for word in negative_words if word in text)
# Score from -1 to 1
if positive_count + negative_count == 0:
return 0.0
score = (positive_count - negative_count) / (positive_count + negative_count)
# Adjust for score ratio (upvotes vs downvotes implied)
if post.score > 0:
score_adjustment = min(0.2, post.score / 1000)
score = score * 0.8 + score_adjustment * 0.2
return max(-1.0, min(1.0, score))
def _calculate_engagement_metrics(self, posts: list[PostData]) -> dict[str, float]:
"""Calculate engagement metrics from posts."""
if not posts:
return {
"total_engagement": 0,
"average_engagement": 0,
"max_engagement": 0,
"total_posts": 0,
}
engagements = [post.engagement_score for post in posts]
metrics = {
"total_engagement": sum(engagements),
"average_engagement": sum(engagements) / len(engagements),
"max_engagement": max(engagements),
"total_posts": len(posts),
}
# Add top posts info
sorted_posts = sorted(posts, key=lambda p: p.engagement_score, reverse=True)
metrics["top_posts"] = [
{"title": p.title[:100], "engagement": p.engagement_score}
for p in sorted_posts[:3]
]
return metrics
def _determine_data_quality(
self, data_source: str, record_count: int, has_errors: bool = False
) -> DataQuality:
"""Determine data quality based on source, record count, and errors."""
if has_errors or record_count == 0:
return DataQuality.LOW
if data_source in ["local_cache", "error", "refresh_error"]:
return DataQuality.LOW
elif data_source in ["live_api", "live_api_refresh"]:
if record_count >= 20:
return DataQuality.HIGH
elif record_count >= 5:
return DataQuality.MEDIUM
else:
return DataQuality.LOW
else:
return DataQuality.MEDIUM