""" Social Media Service for aggregating and analyzing social media data. """ import logging from datetime import datetime from typing import Any from tradingagents.clients.base import BaseClient from tradingagents.models.context import ( DataQuality, PostData, SentimentScore, SocialContext, ) from tradingagents.repositories.base import BaseRepository from tradingagents.services.base import BaseService logger = logging.getLogger(__name__) class SocialMediaService(BaseService): """Service for social media data aggregation and analysis.""" def __init__( self, reddit_client: BaseClient | None = None, repository: BaseRepository | None = None, online_mode: bool = True, data_dir: str = "data", **kwargs, ): """Initialize Social Media Service. Args: reddit_client: Client for Reddit API access repository: Repository for cached social data online_mode: Whether to fetch live data data_dir: Directory for data storage """ super().__init__(online_mode=online_mode, data_dir=data_dir, **kwargs) self.reddit_client = reddit_client self.repository = repository def get_context( self, query: str, start_date: str, end_date: str, symbol: str | None = None, subreddits: list[str] | None = None, force_refresh: bool = False, **kwargs, ) -> SocialContext: """Get social media context for a query. Args: query: Search query start_date: Start date in YYYY-MM-DD format end_date: End date in YYYY-MM-DD format symbol: Optional stock symbol subreddits: Optional list of subreddits to search force_refresh: If True, skip local data and fetch fresh from APIs Returns: SocialContext with posts and sentiment analysis """ posts = [] error_info = {} data_source = "unknown" try: # Local-first data strategy with force refresh option if force_refresh: # Skip local data, fetch fresh from APIs posts, data_source = self._fetch_and_cache_fresh_social_data( query, start_date, end_date, symbol, subreddits ) else: # Check local data first, fetch missing if needed posts, data_source = self._get_social_data_local_first( query, start_date, end_date, symbol, subreddits ) except Exception as e: logger.error(f"Error fetching social media data: {e}") error_info = {"error": str(e)} # Calculate sentiment and engagement metrics sentiment_summary = self._calculate_sentiment(posts) engagement_metrics = self._calculate_engagement_metrics(posts) # Determine data quality based on data source data_quality = self._determine_data_quality( data_source=data_source, record_count=len(posts), has_errors=bool(error_info), ) # Separate float metrics from metadata float_metrics = { k: v for k, v in engagement_metrics.items() if isinstance(v, int | float) } metadata_info = { k: v for k, v in engagement_metrics.items() if not isinstance(v, int | float) } return SocialContext( symbol=symbol, period={"start": start_date, "end": end_date}, posts=posts, engagement_metrics=float_metrics, sentiment_summary=sentiment_summary, post_count=len(posts), platforms=["reddit"], metadata={ "data_quality": data_quality, "service": "social_media", "online_mode": self.is_online(), "subreddits": subreddits or [], "data_source": data_source, "force_refresh": force_refresh, **metadata_info, **error_info, }, ) def get_company_social_context( self, symbol: str, start_date: str, end_date: str, subreddits: list[str] | None = None, **kwargs, ) -> SocialContext: """Get company-specific social media context. Args: symbol: Stock ticker symbol start_date: Start date in YYYY-MM-DD format end_date: End date in YYYY-MM-DD format subreddits: Optional list of subreddits Returns: SocialContext for the company """ return self.get_context( query=symbol, start_date=start_date, end_date=end_date, symbol=symbol, subreddits=subreddits, **kwargs, ) def get_global_trends( self, start_date: str, end_date: str, subreddits: list[str] | None = None, **kwargs, ) -> SocialContext: """Get global social media trends. Args: start_date: Start date in YYYY-MM-DD format end_date: End date in YYYY-MM-DD format subreddits: Optional list of subreddits Returns: SocialContext with global trends """ posts = [] try: if self.is_online() and self.reddit_client: subreddit_list = subreddits or ["news", "worldnews", "Economics"] # Get top posts from subreddits raw_posts = self.reddit_client.get_top_posts( subreddit_names=subreddit_list, limit=50, time_filter="week" ) # Filter by date if hasattr(self.reddit_client, "filter_posts_by_date"): raw_posts = self.reddit_client.filter_posts_by_date( raw_posts, start_date, end_date ) posts = self._convert_to_post_data(raw_posts) except Exception as e: logger.error(f"Error fetching global trends: {e}") sentiment_summary = self._calculate_sentiment(posts) engagement_metrics = self._calculate_engagement_metrics(posts) # Separate float metrics from metadata float_metrics = { k: v for k, v in engagement_metrics.items() if isinstance(v, int | float) } metadata_info = { k: v for k, v in engagement_metrics.items() if not isinstance(v, int | float) } return SocialContext( symbol=None, # No specific symbol for global trends period={"start": start_date, "end": end_date}, posts=posts, engagement_metrics=float_metrics, sentiment_summary=sentiment_summary, post_count=len(posts), platforms=["reddit"], metadata={ "data_quality": self._determine_data_quality( data_source="live_api" if self.is_online() else "offline", record_count=len(posts), has_errors=False, ), "service": "social_media", "type": "global_trends", "subreddits": subreddits or [], **metadata_info, }, ) def _convert_to_post_data(self, raw_posts: list[dict[str, Any]]) -> list[PostData]: """Convert raw Reddit posts to PostData objects.""" posts = [] for post in raw_posts: try: # Calculate engagement score engagement = post.get("upvotes", 0) + post.get("num_comments", 0) # Get posted date if "posted_date" in post: date_str = post["posted_date"] elif "created_utc" in post: date_str = datetime.fromtimestamp(post["created_utc"]).strftime( "%Y-%m-%d" ) else: date_str = datetime.now().strftime("%Y-%m-%d") post_data = PostData( title=post.get("title", ""), content=post.get("content", ""), author=post.get("author", "unknown"), source=post.get("subreddit", "reddit"), date=date_str, url=post.get("url", ""), score=post.get("score", 0), comments=post.get("num_comments", 0), engagement_score=engagement, subreddit=post.get("subreddit"), metadata={ "upvotes": post.get("upvotes", 0), "num_comments": post.get("num_comments", 0), "subreddit": post.get("subreddit", ""), }, ) posts.append(post_data) except Exception as e: logger.warning(f"Error converting post: {e}") continue return posts def _convert_cached_to_posts(self, cached_data: dict[str, Any]) -> list[PostData]: """Convert cached repository data to PostData objects.""" posts = [] if not cached_data or "posts" not in cached_data: return posts for post in cached_data.get("posts", []): try: posts.append(PostData(**post)) except Exception as e: logger.warning(f"Error converting cached post: {e}") return posts def _get_social_data_local_first( self, query: str, start_date: str, end_date: str, symbol: str | None, subreddits: list[str] | None, ) -> tuple[list[PostData], str]: """Get social data using local-first strategy: check local data first, fetch missing if needed.""" try: # Check if we have sufficient local data search_key = symbol or query if self.repository and self.repository.has_data_for_period( search_key, start_date, end_date, symbol=symbol ): logger.info( f"Using local social data for {search_key} ({start_date} to {end_date})" ) cached_data = self.repository.get_data( query=search_key, start_date=start_date, end_date=end_date ) posts = self._convert_cached_to_posts(cached_data) return posts, "local_cache" # We don't have sufficient local data - need to fetch from APIs logger.info( f"Local data insufficient, fetching from APIs for {search_key} ({start_date} to {end_date})" ) posts, _ = self._fetch_fresh_social_data( query, start_date, end_date, symbol, subreddits ) # Cache the fresh data if we have a repository if posts and self.repository: try: posts_data = [post.model_dump() for post in posts] cache_data = { "query": query, "symbol": symbol, "posts": posts_data, "subreddits": subreddits, "metadata": {"cached_at": datetime.utcnow().isoformat()}, } self.repository.store_data(search_key, cache_data, symbol=symbol) logger.debug(f"Cached fresh social data for {search_key}") except Exception as e: logger.warning(f"Failed to cache social data for {search_key}: {e}") return posts, "live_api" except Exception as e: logger.error(f"Error fetching social data for {query}: {e}") return [], "error" def _fetch_and_cache_fresh_social_data( self, query: str, start_date: str, end_date: str, symbol: str | None, subreddits: list[str] | None, ) -> tuple[list[PostData], str]: """Force fetch fresh social data from APIs and cache it, bypassing local data.""" try: search_key = symbol or query logger.info( f"Force refreshing social data from APIs for {search_key} ({start_date} to {end_date})" ) # Clear existing data if we have a repository if self.repository: try: self.repository.clear_data( search_key, start_date, end_date, symbol=symbol ) logger.debug(f"Cleared existing social data for {search_key}") except Exception as e: logger.warning( f"Failed to clear existing social data for {search_key}: {e}" ) # Fetch fresh data posts, _ = self._fetch_fresh_social_data( query, start_date, end_date, symbol, subreddits ) # Cache the fresh data if posts and self.repository: try: posts_data = [post.model_dump() for post in posts] cache_data = { "query": query, "symbol": symbol, "posts": posts_data, "subreddits": subreddits, "metadata": {"refreshed_at": datetime.utcnow().isoformat()}, } self.repository.store_data( search_key, cache_data, symbol=symbol, overwrite=True ) logger.debug(f"Cached refreshed social data for {search_key}") except Exception as e: logger.warning( f"Failed to cache refreshed social data for {search_key}: {e}" ) return posts, "live_api_refresh" except Exception as e: logger.error(f"Error force refreshing social data for {query}: {e}") return [], "refresh_error" def _fetch_fresh_social_data( self, query: str, start_date: str, end_date: str, symbol: str | None, subreddits: list[str] | None, ) -> tuple[list[PostData], str]: """Fetch fresh social data from APIs.""" posts = [] if self.is_online() and self.reddit_client: # Get live Reddit data subreddit_list = subreddits or ["investing", "stocks", "wallstreetbets"] # Search for posts raw_posts = self.reddit_client.search_posts( query=query, subreddit_names=subreddit_list, limit=50, time_filter="week", ) # Filter by date if hasattr(self.reddit_client, "filter_posts_by_date"): raw_posts = self.reddit_client.filter_posts_by_date( raw_posts, start_date, end_date ) # Convert to PostData objects posts = self._convert_to_post_data(raw_posts) return posts, "live_api" def _calculate_sentiment(self, posts: list[PostData]) -> SentimentScore: """Calculate overall sentiment from posts.""" if not posts: return SentimentScore(score=0.0, confidence=0.0, label="neutral") total_score = 0.0 total_weight = 0.0 for post in posts: # Simple sentiment analysis based on keywords and engagement sentiment_score = self._analyze_post_sentiment(post) # Weight by engagement weight = 1 + ( post.engagement_score / 1000 ) # Higher engagement = more weight total_score += sentiment_score * weight total_weight += weight # Set individual post sentiment post.sentiment = SentimentScore( score=sentiment_score, confidence=0.7, # Moderate confidence for keyword-based analysis label="positive" if sentiment_score > 0.2 else "negative" if sentiment_score < -0.2 else "neutral", ) # Calculate weighted average avg_score = total_score / total_weight if total_weight > 0 else 0.0 # Determine label if avg_score > 0.2: label = "positive" elif avg_score < -0.2: label = "negative" else: label = "neutral" # Confidence based on number of posts confidence = min(0.9, 0.5 + (len(posts) / 100)) return SentimentScore(score=avg_score, confidence=confidence, label=label) def _analyze_post_sentiment(self, post: PostData) -> float: """Analyze sentiment of a single post.""" text = f"{post.title} {post.content or ''}".lower() # Simple keyword-based sentiment positive_words = [ "bullish", "moon", "gains", "buy", "hold", "amazing", "great", "excellent", "positive", "growth", "beat", "upgrade", "🚀", ] negative_words = [ "bearish", "crash", "sell", "loss", "decline", "terrible", "bad", "negative", "downgrade", "warning", "overvalued", ] positive_count = sum(1 for word in positive_words if word in text) negative_count = sum(1 for word in negative_words if word in text) # Score from -1 to 1 if positive_count + negative_count == 0: return 0.0 score = (positive_count - negative_count) / (positive_count + negative_count) # Adjust for score ratio (upvotes vs downvotes implied) if post.score > 0: score_adjustment = min(0.2, post.score / 1000) score = score * 0.8 + score_adjustment * 0.2 return max(-1.0, min(1.0, score)) def _calculate_engagement_metrics(self, posts: list[PostData]) -> dict[str, float]: """Calculate engagement metrics from posts.""" if not posts: return { "total_engagement": 0, "average_engagement": 0, "max_engagement": 0, "total_posts": 0, } engagements = [post.engagement_score for post in posts] metrics = { "total_engagement": sum(engagements), "average_engagement": sum(engagements) / len(engagements), "max_engagement": max(engagements), "total_posts": len(posts), } # Add top posts info sorted_posts = sorted(posts, key=lambda p: p.engagement_score, reverse=True) metrics["top_posts"] = [ {"title": p.title[:100], "engagement": p.engagement_score} for p in sorted_posts[:3] ] return metrics def _determine_data_quality( self, data_source: str, record_count: int, has_errors: bool = False ) -> DataQuality: """Determine data quality based on source, record count, and errors.""" if has_errors or record_count == 0: return DataQuality.LOW if data_source in ["local_cache", "error", "refresh_error"]: return DataQuality.LOW elif data_source in ["live_api", "live_api_refresh"]: if record_count >= 20: return DataQuality.HIGH elif record_count >= 5: return DataQuality.MEDIUM else: return DataQuality.LOW else: return DataQuality.MEDIUM