# Social Media Domain Implementation Tasks ## Overview Complete greenfield implementation of the socialmedia domain from empty stubs to production-ready system with PRAW Reddit API integration, PostgreSQL migration, OpenRouter LLM sentiment analysis, and AgentToolkit RAG methods. **Total Estimated Time: 32 hours (3-phase parallel development approach)** ## Phase Structure ### Phase 1: Foundation (12 hours) - Database & Core Models **Parallel Execution Ready**: Multiple agents can work on different components simultaneously ### Phase 2: API Integration & Processing (12 hours) - Clients & Services **Parallel Execution Ready**: API clients and LLM services can be developed in parallel ### Phase 3: Integration & Validation (8 hours) - AgentToolkit & Dagster **Parallel Execution Ready**: AgentToolkit and pipeline development with comprehensive testing --- ## Phase 1: Foundation (12 hours) ### Task 1.1: Database Schema Migration (3 hours) **Priority: Blocking** | **Agent: Database Specialist** Create PostgreSQL migration for social_media_posts table with TimescaleDB and pgvectorscale support. **Implementation:** ```sql -- Migration: 003_create_social_media_posts.sql CREATE TABLE social_media_posts ( id UUID PRIMARY KEY DEFAULT uuid7(), post_id VARCHAR(50) UNIQUE NOT NULL, title TEXT NOT NULL, content TEXT, author VARCHAR(100) NOT NULL, subreddit VARCHAR(50) NOT NULL, created_utc TIMESTAMPTZ NOT NULL, upvotes INTEGER NOT NULL DEFAULT 0, downvotes INTEGER NOT NULL DEFAULT 0, comments_count INTEGER NOT NULL DEFAULT 0, url TEXT NOT NULL, sentiment_score JSONB, sentiment_label VARCHAR(20), tickers TEXT[] DEFAULT '{}', title_embedding VECTOR(1536), content_embedding VECTOR(1536), inserted_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() ); SELECT create_hypertable('social_media_posts', 'created_utc', chunk_time_interval => INTERVAL '1 day'); -- Performance indexes CREATE UNIQUE INDEX idx_social_posts_post_id ON social_media_posts (post_id); CREATE INDEX idx_social_posts_subreddit_time ON social_media_posts (subreddit, created_utc DESC); CREATE INDEX idx_social_posts_tickers_gin ON social_media_posts USING GIN (tickers); CREATE INDEX idx_social_posts_title_embedding ON social_media_posts USING vectors (title_embedding vector_cosine_ops); CREATE INDEX idx_social_posts_content_embedding ON social_media_posts USING vectors (content_embedding vector_cosine_ops); CREATE INDEX idx_social_posts_sentiment ON social_media_posts (((sentiment_score->>'sentiment'))) WHERE sentiment_score IS NOT NULL; -- Constraints ALTER TABLE social_media_posts ADD CONSTRAINT chk_sentiment_score CHECK ( sentiment_score IS NULL OR ((sentiment_score->>'confidence')::float BETWEEN 0 AND 1) ); ALTER TABLE social_media_posts ADD CONSTRAINT chk_created_utc CHECK (created_utc <= NOW()); ``` **Acceptance Criteria:** - [ ] Migration script creates social_media_posts table - [ ] TimescaleDB hypertable configured for time-series optimization - [ ] pgvectorscale indexes for title_embedding and content_embedding - [ ] All constraints and indexes properly created - [ ] Migration runs successfully in test and development environments **Dependencies:** PostgreSQL + TimescaleDB + pgvectorscale installed **Risk:** Medium - Extension compatibility issues --- ### Task 1.2: SQLAlchemy Entity Implementation (3 hours) **Priority: Blocking** | **Agent: Entity Specialist** Create SocialMediaPostEntity with proper field mappings and domain transformations. **File:** `tradingagents/domains/socialmedia/entities.py` **Implementation:** ```python from sqlalchemy import Column, String, Text, Integer, TIMESTAMP, Index from sqlalchemy.dialects.postgresql import UUID, VECTOR, ARRAY, JSONB from sqlalchemy.sql import func from tradingagents.database.base import Base from typing import Optional, List, Dict, Any import uuid class SocialMediaPostEntity(Base): __tablename__ = 'social_media_posts' id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) post_id = Column(String(50), unique=True, nullable=False, index=True) title = Column(Text, nullable=False) content = Column(Text) author = Column(String(100), nullable=False) subreddit = Column(String(50), nullable=False, index=True) created_utc = Column(TIMESTAMP(timezone=True), nullable=False, index=True) upvotes = Column(Integer, nullable=False, default=0) downvotes = Column(Integer, nullable=False, default=0) comments_count = Column(Integer, nullable=False, default=0) url = Column(Text, nullable=False) # Enhanced fields sentiment_score = Column(JSONB) sentiment_label = Column(String(20)) tickers = Column(ARRAY(String(10)), default=lambda: []) title_embedding = Column(VECTOR(1536)) content_embedding = Column(VECTOR(1536)) # Metadata inserted_at = Column(TIMESTAMP(timezone=True), server_default=func.now()) updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now()) def to_domain(self) -> 'SocialPost': """Convert entity to domain model with proper field mapping""" sentiment_data = self.sentiment_score or {} return SocialPost( post_id=self.post_id, title=self.title, content=self.content, author=self.author, subreddit=self.subreddit, created_utc=self.created_utc, upvotes=self.upvotes, downvotes=self.downvotes, comments_count=self.comments_count, url=self.url, sentiment_score=sentiment_data.get('score'), sentiment_label=self.sentiment_label, sentiment_confidence=sentiment_data.get('confidence'), tickers=list(self.tickers) if self.tickers else [], title_embedding=list(self.title_embedding) if self.title_embedding else None, content_embedding=list(self.content_embedding) if self.content_embedding else None ) @classmethod def from_domain(cls, post: 'SocialPost') -> 'SocialMediaPostEntity': """Create entity from domain model""" sentiment_data = None if post.sentiment_score is not None and post.sentiment_confidence is not None: sentiment_data = { 'score': post.sentiment_score, 'confidence': post.sentiment_confidence, 'reasoning': getattr(post, 'sentiment_reasoning', None) } return cls( post_id=post.post_id, title=post.title, content=post.content, author=post.author, subreddit=post.subreddit, created_utc=post.created_utc, upvotes=post.upvotes, downvotes=post.downvotes, comments_count=post.comments_count, url=post.url, sentiment_score=sentiment_data, sentiment_label=post.sentiment_label, tickers=post.tickers or [], title_embedding=post.title_embedding, content_embedding=post.content_embedding ) ``` **Acceptance Criteria:** - [ ] SocialMediaPostEntity properly maps all database fields - [ ] to_domain() and from_domain() methods handle all field conversions - [ ] Proper handling of vector fields and JSONB sentiment data - [ ] Entity integrates with existing database session management - [ ] All field types match database schema exactly **Dependencies:** Task 1.1 (database schema) **Risk:** Low - Standard SQLAlchemy patterns --- ### Task 1.3: Domain Model Enhancement (3 hours) **Priority: Blocking** | **Agent: Domain Specialist** Enhance SocialPost domain entity with comprehensive validation, transformations, and business rules. **File:** `tradingagents/domains/socialmedia/models.py` **Implementation:** ```python from pydantic import BaseModel, Field, validator, root_validator from typing import Optional, List, Dict, Any from datetime import datetime import re class SentimentScore(BaseModel): """Structured sentiment analysis result from OpenRouter LLM""" sentiment: Literal['positive', 'negative', 'neutral'] confidence: float = Field(..., ge=0.0, le=1.0) reasoning: Optional[str] = None @validator('reasoning') def reasoning_not_empty(cls, v): if v is not None and len(v.strip()) == 0: return None return v class SocialPost(BaseModel): """Core domain entity with business rules and transformations""" # Base fields from Reddit API post_id: str = Field(..., regex=r'^[a-zA-Z0-9_-]+$') title: str = Field(..., min_length=1, max_length=300) content: Optional[str] = None author: str = Field(..., min_length=1, max_length=100) subreddit: str = Field(..., min_length=1, max_length=50) created_utc: datetime upvotes: int = Field(..., ge=0) downvotes: int = Field(..., ge=0) comments_count: int = Field(..., ge=0) url: str = Field(..., min_length=1) # Enhanced fields sentiment_score: Optional[float] = Field(None, ge=-1.0, le=1.0) sentiment_label: Optional[str] = Field(None, regex=r'^(positive|negative|neutral)$') sentiment_confidence: Optional[float] = Field(None, ge=0.0, le=1.0) sentiment_reasoning: Optional[str] = None tickers: Optional[List[str]] = Field(default_factory=list) title_embedding: Optional[List[float]] = None content_embedding: Optional[List[float]] = None @validator('tickers') def validate_tickers(cls, v): """Validate ticker symbols format""" if v is None: return [] # Ensure tickers are uppercase and valid format return [ticker.upper() for ticker in v if re.match(r'^[A-Z]{1,5}$', ticker.upper())] @validator('title_embedding', 'content_embedding') def validate_embedding_dimensions(cls, v): """Ensure embeddings have correct dimensions""" if v is not None and len(v) != 1536: raise ValueError('Embedding must be 1536 dimensions') return v @root_validator def validate_sentiment_consistency(cls, values): """Ensure sentiment fields are consistent""" score = values.get('sentiment_score') label = values.get('sentiment_label') confidence = values.get('sentiment_confidence') # All sentiment fields should be present or all None sentiment_fields = [score, label, confidence] non_none_count = sum(1 for field in sentiment_fields if field is not None) if non_none_count > 0 and non_none_count < 3: raise ValueError('All sentiment fields (score, label, confidence) must be provided together') return values @classmethod def from_praw_submission(cls, submission: Any) -> 'SocialPost': """Create SocialPost from PRAW Reddit submission""" return cls( post_id=submission.id, title=submission.title[:300], # Truncate long titles content=submission.selftext if submission.selftext else None, author=str(submission.author) if submission.author else '[deleted]', subreddit=submission.subreddit.display_name, created_utc=datetime.fromtimestamp(submission.created_utc), upvotes=submission.ups if hasattr(submission, 'ups') else submission.score, downvotes=max(0, submission.score - submission.ups) if hasattr(submission, 'ups') else 0, comments_count=submission.num_comments, url=f"https://reddit.com{submission.permalink}" ) def extract_tickers(self) -> List[str]: """Extract ticker symbols from title and content""" text = f"{self.title} {self.content or ''}" # Look for $TICKER or TICKER patterns ticker_pattern = r'\b(?:\$)?([A-Z]{1,5})\b' potential_tickers = re.findall(ticker_pattern, text.upper()) # Filter out common words that look like tickers excluded = {'THE', 'AND', 'OR', 'FOR', 'TO', 'OF', 'IN', 'ON', 'AT', 'BY', 'UP', 'IS', 'IT', 'BE', 'AS', 'ARE', 'WAS', 'HE', 'SHE', 'WE', 'YOU', 'THEY', 'ALL', 'ANY', 'CAN', 'HAD', 'HER', 'HIS', 'HOW', 'ITS', 'MAY', 'NEW', 'NOW', 'OLD', 'SEE', 'TWO', 'WHO', 'BOY', 'DID', 'HAS', 'LET', 'PUT', 'SAY', 'SIX', 'TEN', 'USE', 'WAS', 'WIN', 'YES'} tickers = [ticker for ticker in potential_tickers if ticker not in excluded] return list(set(tickers)) # Remove duplicates def has_reliable_sentiment(self) -> bool: """Check if sentiment analysis has sufficient confidence""" return (self.sentiment_confidence is not None and self.sentiment_confidence >= 0.5) def to_agent_context(self) -> Dict[str, Any]: """Format post for agent consumption""" sentiment_emoji = {"positive": "šŸ“ˆ", "negative": "šŸ“‰", "neutral": "āž”ļø"}.get(self.sentiment_label, "ā“") return { 'post_id': self.post_id, 'subreddit': self.subreddit, 'title': self.title, 'content': self.content[:200] + '...' if self.content and len(self.content) > 200 else self.content, 'author': self.author, 'created_utc': self.created_utc.isoformat(), 'engagement': { 'upvotes': self.upvotes, 'comments_count': self.comments_count, 'score': self.upvotes - self.downvotes }, 'sentiment': { 'label': self.sentiment_label, 'score': self.sentiment_score, 'confidence': self.sentiment_confidence, 'emoji': sentiment_emoji, 'reliable': self.has_reliable_sentiment() }, 'tickers': self.tickers or [], 'url': self.url } ``` **Acceptance Criteria:** - [ ] SocialPost model handles all Reddit API fields properly - [ ] Comprehensive validation for all fields including sentiment and embeddings - [ ] from_praw_submission() creates valid domain objects from Reddit data - [ ] extract_tickers() accurately finds ticker symbols in text - [ ] to_agent_context() formats data for AI agent consumption - [ ] Business rule validation prevents invalid state combinations **Dependencies:** None (can run parallel with other tasks) **Risk:** Low - Standard domain modeling --- ### Task 1.4: Repository Implementation (3 hours) **Priority: Medium** | **Agent: Repository Specialist** Implement SocialRepository with PostgreSQL operations, vector similarity search, and performance optimization. **File:** `tradingagents/domains/socialmedia/repositories.py` **Implementation:** ```python from typing import List, Optional, Dict, Any, Tuple from sqlalchemy import and_, or_, desc, text, func from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError from tradingagents.domains.socialmedia.entities import SocialMediaPostEntity from tradingagents.domains.socialmedia.models import SocialPost from tradingagents.database import DatabaseManager from datetime import datetime, timedelta import logging logger = logging.getLogger(__name__) class SocialRepository: """PostgreSQL repository for social media posts with vector search capabilities""" def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager async def upsert_batch(self, posts: List[SocialPost]) -> List[str]: """Batch upsert social media posts with deduplication""" async with self.db_manager.get_session() as session: saved_ids = [] for post in posts: try: # Check for existing post existing = await session.execute( text("SELECT id FROM social_media_posts WHERE post_id = :post_id"), {"post_id": post.post_id} ) if existing.first(): logger.debug(f"Skipping duplicate post: {post.post_id}") continue entity = SocialMediaPostEntity.from_domain(post) session.add(entity) saved_ids.append(post.post_id) except IntegrityError as e: logger.warning(f"Integrity error saving post {post.post_id}: {e}") await session.rollback() continue await session.commit() logger.info(f"Saved {len(saved_ids)} new posts to database") return saved_ids async def find_by_ticker(self, ticker: str, days: int = 30, limit: int = 50) -> List[SocialPost]: """Find posts mentioning specific ticker symbol""" async with self.db_manager.get_session() as session: cutoff_date = datetime.now() - timedelta(days=days) result = await session.execute( text(""" SELECT * FROM social_media_posts WHERE :ticker = ANY(tickers) AND created_utc >= :cutoff_date ORDER BY created_utc DESC LIMIT :limit """), { "ticker": ticker.upper(), "cutoff_date": cutoff_date, "limit": limit } ) entities = [SocialMediaPostEntity(**row) for row in result.mappings()] return [entity.to_domain() for entity in entities] async def find_by_subreddit(self, subreddit: str, hours: int = 24, limit: int = 100) -> List[SocialPost]: """Find recent posts from specific subreddit""" async with self.db_manager.get_session() as session: cutoff_date = datetime.now() - timedelta(hours=hours) result = await session.execute( text(""" SELECT * FROM social_media_posts WHERE subreddit = :subreddit AND created_utc >= :cutoff_date ORDER BY created_utc DESC LIMIT :limit """), { "subreddit": subreddit, "cutoff_date": cutoff_date, "limit": limit } ) entities = [SocialMediaPostEntity(**row) for row in result.mappings()] return [entity.to_domain() for entity in entities] async def find_similar_posts( self, query_embedding: List[float], ticker: Optional[str] = None, limit: int = 10, similarity_threshold: float = 0.8 ) -> List[Tuple[SocialPost, float]]: """Find similar posts using vector similarity search""" async with self.db_manager.get_session() as session: embedding_str = str(query_embedding) base_query = """ SELECT *, LEAST( 1 - (title_embedding <=> :embedding), 1 - (content_embedding <=> :embedding) ) as similarity_score FROM social_media_posts WHERE (title_embedding IS NOT NULL OR content_embedding IS NOT NULL) """ params = {"embedding": embedding_str} if ticker: base_query += " AND :ticker = ANY(tickers)" params["ticker"] = ticker.upper() base_query += """ AND LEAST( 1 - (title_embedding <=> :embedding), 1 - (content_embedding <=> :embedding) ) >= :threshold ORDER BY similarity_score DESC LIMIT :limit """ params.update({ "threshold": similarity_threshold, "limit": limit }) result = await session.execute(text(base_query), params) posts_with_scores = [] for row in result.mappings(): entity = SocialMediaPostEntity(**{k: v for k, v in row.items() if k != 'similarity_score'}) post = entity.to_domain() similarity = row['similarity_score'] posts_with_scores.append((post, similarity)) return posts_with_scores async def get_sentiment_summary( self, ticker: Optional[str] = None, subreddit: Optional[str] = None, hours: int = 24 ) -> Dict[str, Any]: """Get aggregated sentiment analysis for ticker or subreddit""" async with self.db_manager.get_session() as session: cutoff_date = datetime.now() - timedelta(hours=hours) base_query = """ SELECT sentiment_label, COUNT(*) as count, AVG((sentiment_score->>'score')::float) as avg_score, AVG((sentiment_score->>'confidence')::float) as avg_confidence, SUM(upvotes) as total_upvotes, SUM(comments_count) as total_comments FROM social_media_posts WHERE created_utc >= :cutoff_date AND sentiment_score IS NOT NULL """ params = {"cutoff_date": cutoff_date} if ticker: base_query += " AND :ticker = ANY(tickers)" params["ticker"] = ticker.upper() if subreddit: base_query += " AND subreddit = :subreddit" params["subreddit"] = subreddit base_query += " GROUP BY sentiment_label" result = await session.execute(text(base_query), params) sentiment_counts = {} total_posts = 0 weighted_score = 0 total_engagement = 0 for row in result.mappings(): label = row['sentiment_label'] count = row['count'] avg_score = float(row['avg_score'] or 0) engagement = (row['total_upvotes'] or 0) + (row['total_comments'] or 0) sentiment_counts[label] = { 'count': count, 'avg_score': avg_score, 'avg_confidence': float(row['avg_confidence'] or 0), 'engagement': engagement } total_posts += count weighted_score += avg_score * count total_engagement += engagement return { 'ticker': ticker, 'subreddit': subreddit, 'period_hours': hours, 'total_posts': total_posts, 'sentiment_breakdown': sentiment_counts, 'overall_sentiment': weighted_score / total_posts if total_posts > 0 else 0.0, 'total_engagement': total_engagement, 'data_quality': { 'posts_with_sentiment': total_posts, 'period_start': cutoff_date.isoformat(), 'generated_at': datetime.now().isoformat() } } async def cleanup_old_posts(self, days: int = 90) -> int: """Remove posts older than specified days""" async with self.db_manager.get_session() as session: cutoff_date = datetime.now() - timedelta(days=days) result = await session.execute( text("DELETE FROM social_media_posts WHERE created_utc < :cutoff_date"), {"cutoff_date": cutoff_date} ) deleted_count = result.rowcount await session.commit() logger.info(f"Cleaned up {deleted_count} posts older than {days} days") return deleted_count async def get_trending_tickers(self, hours: int = 24, min_mentions: int = 5) -> List[Dict[str, Any]]: """Find trending ticker symbols by mention frequency and sentiment""" async with self.db_manager.get_session() as session: cutoff_date = datetime.now() - timedelta(hours=hours) result = await session.execute( text(""" SELECT unnest(tickers) as ticker, COUNT(*) as mention_count, AVG((sentiment_score->>'score')::float) as avg_sentiment, SUM(upvotes) as total_upvotes, SUM(comments_count) as total_comments FROM social_media_posts WHERE created_utc >= :cutoff_date AND sentiment_score IS NOT NULL AND array_length(tickers, 1) > 0 GROUP BY ticker HAVING COUNT(*) >= :min_mentions ORDER BY mention_count DESC, avg_sentiment DESC LIMIT 20 """), { "cutoff_date": cutoff_date, "min_mentions": min_mentions } ) trending = [] for row in result.mappings(): trending.append({ 'ticker': row['ticker'], 'mention_count': row['mention_count'], 'avg_sentiment': float(row['avg_sentiment'] or 0), 'total_upvotes': row['total_upvotes'] or 0, 'total_comments': row['total_comments'] or 0, 'engagement_score': (row['total_upvotes'] or 0) + (row['total_comments'] or 0) }) return trending ``` **Acceptance Criteria:** - [ ] Batch upsert operations with proper deduplication - [ ] Vector similarity search using pgvectorscale indexes - [ ] Efficient ticker-based queries with TimescaleDB optimization - [ ] Comprehensive sentiment aggregation with engagement metrics - [ ] Data cleanup operations with configurable retention - [ ] Trending ticker analysis with minimum mention thresholds - [ ] Proper error handling and logging throughout **Dependencies:** Task 1.1 (database schema), Task 1.2 (entity model) **Risk:** Medium - Complex vector search queries --- ## Phase 2: API Integration & Processing (12 hours) ### Task 2.1: Reddit Client Implementation (4 hours) **Priority: Blocking** | **Agent: API Integration Specialist** Implement RedditClient using PRAW with comprehensive rate limiting, error handling, and financial subreddit focus. **File:** `tradingagents/domains/socialmedia/clients.py` **Implementation:** ```python import praw import asyncio from typing import List, Optional, Dict, Any, AsyncIterator from datetime import datetime, timedelta from tradingagents.config import TradingAgentsConfig import logging import time from contextlib import asynccontextmanager logger = logging.getLogger(__name__) class RedditClient: """PRAW-based Reddit client with rate limiting and error handling""" def __init__(self, config: TradingAgentsConfig): self.config = config self.reddit = None self.last_request_time = 0 self.min_request_interval = 1.0 # 1 second between requests self.financial_subreddits = [ 'wallstreetbets', 'investing', 'stocks', 'SecurityAnalysis', 'ValueInvesting', 'financialindependence', 'StockMarket', 'options', 'dividends', 'pennystocks' ] async def __aenter__(self): """Async context manager entry""" self._initialize_reddit() return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit""" pass def _initialize_reddit(self): """Initialize PRAW Reddit instance""" try: self.reddit = praw.Reddit( client_id=self.config.reddit_client_id, client_secret=self.config.reddit_client_secret, user_agent=self.config.reddit_user_agent, check_for_async=False ) # Test authentication self.reddit.user.me() logger.info("Reddit client initialized successfully") except Exception as e: logger.error(f"Failed to initialize Reddit client: {e}") raise async def _rate_limit_delay(self): """Implement rate limiting between requests""" current_time = time.time() time_since_last = current_time - self.last_request_time if time_since_last < self.min_request_interval: delay = self.min_request_interval - time_since_last await asyncio.sleep(delay) self.last_request_time = time.time() async def fetch_subreddit_posts( self, subreddit_name: str, time_filter: str = 'day', limit: int = 50, sort_type: str = 'hot' ) -> List[Dict[str, Any]]: """Fetch posts from a specific subreddit""" if not self.reddit: self._initialize_reddit() await self._rate_limit_delay() try: subreddit = self.reddit.subreddit(subreddit_name) # Get submissions based on sort type if sort_type == 'hot': submissions = subreddit.hot(limit=limit) elif sort_type == 'top': submissions = subreddit.top(time_filter=time_filter, limit=limit) elif sort_type == 'new': submissions = subreddit.new(limit=limit) else: submissions = subreddit.hot(limit=limit) posts = [] for submission in submissions: # Skip removed or deleted posts if submission.selftext == '[removed]' or submission.selftext == '[deleted]': continue post_data = self._extract_post_data(submission, subreddit_name) posts.append(post_data) logger.info(f"Fetched {len(posts)} posts from r/{subreddit_name}") return posts except Exception as e: logger.error(f"Error fetching posts from r/{subreddit_name}: {e}") return [] async def fetch_financial_posts_batch( self, subreddits: Optional[List[str]] = None, time_filter: str = 'day', posts_per_subreddit: int = 50 ) -> Dict[str, List[Dict[str, Any]]]: """Fetch posts from multiple financial subreddits""" if not subreddits: subreddits = self.financial_subreddits results = {} for subreddit_name in subreddits: try: posts = await self.fetch_subreddit_posts( subreddit_name=subreddit_name, time_filter=time_filter, limit=posts_per_subreddit ) results[subreddit_name] = posts except Exception as e: logger.error(f"Failed to fetch from r/{subreddit_name}: {e}") results[subreddit_name] = [] total_posts = sum(len(posts) for posts in results.values()) logger.info(f"Fetched {total_posts} total posts from {len(subreddits)} subreddits") return results async def search_posts( self, query: str, subreddit_names: Optional[List[str]] = None, time_filter: str = 'week', limit: int = 25 ) -> List[Dict[str, Any]]: """Search for posts containing specific terms""" if not self.reddit: self._initialize_reddit() if not subreddit_names: subreddit_names = self.financial_subreddits all_posts = [] for subreddit_name in subreddit_names: await self._rate_limit_delay() try: subreddit = self.reddit.subreddit(subreddit_name) search_results = subreddit.search( query=query, time_filter=time_filter, limit=limit, sort='relevance' ) for submission in search_results: if submission.selftext not in ['[removed]', '[deleted]']: post_data = self._extract_post_data(submission, subreddit_name) all_posts.append(post_data) except Exception as e: logger.error(f"Search error in r/{subreddit_name}: {e}") continue logger.info(f"Found {len(all_posts)} posts matching query: {query}") return all_posts def _extract_post_data(self, submission: Any, subreddit_name: str) -> Dict[str, Any]: """Extract structured data from PRAW submission""" try: return { 'post_id': submission.id, 'title': submission.title[:300], # Limit title length 'content': submission.selftext if submission.selftext else None, 'author': str(submission.author) if submission.author else '[deleted]', 'subreddit': subreddit_name, 'created_utc': datetime.fromtimestamp(submission.created_utc), 'upvotes': getattr(submission, 'ups', submission.score), 'downvotes': max(0, submission.score - getattr(submission, 'ups', submission.score)), 'comments_count': submission.num_comments, 'url': f"https://reddit.com{submission.permalink}", 'reddit_score': submission.score, 'upvote_ratio': getattr(submission, 'upvote_ratio', 0.5), 'is_self': submission.is_self, 'domain': submission.domain, 'flair_text': getattr(submission, 'link_flair_text', None) } except Exception as e: logger.error(f"Error extracting post data: {e}") return None async def get_post_details(self, post_id: str) -> Optional[Dict[str, Any]]: """Get detailed information for a specific post""" if not self.reddit: self._initialize_reddit() await self._rate_limit_delay() try: submission = self.reddit.submission(id=post_id) return self._extract_post_data(submission, submission.subreddit.display_name) except Exception as e: logger.error(f"Error fetching post details for {post_id}: {e}") return None async def health_check(self) -> bool: """Check if Reddit API is accessible""" try: if not self.reddit: self._initialize_reddit() # Simple API call to verify connectivity self.reddit.subreddit('wallstreetbets').hot(limit=1) return True except Exception as e: logger.error(f"Reddit health check failed: {e}") return False ``` **Testing Implementation:** ```python # tests/domains/socialmedia/test_reddit_client.py import pytest import pytest_vcr from unittest.mock import MagicMock, patch from tradingagents.domains.socialmedia.clients import RedditClient from tradingagents.config import TradingAgentsConfig @pytest_vcr.use_cassette('reddit_fetch_posts.yaml') @pytest.mark.asyncio async def test_fetch_subreddit_posts(reddit_client, trading_config): """Test fetching posts from a specific subreddit""" async with reddit_client: posts = await reddit_client.fetch_subreddit_posts('wallstreetbets', limit=10) assert len(posts) > 0 for post in posts: assert 'post_id' in post assert 'title' in post assert 'subreddit' in post assert post['subreddit'] == 'wallstreetbets' ``` **Acceptance Criteria:** - [ ] PRAW Reddit client properly authenticated and initialized - [ ] Rate limiting implemented (1 request per second minimum) - [ ] Comprehensive error handling for network issues and API limits - [ ] Financial subreddit focus with configurable subreddit lists - [ ] Structured data extraction from Reddit submissions - [ ] Search functionality across multiple subreddits - [ ] Health check capabilities for monitoring - [ ] Test coverage with pytest-vcr cassettes **Dependencies:** Reddit API credentials in TradingAgentsConfig **Risk:** High - External API dependency, rate limiting complexity --- ### Task 2.2: OpenRouter LLM Sentiment Analysis (3 hours) **Priority: Medium** | **Agent: LLM Integration Specialist** Implement sentiment analysis using OpenRouter with social media-specific prompts and structured output parsing. **File:** `tradingagents/domains/socialmedia/sentiment.py` **Implementation:** ```python from typing import Optional, Dict, Any, List import json import asyncio from tradingagents.llm.openrouter_client import OpenRouterClient from tradingagents.config import TradingAgentsConfig from tradingagents.domains.socialmedia.models import SentimentScore import logging logger = logging.getLogger(__name__) class SocialSentimentAnalyzer: """OpenRouter-based sentiment analysis for social media posts""" def __init__(self, config: TradingAgentsConfig): self.config = config self.client = OpenRouterClient(config) self.batch_size = 5 # Process posts in batches async def analyze_post_sentiment(self, post_text: str, ticker: Optional[str] = None) -> Optional[SentimentScore]: """Analyze sentiment of a single social media post""" prompt = self._create_sentiment_prompt(post_text, ticker) try: response = await self.client.generate_response( model=self.config.quick_think_llm, messages=[{"role": "user", "content": prompt}], max_tokens=150, temperature=0.1, response_format={"type": "json_object"} ) result = json.loads(response) return SentimentScore( sentiment=result.get('sentiment', 'neutral'), confidence=float(result.get('confidence', 0.0)), reasoning=result.get('reasoning') ) except Exception as e: logger.error(f"Sentiment analysis failed: {e}") return None async def analyze_batch( self, posts: List[Dict[str, Any]], include_ticker: bool = True ) -> List[Optional[SentimentScore]]: """Analyze sentiment for multiple posts with rate limiting""" results = [] for i in range(0, len(posts), self.batch_size): batch = posts[i:i + self.batch_size] batch_tasks = [] for post in batch: text = self._combine_post_text(post) ticker = None if include_ticker and 'tickers' in post and post['tickers']: ticker = post['tickers'][0] # Use first ticker if available task = self.analyze_post_sentiment(text, ticker) batch_tasks.append(task) # Process batch with concurrency limit batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) for result in batch_results: if isinstance(result, Exception): logger.error(f"Batch sentiment analysis error: {result}") results.append(None) else: results.append(result) # Rate limiting between batches if i + self.batch_size < len(posts): await asyncio.sleep(1.0) successful_count = sum(1 for r in results if r is not None) logger.info(f"Sentiment analysis completed: {successful_count}/{len(posts)} successful") return results def _create_sentiment_prompt(self, text: str, ticker: Optional[str] = None) -> str: """Create social media-specific sentiment analysis prompt""" ticker_context = f" for ticker ${ticker}" if ticker else "" return f""" Analyze the financial sentiment of this Reddit post{ticker_context}. Consider: - Trading/investment sentiment (not general mood) - Informal language, slang, and memes common in financial social media - Context clues like "diamond hands", "to the moon", "bearish", etc. - Overall market outlook expressed in the post Post text: "{text}" Respond with JSON only: {{ "sentiment": "positive|negative|neutral", "confidence": 0.0-1.0, "reasoning": "brief explanation of key factors" }} Guidelines: - "positive": Bullish, optimistic about price/performance - "negative": Bearish, pessimistic about price/performance - "neutral": Mixed signals or no clear directional sentiment - Confidence: How certain are you? (0.5+ for reliable sentiment) - Reasoning: Key words/phrases that influenced the classification """.strip() def _combine_post_text(self, post: Dict[str, Any]) -> str: """Combine title and content for sentiment analysis""" title = post.get('title', '') content = post.get('content', '') if content: # Limit total text length for efficient processing combined = f"{title} {content}"[:1000] else: combined = title return combined.strip() async def analyze_market_sentiment( self, posts: List[Dict[str, Any]], ticker: str ) -> Dict[str, Any]: """Analyze overall market sentiment for a ticker from multiple posts""" sentiments = await self.analyze_batch(posts, include_ticker=True) # Filter out failed analyses valid_sentiments = [s for s in sentiments if s is not None and s.confidence >= 0.5] if not valid_sentiments: return { 'ticker': ticker, 'overall_sentiment': 'neutral', 'confidence': 0.0, 'post_count': len(posts), 'analysis_success_rate': 0.0, 'sentiment_distribution': {'positive': 0, 'negative': 0, 'neutral': 0} } # Calculate sentiment distribution sentiment_counts = {'positive': 0, 'negative': 0, 'neutral': 0} confidence_sum = 0 for sentiment in valid_sentiments: sentiment_counts[sentiment.sentiment] += 1 confidence_sum += sentiment.confidence # Determine overall sentiment total_valid = len(valid_sentiments) positive_ratio = sentiment_counts['positive'] / total_valid negative_ratio = sentiment_counts['negative'] / total_valid if positive_ratio > 0.6: overall_sentiment = 'positive' elif negative_ratio > 0.6: overall_sentiment = 'negative' else: overall_sentiment = 'neutral' return { 'ticker': ticker, 'overall_sentiment': overall_sentiment, 'confidence': confidence_sum / total_valid, 'post_count': len(posts), 'analyzed_posts': total_valid, 'analysis_success_rate': total_valid / len(posts), 'sentiment_distribution': sentiment_counts, 'positive_ratio': positive_ratio, 'negative_ratio': negative_ratio, 'neutral_ratio': sentiment_counts['neutral'] / total_valid } ``` **Acceptance Criteria:** - [ ] OpenRouter integration for sentiment analysis with structured JSON output - [ ] Social media-specific prompts handling informal language and financial slang - [ ] Batch processing with rate limiting and error handling - [ ] Confidence scoring for sentiment reliability - [ ] Market sentiment aggregation across multiple posts - [ ] Comprehensive error handling and logging - [ ] Test coverage with mocked LLM responses **Dependencies:** OpenRouter client implementation **Risk:** Medium - LLM API reliability and cost management --- ### Task 2.3: Vector Embedding Generation (2 hours) **Priority: Medium** | **Agent: ML Integration Specialist** Implement vector embedding generation for semantic similarity search using OpenRouter embedding models. **File:** `tradingagents/domains/socialmedia/embeddings.py` **Implementation:** ```python from typing import List, Optional, Dict, Any import asyncio import numpy as np from tradingagents.llm.openrouter_client import OpenRouterClient from tradingagents.config import TradingAgentsConfig import logging logger = logging.getLogger(__name__) class SocialEmbeddingGenerator: """Generate vector embeddings for social media posts using OpenRouter""" def __init__(self, config: TradingAgentsConfig): self.config = config self.client = OpenRouterClient(config) self.embedding_model = "text-embedding-3-large" # 1536 dimensions self.max_text_length = 8000 # Token limit for embedding model self.batch_size = 10 async def generate_post_embeddings( self, post: Dict[str, Any] ) -> Dict[str, Optional[List[float]]]: """Generate embeddings for post title and content separately""" embeddings = { 'title_embedding': None, 'content_embedding': None } # Generate title embedding title = post.get('title', '').strip() if title: embeddings['title_embedding'] = await self._generate_embedding(title) # Generate content embedding if content exists content = post.get('content', '').strip() if content: # Combine title and content for content embedding combined_text = f"{title} {content}"[:self.max_text_length] embeddings['content_embedding'] = await self._generate_embedding(combined_text) return embeddings async def generate_batch_embeddings( self, posts: List[Dict[str, Any]] ) -> List[Dict[str, Optional[List[float]]]]: """Generate embeddings for multiple posts with batching""" results = [] for i in range(0, len(posts), self.batch_size): batch = posts[i:i + self.batch_size] # Create tasks for concurrent processing tasks = [self.generate_post_embeddings(post) for post in batch] batch_results = await asyncio.gather(*tasks, return_exceptions=True) for result in batch_results: if isinstance(result, Exception): logger.error(f"Embedding generation error: {result}") results.append({'title_embedding': None, 'content_embedding': None}) else: results.append(result) # Rate limiting between batches if i + self.batch_size < len(posts): await asyncio.sleep(0.5) successful_count = sum( 1 for r in results if r.get('title_embedding') is not None or r.get('content_embedding') is not None ) logger.info(f"Embedding generation completed: {successful_count}/{len(posts)} successful") return results async def generate_query_embedding(self, query: str) -> Optional[List[float]]: """Generate embedding for search query""" return await self._generate_embedding(query[:self.max_text_length]) async def _generate_embedding(self, text: str) -> Optional[List[float]]: """Generate single embedding using OpenRouter""" if not text.strip(): return None try: response = await self.client.create_embeddings( model=self.embedding_model, input=[text], encoding_format="float" ) if response and response.data: embedding = response.data[0].embedding # Validate embedding dimensions if len(embedding) != 1536: logger.error(f"Unexpected embedding dimension: {len(embedding)}") return None return embedding except Exception as e: logger.error(f"Embedding generation failed for text: {e}") return None return None def calculate_similarity( self, embedding1: List[float], embedding2: List[float] ) -> float: """Calculate cosine similarity between two embeddings""" try: # Convert to numpy arrays for efficient computation vec1 = np.array(embedding1) vec2 = np.array(embedding2) # Cosine similarity: dot product / (magnitude1 * magnitude2) dot_product = np.dot(vec1, vec2) magnitude1 = np.linalg.norm(vec1) magnitude2 = np.linalg.norm(vec2) if magnitude1 == 0 or magnitude2 == 0: return 0.0 similarity = dot_product / (magnitude1 * magnitude2) return float(similarity) except Exception as e: logger.error(f"Similarity calculation error: {e}") return 0.0 def find_most_similar( self, query_embedding: List[float], post_embeddings: List[Dict[str, Any]], top_k: int = 10 ) -> List[Dict[str, Any]]: """Find most similar posts to query embedding""" similarities = [] for i, post_data in enumerate(post_embeddings): max_similarity = 0.0 best_embedding_type = None # Check title embedding similarity title_emb = post_data.get('title_embedding') if title_emb: title_sim = self.calculate_similarity(query_embedding, title_emb) if title_sim > max_similarity: max_similarity = title_sim best_embedding_type = 'title' # Check content embedding similarity content_emb = post_data.get('content_embedding') if content_emb: content_sim = self.calculate_similarity(query_embedding, content_emb) if content_sim > max_similarity: max_similarity = content_sim best_embedding_type = 'content' if max_similarity > 0: similarities.append({ 'post_index': i, 'similarity_score': max_similarity, 'embedding_type': best_embedding_type, 'post_data': post_data }) # Sort by similarity score and return top k similarities.sort(key=lambda x: x['similarity_score'], reverse=True) return similarities[:top_k] async def create_semantic_clusters( self, posts: List[Dict[str, Any]], similarity_threshold: float = 0.8 ) -> List[List[Dict[str, Any]]]: """Group similar posts into semantic clusters""" if not posts: return [] # Generate embeddings for all posts embeddings_data = await self.generate_batch_embeddings(posts) # Combine posts with their embeddings posts_with_embeddings = [] for post, embeddings in zip(posts, embeddings_data): if embeddings.get('title_embedding') or embeddings.get('content_embedding'): posts_with_embeddings.append({**post, **embeddings}) clusters = [] processed = set() for i, post in enumerate(posts_with_embeddings): if i in processed: continue current_cluster = [post] processed.add(i) # Find similar posts for current cluster for j, other_post in enumerate(posts_with_embeddings): if j in processed or i == j: continue # Calculate similarity between posts max_sim = 0.0 # Compare all embedding combinations for emb1_type in ['title_embedding', 'content_embedding']: for emb2_type in ['title_embedding', 'content_embedding']: emb1 = post.get(emb1_type) emb2 = other_post.get(emb2_type) if emb1 and emb2: sim = self.calculate_similarity(emb1, emb2) max_sim = max(max_sim, sim) if max_sim >= similarity_threshold: current_cluster.append(other_post) processed.add(j) if len(current_cluster) > 1: # Only include clusters with multiple posts clusters.append(current_cluster) logger.info(f"Created {len(clusters)} semantic clusters from {len(posts)} posts") return clusters ``` **Acceptance Criteria:** - [ ] Vector embedding generation for post titles and content separately - [ ] Batch processing with rate limiting for efficiency - [ ] Cosine similarity calculation for semantic search - [ ] Query embedding generation for search functionality - [ ] Semantic clustering capabilities for related post discovery - [ ] Proper error handling and dimension validation - [ ] Test coverage with mocked embedding responses **Dependencies:** OpenRouter client with embedding support **Risk:** Low - Standard embedding generation patterns --- ### Task 2.4: Service Layer Implementation (3 hours) **Priority: Medium** | **Agent: Service Integration Specialist** Implement SocialMediaService that orchestrates Reddit collection, sentiment analysis, and embedding generation. **File:** `tradingagents/domains/socialmedia/services.py` **Implementation:** ```python from typing import List, Optional, Dict, Any, Tuple import asyncio import logging from datetime import datetime, timedelta from tradingagents.domains.socialmedia.clients import RedditClient from tradingagents.domains.socialmedia.repositories import SocialRepository from tradingagents.domains.socialmedia.sentiment import SocialSentimentAnalyzer from tradingagents.domains.socialmedia.embeddings import SocialEmbeddingGenerator from tradingagents.domains.socialmedia.models import SocialPost, SocialContext from tradingagents.config import TradingAgentsConfig from tradingagents.database import DatabaseManager logger = logging.getLogger(__name__) class SocialMediaService: """Orchestrates social media data collection, analysis, and storage""" def __init__(self, config: TradingAgentsConfig, db_manager: DatabaseManager): self.config = config self.db_manager = db_manager self.repository = SocialRepository(db_manager) self.sentiment_analyzer = SocialSentimentAnalyzer(config) self.embedding_generator = SocialEmbeddingGenerator(config) # Configuration self.financial_subreddits = [ 'wallstreetbets', 'investing', 'stocks', 'SecurityAnalysis', 'ValueInvesting', 'financialindependence', 'StockMarket' ] self.min_score_threshold = 10 # Minimum upvotes self.max_posts_per_subreddit = 50 async def collect_and_process_posts( self, subreddits: Optional[List[str]] = None, time_filter: str = 'day', process_sentiment: bool = True, generate_embeddings: bool = True ) -> Dict[str, Any]: """Main entry point for collecting and processing social media posts""" if not subreddits: subreddits = self.financial_subreddits collection_start = datetime.now() logger.info(f"Starting social media collection from {len(subreddits)} subreddits") async with RedditClient(self.config) as reddit_client: # Collect raw posts from Reddit raw_posts_by_subreddit = await reddit_client.fetch_financial_posts_batch( subreddits=subreddits, time_filter=time_filter, posts_per_subreddit=self.max_posts_per_subreddit ) # Flatten and filter posts all_raw_posts = [] for subreddit, posts in raw_posts_by_subreddit.items(): filtered_posts = [ post for post in posts if post and post.get('reddit_score', 0) >= self.min_score_threshold ] all_raw_posts.extend(filtered_posts) logger.info(f"Collected {len(all_raw_posts)} posts meeting quality thresholds") # Convert to domain objects and extract tickers domain_posts = [] for raw_post in all_raw_posts: try: post = SocialPost(**raw_post) post.tickers = post.extract_tickers() # Extract tickers from content domain_posts.append(post) except Exception as e: logger.error(f"Error creating domain object: {e}") continue # Process sentiment analysis if requested if process_sentiment and domain_posts: await self._process_sentiment_analysis(domain_posts) # Generate embeddings if requested if generate_embeddings and domain_posts: await self._process_embeddings(domain_posts) # Save to database saved_post_ids = await self.repository.upsert_batch(domain_posts) collection_end = datetime.now() processing_time = (collection_end - collection_start).total_seconds() # Calculate success metrics results = { 'collection_timestamp': collection_start.isoformat(), 'processing_time_seconds': processing_time, 'subreddits_processed': subreddits, 'total_posts_collected': len(all_raw_posts), 'posts_processed': len(domain_posts), 'posts_saved': len(saved_post_ids), 'sentiment_analysis_enabled': process_sentiment, 'embeddings_enabled': generate_embeddings, 'subreddit_breakdown': {} } # Add per-subreddit breakdown for subreddit, posts in raw_posts_by_subreddit.items(): results['subreddit_breakdown'][subreddit] = { 'posts_collected': len(posts), 'posts_filtered': len([p for p in posts if p.get('reddit_score', 0) >= self.min_score_threshold]) } logger.info(f"Collection completed: {len(saved_post_ids)} posts saved in {processing_time:.2f}s") return results async def get_social_context( self, ticker: str, days: int = 7, include_similar: bool = True, similarity_query: Optional[str] = None ) -> SocialContext: """Get comprehensive social media context for a ticker""" logger.info(f"Generating social context for {ticker} ({days} days)") # Get direct ticker mentions ticker_posts = await self.repository.find_by_ticker(ticker, days=days, limit=50) similar_posts = [] if include_similar and ticker_posts: # Use semantic search to find related discussions if similarity_query: query_embedding = await self.embedding_generator.generate_query_embedding(similarity_query) if query_embedding: similar_results = await self.repository.find_similar_posts( query_embedding=query_embedding, ticker=ticker, limit=10 ) similar_posts = [post for post, score in similar_results] # Get sentiment summary sentiment_summary = await self.repository.get_sentiment_summary( ticker=ticker, hours=days * 24 ) # Find trending discussions trending_tickers = await self.repository.get_trending_tickers( hours=days * 24, min_mentions=3 ) ticker_trend = next( (trend for trend in trending_tickers if trend['ticker'] == ticker.upper()), None ) return SocialContext( ticker=ticker, period_days=days, direct_mentions=ticker_posts, similar_posts=similar_posts, sentiment_summary=sentiment_summary, trending_info=ticker_trend, total_posts=len(ticker_posts) + len(similar_posts), data_quality_score=self._calculate_data_quality(ticker_posts + similar_posts) ) async def search_posts_semantic( self, query: str, ticker: Optional[str] = None, limit: int = 10, min_similarity: float = 0.7 ) -> List[Tuple[SocialPost, float]]: """Semantic search for social media posts""" query_embedding = await self.embedding_generator.generate_query_embedding(query) if not query_embedding: logger.error(f"Failed to generate query embedding for: {query}") return [] return await self.repository.find_similar_posts( query_embedding=query_embedding, ticker=ticker, limit=limit, similarity_threshold=min_similarity ) async def get_subreddit_analysis( self, subreddit: str, hours: int = 24 ) -> Dict[str, Any]: """Get analysis of a specific subreddit's activity""" posts = await self.repository.find_by_subreddit(subreddit, hours=hours) if not posts: return { 'subreddit': subreddit, 'period_hours': hours, 'total_posts': 0, 'message': f'No posts found for r/{subreddit} in the last {hours} hours' } # Analyze ticker mentions ticker_counts = {} for post in posts: for ticker in post.tickers or []: ticker_counts[ticker] = ticker_counts.get(ticker, 0) + 1 top_tickers = sorted(ticker_counts.items(), key=lambda x: x[1], reverse=True)[:10] # Analyze sentiment distribution sentiment_counts = {'positive': 0, 'negative': 0, 'neutral': 0} reliable_sentiment_count = 0 for post in posts: if post.sentiment_label: sentiment_counts[post.sentiment_label] += 1 if post.has_reliable_sentiment(): reliable_sentiment_count += 1 # Calculate engagement metrics total_upvotes = sum(post.upvotes for post in posts) total_comments = sum(post.comments_count for post in posts) avg_score = total_upvotes / len(posts) if posts else 0 return { 'subreddit': subreddit, 'period_hours': hours, 'total_posts': len(posts), 'engagement_metrics': { 'total_upvotes': total_upvotes, 'total_comments': total_comments, 'avg_score': avg_score, 'top_post_score': max(post.upvotes for post in posts) if posts else 0 }, 'sentiment_analysis': { 'distribution': sentiment_counts, 'reliable_sentiment_posts': reliable_sentiment_count, 'sentiment_reliability': reliable_sentiment_count / len(posts) if posts else 0 }, 'ticker_mentions': { 'top_tickers': top_tickers, 'unique_tickers': len(ticker_counts), 'total_mentions': sum(ticker_counts.values()) }, 'data_quality': self._calculate_data_quality(posts) } async def _process_sentiment_analysis(self, posts: List[SocialPost]) -> None: """Process sentiment analysis for posts""" logger.info(f"Processing sentiment analysis for {len(posts)} posts") # Convert to dict format for sentiment analyzer posts_data = [] for post in posts: post_dict = post.dict() posts_data.append(post_dict) # Analyze sentiment in batches sentiments = await self.sentiment_analyzer.analyze_batch(posts_data) # Update posts with sentiment results for post, sentiment in zip(posts, sentiments): if sentiment: post.sentiment_score = sentiment.score if hasattr(sentiment, 'score') else None post.sentiment_label = sentiment.sentiment post.sentiment_confidence = sentiment.confidence post.sentiment_reasoning = sentiment.reasoning successful_count = sum(1 for s in sentiments if s is not None) logger.info(f"Sentiment analysis completed: {successful_count}/{len(posts)} successful") async def _process_embeddings(self, posts: List[SocialPost]) -> None: """Process embedding generation for posts""" logger.info(f"Generating embeddings for {len(posts)} posts") # Convert to dict format for embedding generator posts_data = [] for post in posts: post_dict = post.dict() posts_data.append(post_dict) # Generate embeddings in batches embeddings = await self.embedding_generator.generate_batch_embeddings(posts_data) # Update posts with embedding results for post, embedding_data in zip(posts, embeddings): post.title_embedding = embedding_data.get('title_embedding') post.content_embedding = embedding_data.get('content_embedding') successful_count = sum( 1 for e in embeddings if e.get('title_embedding') is not None or e.get('content_embedding') is not None ) logger.info(f"Embedding generation completed: {successful_count}/{len(posts)} successful") def _calculate_data_quality(self, posts: List[SocialPost]) -> Dict[str, float]: """Calculate data quality metrics for posts""" if not posts: return {'overall_score': 0.0} sentiment_coverage = sum(1 for p in posts if p.sentiment_label is not None) / len(posts) reliable_sentiment = sum(1 for p in posts if p.has_reliable_sentiment()) / len(posts) embedding_coverage = sum( 1 for p in posts if p.title_embedding is not None or p.content_embedding is not None ) / len(posts) ticker_extraction = sum(1 for p in posts if p.tickers) / len(posts) overall_score = (sentiment_coverage + reliable_sentiment + embedding_coverage + ticker_extraction) / 4 return { 'overall_score': overall_score, 'sentiment_coverage': sentiment_coverage, 'reliable_sentiment_ratio': reliable_sentiment, 'embedding_coverage': embedding_coverage, 'ticker_extraction_ratio': ticker_extraction } ``` **Acceptance Criteria:** - [ ] Orchestrates complete collection, analysis, and storage pipeline - [ ] Integrates Reddit client, sentiment analyzer, and embedding generator - [ ] Handles batch processing with proper error handling and logging - [ ] Provides ticker-specific social context with sentiment and similarity - [ ] Semantic search capabilities with configurable similarity thresholds - [ ] Subreddit analysis with engagement and sentiment metrics - [ ] Data quality scoring and monitoring - [ ] Comprehensive test coverage with mocked dependencies **Dependencies:** All Phase 2 tasks (clients, sentiment, embeddings) **Risk:** Medium - Complex orchestration of multiple async services --- ## Phase 3: Integration & Validation (8 hours) ### Task 3.1: AgentToolkit Integration (3 hours) **Priority: High** | **Agent: Agent Integration Specialist** Add RAG-enhanced social media methods to AgentToolkit for AI agent consumption. **File:** `tradingagents/agents/libs/agent_toolkit.py` (additions) **Implementation:** ```python # Additional methods for AgentToolkit class async def get_reddit_sentiment( self, ticker: str, days: int = 7, include_context: bool = True ) -> str: """Get Reddit sentiment analysis for a specific ticker with RAG context""" try: if not hasattr(self, 'social_service'): self.social_service = SocialMediaService(self.config, self.db_manager) # Get comprehensive social context social_context = await self.social_service.get_social_context( ticker=ticker, days=days, include_similar=include_context ) if not social_context.total_posts: return f"No Reddit sentiment data found for ${ticker} in the last {days} days." # Format for agent consumption sentiment_summary = social_context.sentiment_summary trending_info = social_context.trending_info context = f"Reddit Sentiment Analysis for ${ticker} ({days}-day period):\n\n" # Overall sentiment metrics if sentiment_summary: overall_score = sentiment_summary.get('overall_sentiment', 0.0) sentiment_emoji = "šŸ“ˆ" if overall_score > 0.1 else "šŸ“‰" if overall_score < -0.1 else "āž”ļø" context += f"{sentiment_emoji} Overall Sentiment: {overall_score:.2f}/1.0\n" context += f"šŸ“Š Analysis Coverage: {social_context.total_posts} posts analyzed\n" # Sentiment breakdown breakdown = sentiment_summary.get('sentiment_breakdown', {}) if breakdown: context += f" • Positive: {breakdown.get('positive', {}).get('count', 0)} posts\n" context += f" • Negative: {breakdown.get('negative', {}).get('count', 0)} posts\n" context += f" • Neutral: {breakdown.get('neutral', {}).get('count', 0)} posts\n" # Trending information if trending_info: context += f"\nšŸ”„ Trending Status:\n" context += f" • Mentions: {trending_info['mention_count']} posts\n" context += f" • Engagement: {trending_info['engagement_score']} (upvotes + comments)\n" context += f" • Avg Sentiment: {trending_info['avg_sentiment']:.2f}\n" # Top discussions (sample posts) if social_context.direct_mentions: context += f"\nšŸ’¬ Recent Discussions:\n" for i, post in enumerate(social_context.direct_mentions[:5]): sentiment_emoji = {"positive": "šŸ“ˆ", "negative": "šŸ“‰", "neutral": "āž”ļø"}.get( post.sentiment_label, "ā“" ) context += f"{i+1}. {sentiment_emoji} r/{post.subreddit}: {post.title[:100]}...\n" context += f" Score: {post.upvotes} upvotes, {post.comments_count} comments\n" if post.has_reliable_sentiment(): context += f" Sentiment: {post.sentiment_label} ({post.sentiment_confidence:.2f})\n" # Data quality indicators quality = social_context.data_quality_score context += f"\nšŸ“‹ Data Quality: {quality.get('overall_score', 0):.1%} coverage\n" return context except Exception as e: logger.error(f"Error getting Reddit sentiment for {ticker}: {e}") return f"Error retrieving Reddit sentiment for ${ticker}: {str(e)}" async def get_reddit_stock_info( self, ticker: str, query: Optional[str] = None, days: int = 7 ) -> str: """Get Reddit stock information with optional semantic search""" try: if not hasattr(self, 'social_service'): self.social_service = SocialMediaService(self.config, self.db_manager) context = f"Reddit Stock Information for ${ticker}:\n\n" if query: # Semantic search for specific information search_results = await self.social_service.search_posts_semantic( query=query, ticker=ticker, limit=10, min_similarity=0.7 ) if search_results: context += f"šŸ” Semantic Search Results for '{query}':\n" for i, (post, similarity) in enumerate(search_results[:5]): context += f"{i+1}. (Similarity: {similarity:.2f}) r/{post.subreddit}\n" context += f" Title: {post.title}\n" if post.content: context += f" Content: {post.content[:150]}...\n" context += f" Engagement: {post.upvotes} upvotes, {post.comments_count} comments\n\n" else: context += f"šŸ” No relevant discussions found for '{query}' about ${ticker}\n\n" # Get general stock context social_context = await self.social_service.get_social_context( ticker=ticker, days=days, include_similar=False ) if social_context.direct_mentions: context += f"šŸ“ˆ Recent Stock Discussions ({len(social_context.direct_mentions)} posts):\n" # Group by subreddit for better organization by_subreddit = {} for post in social_context.direct_mentions: if post.subreddit not in by_subreddit: by_subreddit[post.subreddit] = [] by_subreddit[post.subreddit].append(post) for subreddit, posts in by_subreddit.items(): context += f"\nr/{subreddit} ({len(posts)} posts):\n" for post in posts[:3]: # Top 3 per subreddit sentiment_info = "" if post.has_reliable_sentiment(): sentiment_emoji = {"positive": "šŸ“ˆ", "negative": "šŸ“‰", "neutral": "āž”ļø"} emoji = sentiment_emoji.get(post.sentiment_label, "ā“") sentiment_info = f" {emoji} {post.sentiment_label}" context += f" • {post.title[:80]}...{sentiment_info}\n" context += f" {post.upvotes} upvotes, {post.comments_count} comments\n" # Add trending context if available if social_context.trending_info: trend = social_context.trending_info context += f"\nšŸ“Š Trending Analysis:\n" context += f" • Market attention: {trend['mention_count']} recent mentions\n" context += f" • Community sentiment: {trend['avg_sentiment']:.2f}/1.0\n" context += f" • Total engagement: {trend['engagement_score']}\n" return context except Exception as e: logger.error(f"Error getting Reddit stock info for {ticker}: {e}") return f"Error retrieving Reddit stock information for ${ticker}: {str(e)}" async def search_social_posts( self, query: str, ticker: Optional[str] = None, limit: int = 10, days: int = 30 ) -> str: """Search social media posts using semantic similarity""" try: if not hasattr(self, 'social_service'): self.social_service = SocialMediaService(self.config, self.db_manager) # Perform semantic search search_results = await self.social_service.search_posts_semantic( query=query, ticker=ticker, limit=limit, min_similarity=0.6 ) if not search_results: ticker_context = f" about ${ticker}" if ticker else "" return f"No relevant social media posts found for '{query}'{ticker_context}." ticker_context = f" (${ticker})" if ticker else "" context = f"Social Media Search Results for '{query}'{ticker_context}:\n\n" context += f"Found {len(search_results)} relevant posts:\n\n" for i, (post, similarity) in enumerate(search_results): context += f"{i+1}. Relevance: {similarity:.2%} | r/{post.subreddit}\n" context += f" Title: {post.title}\n" if post.content: # Show relevant snippet content_preview = post.content[:200] + "..." if len(post.content) > 200 else post.content context += f" Content: {content_preview}\n" # Add sentiment if available if post.has_reliable_sentiment(): sentiment_emoji = {"positive": "šŸ“ˆ", "negative": "šŸ“‰", "neutral": "āž”ļø"}.get( post.sentiment_label, "ā“" ) context += f" Sentiment: {sentiment_emoji} {post.sentiment_label} ({post.sentiment_confidence:.2f})\n" # Add engagement metrics context += f" Engagement: {post.upvotes} upvotes, {post.comments_count} comments\n" context += f" Posted: {post.created_utc.strftime('%Y-%m-%d %H:%M')} UTC\n\n" return context except Exception as e: logger.error(f"Error searching social posts for '{query}': {e}") return f"Error searching social media posts: {str(e)}" async def get_subreddit_analysis( self, subreddit: str, ticker: Optional[str] = None, hours: int = 24 ) -> str: """Get analysis of activity in a specific financial subreddit""" try: if not hasattr(self, 'social_service'): self.social_service = SocialMediaService(self.config, self.db_manager) analysis = await self.social_service.get_subreddit_analysis(subreddit, hours=hours) if analysis['total_posts'] == 0: return f"No recent activity found in r/{subreddit} in the last {hours} hours." context = f"r/{subreddit} Analysis ({hours}-hour period):\n\n" # Activity overview context += f"šŸ“Š Activity Overview:\n" context += f" • Total Posts: {analysis['total_posts']}\n" context += f" • Total Upvotes: {analysis['engagement_metrics']['total_upvotes']:,}\n" context += f" • Total Comments: {analysis['engagement_metrics']['total_comments']:,}\n" context += f" • Avg Score: {analysis['engagement_metrics']['avg_score']:.1f}\n" context += f" • Top Post Score: {analysis['engagement_metrics']['top_post_score']:,}\n\n" # Sentiment analysis sentiment_dist = analysis['sentiment_analysis']['distribution'] reliable_ratio = analysis['sentiment_analysis']['sentiment_reliability'] context += f"😊 Sentiment Analysis:\n" context += f" • Positive: {sentiment_dist['positive']} posts\n" context += f" • Negative: {sentiment_dist['negative']} posts\n" context += f" • Neutral: {sentiment_dist['neutral']} posts\n" context += f" • Reliability: {reliable_ratio:.1%} of posts have confident sentiment scores\n\n" # Ticker mentions ticker_info = analysis['ticker_mentions'] context += f"šŸ’° Stock Mentions:\n" context += f" • Unique Tickers: {ticker_info['unique_tickers']}\n" context += f" • Total Mentions: {ticker_info['total_mentions']}\n" if ticker_info['top_tickers']: context += f" • Most Discussed:\n" for ticker_symbol, count in ticker_info['top_tickers'][:5]: context += f" - ${ticker_symbol}: {count} mentions\n" # Filter for specific ticker if requested if ticker: ticker_mentions = next( (count for symbol, count in ticker_info['top_tickers'] if symbol == ticker.upper()), 0 ) if ticker_mentions > 0: context += f"\nšŸŽÆ ${ticker} Activity: {ticker_mentions} mentions in this period\n" else: context += f"\nšŸŽÆ ${ticker}: No mentions found in r/{subreddit} during this period\n" # Data quality quality = analysis['data_quality']['overall_score'] context += f"\nšŸ“‹ Data Quality Score: {quality:.1%}\n" return context except Exception as e: logger.error(f"Error analyzing subreddit {subreddit}: {e}") return f"Error analyzing r/{subreddit}: {str(e)}" ``` **Acceptance Criteria:** - [ ] get_reddit_sentiment() provides comprehensive sentiment analysis with visual formatting - [ ] get_reddit_stock_info() supports both general info and semantic search queries - [ ] search_social_posts() enables semantic search across all social media content - [ ] get_subreddit_analysis() provides detailed subreddit activity and ticker analysis - [ ] All methods return human-readable formatted strings for AI agent consumption - [ ] Proper error handling with fallback responses - [ ] Methods integrate seamlessly with existing AgentToolkit patterns - [ ] Test coverage with mocked service dependencies **Dependencies:** Task 2.4 (SocialMediaService implementation) **Risk:** Low - Standard AgentToolkit integration patterns --- ### Task 3.2: Dagster Pipeline Implementation (2 hours) **Priority: Medium** | **Agent: Pipeline Specialist** Implement Dagster asset for scheduled social media collection and processing. **File:** `tradingagents/data/assets/social_media.py` **Implementation:** ```python from dagster import asset, AssetExecutionContext, Config, DailyPartitionsDefinition from typing import Dict, Any, List import asyncio from datetime import datetime, timedelta from tradingagents.domains.socialmedia.services import SocialMediaService from tradingagents.config import TradingAgentsConfig from tradingagents.database import DatabaseManager class SocialMediaCollectionConfig(Config): """Configuration for social media collection""" subreddits: List[str] = [ 'wallstreetbets', 'investing', 'stocks', 'SecurityAnalysis', 'ValueInvesting', 'StockMarket', 'options' ] time_filter: str = 'day' process_sentiment: bool = True generate_embeddings: bool = True max_posts_per_subreddit: int = 50 cleanup_old_data: bool = True retention_days: int = 90 @asset( partitions_def=DailyPartitionsDefinition(start_date="2024-01-01"), group_name="social_media", description="Daily collection of Reddit posts from financial subreddits with sentiment analysis and embeddings", compute_kind="python", tags={"domain": "socialmedia", "source": "reddit"} ) async def reddit_financial_posts( context: AssetExecutionContext, config: SocialMediaCollectionConfig ) -> Dict[str, Any]: """Daily collection and processing of Reddit financial posts""" partition_date = context.partition_key context.log.info(f"Starting social media collection for partition: {partition_date}") # Initialize services trading_config = TradingAgentsConfig.from_env() db_manager = DatabaseManager(trading_config) social_service = SocialMediaService(trading_config, db_manager) collection_start = datetime.now() try: # Main collection and processing results = await social_service.collect_and_process_posts( subreddits=config.subreddits, time_filter=config.time_filter, process_sentiment=config.process_sentiment, generate_embeddings=config.generate_embeddings ) # Log detailed results context.log.info(f"Collection completed successfully:") context.log.info(f" - Total posts collected: {results['total_posts_collected']}") context.log.info(f" - Posts processed: {results['posts_processed']}") context.log.info(f" - Posts saved: {results['posts_saved']}") context.log.info(f" - Processing time: {results['processing_time_seconds']:.2f}s") # Log per-subreddit breakdown for subreddit, breakdown in results['subreddit_breakdown'].items(): context.log.info(f" - r/{subreddit}: {breakdown['posts_collected']} collected, " f"{breakdown['posts_filtered']} after filtering") # Data quality check if results['posts_saved'] == 0: context.log.warning("No posts were saved - possible data quality issues") elif results['posts_saved'] < results['posts_processed'] * 0.5: context.log.warning(f"Low save rate: {results['posts_saved']}/{results['posts_processed']} posts saved") # Cleanup old data if configured if config.cleanup_old_data: try: deleted_count = await social_service.repository.cleanup_old_posts( days=config.retention_days ) context.log.info(f"Cleaned up {deleted_count} posts older than {config.retention_days} days") results['cleanup_deleted_count'] = deleted_count except Exception as e: context.log.error(f"Cleanup failed: {e}") results['cleanup_error'] = str(e) # Add partition metadata results.update({ 'partition_date': partition_date, 'asset_name': 'reddit_financial_posts', 'collection_success': True }) return results except Exception as e: context.log.error(f"Social media collection failed: {e}") # Return error results for monitoring return { 'partition_date': partition_date, 'asset_name': 'reddit_financial_posts', 'collection_success': False, 'error_message': str(e), 'processing_time_seconds': (datetime.now() - collection_start).total_seconds(), 'total_posts_collected': 0, 'posts_processed': 0, 'posts_saved': 0 } finally: # Always close database connections if 'db_manager' in locals(): await db_manager.close_all() @asset( deps=[reddit_financial_posts], group_name="social_media", description="Generate daily social media analytics and trending analysis", compute_kind="python", tags={"domain": "socialmedia", "analytics": "trending"} ) async def social_media_analytics(context: AssetExecutionContext) -> Dict[str, Any]: """Generate analytics and trending analysis from collected social media data""" context.log.info("Generating social media analytics") # Initialize services trading_config = TradingAgentsConfig.from_env() db_manager = DatabaseManager(trading_config) social_service = SocialMediaService(trading_config, db_manager) try: # Get trending tickers analysis trending_tickers = await social_service.repository.get_trending_tickers( hours=24, min_mentions=5 ) context.log.info(f"Found {len(trending_tickers)} trending tickers") # Analyze top subreddits financial_subreddits = [ 'wallstreetbets', 'investing', 'stocks', 'SecurityAnalysis', 'ValueInvesting', 'StockMarket' ] subreddit_analysis = {} for subreddit in financial_subreddits: analysis = await social_service.get_subreddit_analysis(subreddit, hours=24) subreddit_analysis[subreddit] = analysis if analysis['total_posts'] > 0: context.log.info(f"r/{subreddit}: {analysis['total_posts']} posts, " f"{analysis['ticker_mentions']['unique_tickers']} unique tickers") # Calculate overall sentiment trends overall_sentiment_summary = {} for ticker_info in trending_tickers[:10]: # Top 10 trending ticker = ticker_info['ticker'] sentiment_data = await social_service.repository.get_sentiment_summary( ticker=ticker, hours=24 ) overall_sentiment_summary[ticker] = sentiment_data analytics_results = { 'generated_at': datetime.now().isoformat(), 'period_hours': 24, 'trending_tickers': trending_tickers, 'subreddit_analysis': subreddit_analysis, 'sentiment_trends': overall_sentiment_summary, 'analytics_success': True } # Log key insights if trending_tickers: top_ticker = trending_tickers[0] context.log.info(f"Most trending ticker: ${top_ticker['ticker']} " f"({top_ticker['mention_count']} mentions, " f"{top_ticker['avg_sentiment']:.2f} sentiment)") return analytics_results except Exception as e: context.log.error(f"Analytics generation failed: {e}") return { 'generated_at': datetime.now().isoformat(), 'analytics_success': False, 'error_message': str(e) } finally: if 'db_manager' in locals(): await db_manager.close_all() @asset( deps=[social_media_analytics], group_name="social_media", description="Data quality monitoring and validation for social media pipeline", compute_kind="python", tags={"domain": "socialmedia", "monitoring": "data_quality"} ) async def social_media_quality_check(context: AssetExecutionContext) -> Dict[str, Any]: """Monitor data quality and pipeline health for social media assets""" context.log.info("Performing social media data quality checks") trading_config = TradingAgentsConfig.from_env() db_manager = DatabaseManager(trading_config) social_service = SocialMediaService(trading_config, db_manager) try: # Check recent data volume recent_posts = await social_service.repository.find_by_subreddit( 'wallstreetbets', # Use as representative subreddit hours=24, limit=1000 ) # Quality metrics total_posts = len(recent_posts) posts_with_sentiment = sum(1 for p in recent_posts if p.sentiment_label is not None) posts_with_embeddings = sum( 1 for p in recent_posts if p.title_embedding is not None or p.content_embedding is not None ) posts_with_tickers = sum(1 for p in recent_posts if p.tickers) # Calculate quality percentages sentiment_coverage = posts_with_sentiment / total_posts if total_posts > 0 else 0 embedding_coverage = posts_with_embeddings / total_posts if total_posts > 0 else 0 ticker_coverage = posts_with_tickers / total_posts if total_posts > 0 else 0 # Quality thresholds quality_checks = { 'data_volume_check': total_posts >= 100, # Expect at least 100 posts per day 'sentiment_coverage_check': sentiment_coverage >= 0.8, # 80% should have sentiment 'embedding_coverage_check': embedding_coverage >= 0.7, # 70% should have embeddings 'ticker_coverage_check': ticker_coverage >= 0.3 # 30% should have ticker mentions } overall_health = all(quality_checks.values()) # Log quality results context.log.info(f"Data quality assessment:") context.log.info(f" - Total posts (24h): {total_posts}") context.log.info(f" - Sentiment coverage: {sentiment_coverage:.1%}") context.log.info(f" - Embedding coverage: {embedding_coverage:.1%}") context.log.info(f" - Ticker coverage: {ticker_coverage:.1%}") context.log.info(f" - Overall health: {'PASS' if overall_health else 'FAIL'}") # Alert on quality issues for check_name, passed in quality_checks.items(): if not passed: context.log.warning(f"Quality check failed: {check_name}") return { 'check_timestamp': datetime.now().isoformat(), 'total_posts_24h': total_posts, 'quality_metrics': { 'sentiment_coverage': sentiment_coverage, 'embedding_coverage': embedding_coverage, 'ticker_coverage': ticker_coverage }, 'quality_checks': quality_checks, 'overall_health': overall_health, 'quality_check_success': True } except Exception as e: context.log.error(f"Quality check failed: {e}") return { 'check_timestamp': datetime.now().isoformat(), 'quality_check_success': False, 'error_message': str(e) } finally: if 'db_manager' in locals(): await db_manager.close_all() # Schedule configuration for the social media pipeline SOCIAL_MEDIA_SCHEDULE = { "reddit_financial_posts": "0 6,18 * * *", # 6 AM and 6 PM UTC daily "social_media_analytics": "30 7,19 * * *", # 30 minutes after collection "social_media_quality_check": "0 8,20 * * *" # 1 hour after collection } ``` **Acceptance Criteria:** - [ ] Daily scheduled collection from financial subreddits - [ ] Sentiment analysis and embedding generation in pipeline - [ ] Analytics generation with trending ticker analysis - [ ] Data quality monitoring with configurable thresholds - [ ] Proper error handling and logging throughout pipeline - [ ] Cleanup of old data based on retention policies - [ ] Integration with existing Dagster infrastructure - [ ] Monitoring and alerting on pipeline failures **Dependencies:** Task 2.4 (SocialMediaService) **Risk:** Low - Standard Dagster asset patterns --- ### Task 3.3: Comprehensive Testing Suite (3 hours) **Priority: High** | **Agent: Testing Specialist** Implement comprehensive test suite covering all socialmedia domain components with >85% coverage. **Test Structure:** ``` tests/domains/socialmedia/ ā”œā”€ā”€ conftest.py # Fixtures and test configuration ā”œā”€ā”€ test_entities.py # SQLAlchemy entity tests ā”œā”€ā”€ test_models.py # Domain model validation tests ā”œā”€ā”€ test_reddit_client.py # API integration with VCR ā”œā”€ā”€ test_sentiment_analyzer.py # LLM sentiment analysis ā”œā”€ā”€ test_embedding_generator.py # Vector embedding generation ā”œā”€ā”€ test_social_repository.py # Database operations ā”œā”€ā”€ test_social_service.py # Service orchestration ā”œā”€ā”€ test_agent_toolkit.py # AgentToolkit integration ā”œā”€ā”€ test_dagster_assets.py # Pipeline testing └── fixtures/ ā”œā”€ā”€ reddit_responses.yaml # VCR cassettes ā”œā”€ā”€ sample_posts.json # Test data └── embeddings.json # Sample embeddings ``` **Implementation Samples:** **conftest.py:** ```python import pytest import asyncio from unittest.mock import MagicMock, AsyncMock from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from tradingagents.config import TradingAgentsConfig from tradingagents.database import DatabaseManager from tradingagents.domains.socialmedia.entities import SocialMediaPostEntity from tradingagents.domains.socialmedia.models import SocialPost from tradingagents.domains.socialmedia.services import SocialMediaService @pytest.fixture(scope="session") def event_loop(): """Create event loop for async tests""" loop = asyncio.new_event_loop() yield loop loop.close() @pytest.fixture def test_config(): """Test configuration""" return TradingAgentsConfig( reddit_client_id="test_client_id", reddit_client_secret="test_secret", reddit_user_agent="test_agent", openrouter_api_key="test_openrouter_key", quick_think_llm="test/model", database_url="sqlite:///test.db" ) @pytest.fixture async def db_session(test_config): """Test database session""" engine = create_engine(test_config.database_url, echo=False) SocialMediaPostEntity.metadata.create_all(engine) SessionLocal = sessionmaker(bind=engine) session = SessionLocal() yield session session.close() SocialMediaPostEntity.metadata.drop_all(engine) @pytest.fixture def sample_social_post(): """Sample SocialPost for testing""" return SocialPost( post_id="test123", title="AAPL to the moon! šŸš€", content="Apple stock is going to explode higher after earnings!", author="test_user", subreddit="wallstreetbets", created_utc=datetime(2024, 1, 15, 10, 0, 0), upvotes=150, downvotes=25, comments_count=45, url="https://reddit.com/r/wallstreetbets/test123", tickers=["AAPL"], sentiment_score=0.8, sentiment_label="positive", sentiment_confidence=0.9 ) @pytest.fixture def mock_social_service(test_config): """Mocked SocialMediaService""" service = MagicMock(spec=SocialMediaService) service.config = test_config service.repository = AsyncMock() service.sentiment_analyzer = AsyncMock() service.embedding_generator = AsyncMock() return service ``` **test_models.py:** ```python import pytest from datetime import datetime from tradingagents.domains.socialmedia.models import SocialPost, SentimentScore def test_social_post_validation(): """Test SocialPost validation rules""" # Valid post post = SocialPost( post_id="abc123", title="Test post", author="test_user", subreddit="stocks", created_utc=datetime.now(), upvotes=10, downvotes=2, comments_count=5, url="https://reddit.com/test" ) assert post.post_id == "abc123" assert post.tickers == [] def test_extract_tickers(): """Test ticker extraction from post content""" post = SocialPost( post_id="abc123", title="AAPL and $TSLA are great buys", content="I think MSFT will outperform this year", author="test_user", subreddit="investing", created_utc=datetime.now(), upvotes=10, downvotes=0, comments_count=3, url="https://reddit.com/test" ) tickers = post.extract_tickers() assert "AAPL" in tickers assert "TSLA" in tickers assert "MSFT" in tickers assert len(tickers) == 3 def test_sentiment_validation(): """Test sentiment score validation""" # Valid sentiment sentiment = SentimentScore( sentiment="positive", confidence=0.85, reasoning="Bullish language and positive outlook" ) assert sentiment.confidence == 0.85 # Invalid confidence with pytest.raises(ValueError): SentimentScore( sentiment="positive", confidence=1.5 # > 1.0 ) @pytest.mark.parametrize("sentiment_score,sentiment_label,confidence,expected_reliable", [ (0.8, "positive", 0.9, True), (0.3, "neutral", 0.4, False), (-0.6, "negative", 0.7, True), (None, None, None, False) ]) def test_has_reliable_sentiment(sentiment_score, sentiment_label, confidence, expected_reliable): """Test sentiment reliability check""" post = SocialPost( post_id="test", title="Test", author="user", subreddit="test", created_utc=datetime.now(), upvotes=1, downvotes=0, comments_count=0, url="test", sentiment_score=sentiment_score, sentiment_label=sentiment_label, sentiment_confidence=confidence ) assert post.has_reliable_sentiment() == expected_reliable ``` **test_social_repository.py:** ```python import pytest from datetime import datetime, timedelta from tradingagents.domains.socialmedia.repositories import SocialRepository from tradingagents.domains.socialmedia.models import SocialPost @pytest.mark.asyncio async def test_upsert_batch_deduplication(social_repository, sample_social_post): """Test batch upsert with deduplication""" posts = [sample_social_post, sample_social_post] # Duplicate posts saved_ids = await social_repository.upsert_batch(posts) assert len(saved_ids) == 1 # Only one saved due to deduplication assert saved_ids[0] == sample_social_post.post_id @pytest.mark.asyncio async def test_find_by_ticker(social_repository, sample_social_post): """Test finding posts by ticker symbol""" await social_repository.upsert_batch([sample_social_post]) posts = await social_repository.find_by_ticker("AAPL", days=7) assert len(posts) == 1 assert posts[0].post_id == sample_social_post.post_id assert "AAPL" in posts[0].tickers @pytest.mark.asyncio async def test_vector_similarity_search(social_repository, sample_social_post): """Test vector similarity search""" # Add post with embedding sample_social_post.title_embedding = [0.1] * 1536 # Mock embedding await social_repository.upsert_batch([sample_social_post]) # Search with similar embedding query_embedding = [0.1] * 1536 results = await social_repository.find_similar_posts( query_embedding=query_embedding, limit=5 ) assert len(results) >= 0 # May be empty if similarity too low if results: post, similarity = results[0] assert isinstance(similarity, float) assert 0 <= similarity <= 1 @pytest.mark.asyncio async def test_sentiment_summary(social_repository, sample_social_post): """Test sentiment aggregation""" await social_repository.upsert_batch([sample_social_post]) summary = await social_repository.get_sentiment_summary( ticker="AAPL", hours=24 ) assert summary['ticker'] == "AAPL" assert summary['total_posts'] >= 0 assert 'sentiment_breakdown' in summary assert 'overall_sentiment' in summary @pytest.mark.asyncio async def test_cleanup_old_posts(social_repository, sample_social_post): """Test cleanup of old posts""" # Create old post old_post = sample_social_post.copy() old_post.post_id = "old_post" old_post.created_utc = datetime.now() - timedelta(days=100) await social_repository.upsert_batch([old_post]) deleted_count = await social_repository.cleanup_old_posts(days=90) assert deleted_count >= 1 ``` **test_reddit_client.py (with VCR):** ```python import pytest import pytest_vcr from tradingagents.domains.socialmedia.clients import RedditClient @pytest_vcr.use_cassette('fixtures/reddit_fetch_posts.yaml') @pytest.mark.asyncio async def test_fetch_subreddit_posts(test_config): """Test fetching posts from Reddit API""" async with RedditClient(test_config) as client: posts = await client.fetch_subreddit_posts( subreddit_name="wallstreetbets", limit=10 ) assert len(posts) > 0 for post in posts: assert 'post_id' in post assert 'title' in post assert 'subreddit' in post assert post['subreddit'] == 'wallstreetbets' @pytest_vcr.use_cassette('fixtures/reddit_search.yaml') @pytest.mark.asyncio async def test_search_posts(test_config): """Test Reddit post search functionality""" async with RedditClient(test_config) as client: posts = await client.search_posts( query="AAPL", subreddit_names=["investing"], limit=5 ) assert isinstance(posts, list) if posts: # May be empty in test for post in posts: assert 'post_id' in post assert 'title' in post @pytest.mark.asyncio async def test_health_check(test_config): """Test Reddit API health check""" async with RedditClient(test_config) as client: health = await client.health_check() assert isinstance(health, bool) ``` **Acceptance Criteria:** - [ ] >85% test coverage across all socialmedia domain components - [ ] Unit tests for all domain models with validation edge cases - [ ] Integration tests for Reddit API client with VCR cassettes - [ ] Repository tests with real PostgreSQL database operations - [ ] Service layer tests with proper mocking of dependencies - [ ] AgentToolkit integration tests - [ ] Dagster pipeline asset tests with mocked data - [ ] Performance benchmarks for vector similarity queries - [ ] Error handling and edge case coverage - [ ] Test fixtures and sample data for consistent testing **Dependencies:** All implementation tasks **Risk:** Low - Standard testing patterns --- ## Implementation Dependencies & Parallel Execution ### Phase 1 Dependencies - Task 1.1 → Task 1.2 (Entity depends on database schema) - Task 1.3 can run parallel with 1.1 and 1.2 - Task 1.4 depends on 1.1 and 1.2 ### Phase 2 Dependencies - All Phase 2 tasks can run in parallel - Task 2.4 depends on 2.1, 2.2, and 2.3 ### Phase 3 Dependencies - Task 3.1 depends on Task 2.4 - Task 3.2 depends on Task 2.4 - Task 3.3 can start after any component is complete ### Risk Assessment **High Risk Tasks:** - Task 2.1 (Reddit Client) - External API complexity, rate limiting **Medium Risk Tasks:** - Task 1.1 (Database Migration) - Extension dependencies - Task 1.4 (Repository) - Complex vector queries - Task 2.2 (Sentiment Analysis) - LLM API reliability - Task 2.4 (Service Layer) - Complex orchestration **Low Risk Tasks:** - Task 1.2 (Entity Implementation) - Task 1.3 (Domain Models) - Task 2.3 (Embedding Generation) - Task 3.1 (AgentToolkit Integration) - Task 3.2 (Dagster Pipeline) - Task 3.3 (Testing Suite) ## Success Criteria Summary ### Functionality - āœ… Complete Reddit data collection with PRAW integration - āœ… OpenRouter LLM sentiment analysis with confidence scoring - āœ… Vector embeddings for semantic similarity search - āœ… PostgreSQL + TimescaleDB + pgvectorscale data persistence - āœ… AgentToolkit RAG methods for AI agent integration - āœ… Daily Dagster pipeline for automated collection - āœ… Comprehensive error handling and resilience ### Performance - āœ… <2 second social context queries for AI agents - āœ… <1 second vector similarity search (top 10 results) - āœ… <5 seconds batch processing 1000 posts - āœ… Efficient TimescaleDB time-series queries ### Quality - āœ… >85% test coverage across all components - āœ… Data quality monitoring and validation - āœ… Comprehensive logging and observability - āœ… Best-effort processing with graceful degradation ### Integration - āœ… Seamless integration with existing TradingAgents architecture - āœ… Follows news domain patterns for consistency - āœ… Compatible with multi-agent trading workflows - āœ… Production-ready deployment capability This comprehensive task breakdown enables efficient parallel development by multiple AI agents while ensuring complete coverage of the socialmedia domain implementation requirements.