2729 lines
106 KiB
Markdown
2729 lines
106 KiB
Markdown
# Social Media Domain Implementation Tasks
|
|
|
|
## Overview
|
|
|
|
Complete greenfield implementation of the socialmedia domain from empty stubs to production-ready system with PRAW Reddit API integration, PostgreSQL migration, OpenRouter LLM sentiment analysis, and AgentToolkit RAG methods.
|
|
|
|
**Total Estimated Time: 32 hours (3-phase parallel development approach)**
|
|
|
|
## Phase Structure
|
|
|
|
### Phase 1: Foundation (12 hours) - Database & Core Models
|
|
**Parallel Execution Ready**: Multiple agents can work on different components simultaneously
|
|
|
|
### Phase 2: API Integration & Processing (12 hours) - Clients & Services
|
|
**Parallel Execution Ready**: API clients and LLM services can be developed in parallel
|
|
|
|
### Phase 3: Integration & Validation (8 hours) - AgentToolkit & Dagster
|
|
**Parallel Execution Ready**: AgentToolkit and pipeline development with comprehensive testing
|
|
|
|
---
|
|
|
|
## Phase 1: Foundation (12 hours)
|
|
|
|
### Task 1.1: Database Schema Migration (3 hours)
|
|
**Priority: Blocking** | **Agent: Database Specialist**
|
|
|
|
Create PostgreSQL migration for social_media_posts table with TimescaleDB and pgvectorscale support.
|
|
|
|
**Implementation:**
|
|
```sql
|
|
-- Migration: 003_create_social_media_posts.sql
|
|
CREATE TABLE social_media_posts (
|
|
id UUID PRIMARY KEY DEFAULT uuid7(),
|
|
post_id VARCHAR(50) UNIQUE NOT NULL,
|
|
title TEXT NOT NULL,
|
|
content TEXT,
|
|
author VARCHAR(100) NOT NULL,
|
|
subreddit VARCHAR(50) NOT NULL,
|
|
created_utc TIMESTAMPTZ NOT NULL,
|
|
upvotes INTEGER NOT NULL DEFAULT 0,
|
|
downvotes INTEGER NOT NULL DEFAULT 0,
|
|
comments_count INTEGER NOT NULL DEFAULT 0,
|
|
url TEXT NOT NULL,
|
|
sentiment_score JSONB,
|
|
sentiment_label VARCHAR(20),
|
|
tickers TEXT[] DEFAULT '{}',
|
|
title_embedding VECTOR(1536),
|
|
content_embedding VECTOR(1536),
|
|
inserted_at TIMESTAMPTZ DEFAULT NOW(),
|
|
updated_at TIMESTAMPTZ DEFAULT NOW()
|
|
);
|
|
|
|
SELECT create_hypertable('social_media_posts', 'created_utc', chunk_time_interval => INTERVAL '1 day');
|
|
|
|
-- Performance indexes
|
|
CREATE UNIQUE INDEX idx_social_posts_post_id ON social_media_posts (post_id);
|
|
CREATE INDEX idx_social_posts_subreddit_time ON social_media_posts (subreddit, created_utc DESC);
|
|
CREATE INDEX idx_social_posts_tickers_gin ON social_media_posts USING GIN (tickers);
|
|
CREATE INDEX idx_social_posts_title_embedding ON social_media_posts USING vectors (title_embedding vector_cosine_ops);
|
|
CREATE INDEX idx_social_posts_content_embedding ON social_media_posts USING vectors (content_embedding vector_cosine_ops);
|
|
CREATE INDEX idx_social_posts_sentiment ON social_media_posts (((sentiment_score->>'sentiment'))) WHERE sentiment_score IS NOT NULL;
|
|
|
|
-- Constraints
|
|
ALTER TABLE social_media_posts ADD CONSTRAINT chk_sentiment_score CHECK (
|
|
sentiment_score IS NULL OR ((sentiment_score->>'confidence')::float BETWEEN 0 AND 1)
|
|
);
|
|
ALTER TABLE social_media_posts ADD CONSTRAINT chk_created_utc CHECK (created_utc <= NOW());
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] Migration script creates social_media_posts table
|
|
- [ ] TimescaleDB hypertable configured for time-series optimization
|
|
- [ ] pgvectorscale indexes for title_embedding and content_embedding
|
|
- [ ] All constraints and indexes properly created
|
|
- [ ] Migration runs successfully in test and development environments
|
|
|
|
**Dependencies:** PostgreSQL + TimescaleDB + pgvectorscale installed
|
|
**Risk:** Medium - Extension compatibility issues
|
|
|
|
---
|
|
|
|
### Task 1.2: SQLAlchemy Entity Implementation (3 hours)
|
|
**Priority: Blocking** | **Agent: Entity Specialist**
|
|
|
|
Create SocialMediaPostEntity with proper field mappings and domain transformations.
|
|
|
|
**File:** `tradingagents/domains/socialmedia/entities.py`
|
|
|
|
**Implementation:**
|
|
```python
|
|
from sqlalchemy import Column, String, Text, Integer, TIMESTAMP, Index
|
|
from sqlalchemy.dialects.postgresql import UUID, VECTOR, ARRAY, JSONB
|
|
from sqlalchemy.sql import func
|
|
from tradingagents.database.base import Base
|
|
from typing import Optional, List, Dict, Any
|
|
import uuid
|
|
|
|
class SocialMediaPostEntity(Base):
|
|
__tablename__ = 'social_media_posts'
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
post_id = Column(String(50), unique=True, nullable=False, index=True)
|
|
title = Column(Text, nullable=False)
|
|
content = Column(Text)
|
|
author = Column(String(100), nullable=False)
|
|
subreddit = Column(String(50), nullable=False, index=True)
|
|
created_utc = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
|
|
upvotes = Column(Integer, nullable=False, default=0)
|
|
downvotes = Column(Integer, nullable=False, default=0)
|
|
comments_count = Column(Integer, nullable=False, default=0)
|
|
url = Column(Text, nullable=False)
|
|
|
|
# Enhanced fields
|
|
sentiment_score = Column(JSONB)
|
|
sentiment_label = Column(String(20))
|
|
tickers = Column(ARRAY(String(10)), default=lambda: [])
|
|
title_embedding = Column(VECTOR(1536))
|
|
content_embedding = Column(VECTOR(1536))
|
|
|
|
# Metadata
|
|
inserted_at = Column(TIMESTAMP(timezone=True), server_default=func.now())
|
|
updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now())
|
|
|
|
def to_domain(self) -> 'SocialPost':
|
|
"""Convert entity to domain model with proper field mapping"""
|
|
sentiment_data = self.sentiment_score or {}
|
|
return SocialPost(
|
|
post_id=self.post_id,
|
|
title=self.title,
|
|
content=self.content,
|
|
author=self.author,
|
|
subreddit=self.subreddit,
|
|
created_utc=self.created_utc,
|
|
upvotes=self.upvotes,
|
|
downvotes=self.downvotes,
|
|
comments_count=self.comments_count,
|
|
url=self.url,
|
|
sentiment_score=sentiment_data.get('score'),
|
|
sentiment_label=self.sentiment_label,
|
|
sentiment_confidence=sentiment_data.get('confidence'),
|
|
tickers=list(self.tickers) if self.tickers else [],
|
|
title_embedding=list(self.title_embedding) if self.title_embedding else None,
|
|
content_embedding=list(self.content_embedding) if self.content_embedding else None
|
|
)
|
|
|
|
@classmethod
|
|
def from_domain(cls, post: 'SocialPost') -> 'SocialMediaPostEntity':
|
|
"""Create entity from domain model"""
|
|
sentiment_data = None
|
|
if post.sentiment_score is not None and post.sentiment_confidence is not None:
|
|
sentiment_data = {
|
|
'score': post.sentiment_score,
|
|
'confidence': post.sentiment_confidence,
|
|
'reasoning': getattr(post, 'sentiment_reasoning', None)
|
|
}
|
|
|
|
return cls(
|
|
post_id=post.post_id,
|
|
title=post.title,
|
|
content=post.content,
|
|
author=post.author,
|
|
subreddit=post.subreddit,
|
|
created_utc=post.created_utc,
|
|
upvotes=post.upvotes,
|
|
downvotes=post.downvotes,
|
|
comments_count=post.comments_count,
|
|
url=post.url,
|
|
sentiment_score=sentiment_data,
|
|
sentiment_label=post.sentiment_label,
|
|
tickers=post.tickers or [],
|
|
title_embedding=post.title_embedding,
|
|
content_embedding=post.content_embedding
|
|
)
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] SocialMediaPostEntity properly maps all database fields
|
|
- [ ] to_domain() and from_domain() methods handle all field conversions
|
|
- [ ] Proper handling of vector fields and JSONB sentiment data
|
|
- [ ] Entity integrates with existing database session management
|
|
- [ ] All field types match database schema exactly
|
|
|
|
**Dependencies:** Task 1.1 (database schema)
|
|
**Risk:** Low - Standard SQLAlchemy patterns
|
|
|
|
---
|
|
|
|
### Task 1.3: Domain Model Enhancement (3 hours)
|
|
**Priority: Blocking** | **Agent: Domain Specialist**
|
|
|
|
Enhance SocialPost domain entity with comprehensive validation, transformations, and business rules.
|
|
|
|
**File:** `tradingagents/domains/socialmedia/models.py`
|
|
|
|
**Implementation:**
|
|
```python
|
|
from pydantic import BaseModel, Field, validator, root_validator
|
|
from typing import Optional, List, Dict, Any
|
|
from datetime import datetime
|
|
import re
|
|
|
|
class SentimentScore(BaseModel):
|
|
"""Structured sentiment analysis result from OpenRouter LLM"""
|
|
sentiment: Literal['positive', 'negative', 'neutral']
|
|
confidence: float = Field(..., ge=0.0, le=1.0)
|
|
reasoning: Optional[str] = None
|
|
|
|
@validator('reasoning')
|
|
def reasoning_not_empty(cls, v):
|
|
if v is not None and len(v.strip()) == 0:
|
|
return None
|
|
return v
|
|
|
|
class SocialPost(BaseModel):
|
|
"""Core domain entity with business rules and transformations"""
|
|
# Base fields from Reddit API
|
|
post_id: str = Field(..., regex=r'^[a-zA-Z0-9_-]+$')
|
|
title: str = Field(..., min_length=1, max_length=300)
|
|
content: Optional[str] = None
|
|
author: str = Field(..., min_length=1, max_length=100)
|
|
subreddit: str = Field(..., min_length=1, max_length=50)
|
|
created_utc: datetime
|
|
upvotes: int = Field(..., ge=0)
|
|
downvotes: int = Field(..., ge=0)
|
|
comments_count: int = Field(..., ge=0)
|
|
url: str = Field(..., min_length=1)
|
|
|
|
# Enhanced fields
|
|
sentiment_score: Optional[float] = Field(None, ge=-1.0, le=1.0)
|
|
sentiment_label: Optional[str] = Field(None, regex=r'^(positive|negative|neutral)$')
|
|
sentiment_confidence: Optional[float] = Field(None, ge=0.0, le=1.0)
|
|
sentiment_reasoning: Optional[str] = None
|
|
tickers: Optional[List[str]] = Field(default_factory=list)
|
|
title_embedding: Optional[List[float]] = None
|
|
content_embedding: Optional[List[float]] = None
|
|
|
|
@validator('tickers')
|
|
def validate_tickers(cls, v):
|
|
"""Validate ticker symbols format"""
|
|
if v is None:
|
|
return []
|
|
# Ensure tickers are uppercase and valid format
|
|
return [ticker.upper() for ticker in v if re.match(r'^[A-Z]{1,5}$', ticker.upper())]
|
|
|
|
@validator('title_embedding', 'content_embedding')
|
|
def validate_embedding_dimensions(cls, v):
|
|
"""Ensure embeddings have correct dimensions"""
|
|
if v is not None and len(v) != 1536:
|
|
raise ValueError('Embedding must be 1536 dimensions')
|
|
return v
|
|
|
|
@root_validator
|
|
def validate_sentiment_consistency(cls, values):
|
|
"""Ensure sentiment fields are consistent"""
|
|
score = values.get('sentiment_score')
|
|
label = values.get('sentiment_label')
|
|
confidence = values.get('sentiment_confidence')
|
|
|
|
# All sentiment fields should be present or all None
|
|
sentiment_fields = [score, label, confidence]
|
|
non_none_count = sum(1 for field in sentiment_fields if field is not None)
|
|
|
|
if non_none_count > 0 and non_none_count < 3:
|
|
raise ValueError('All sentiment fields (score, label, confidence) must be provided together')
|
|
|
|
return values
|
|
|
|
@classmethod
|
|
def from_praw_submission(cls, submission: Any) -> 'SocialPost':
|
|
"""Create SocialPost from PRAW Reddit submission"""
|
|
return cls(
|
|
post_id=submission.id,
|
|
title=submission.title[:300], # Truncate long titles
|
|
content=submission.selftext if submission.selftext else None,
|
|
author=str(submission.author) if submission.author else '[deleted]',
|
|
subreddit=submission.subreddit.display_name,
|
|
created_utc=datetime.fromtimestamp(submission.created_utc),
|
|
upvotes=submission.ups if hasattr(submission, 'ups') else submission.score,
|
|
downvotes=max(0, submission.score - submission.ups) if hasattr(submission, 'ups') else 0,
|
|
comments_count=submission.num_comments,
|
|
url=f"https://reddit.com{submission.permalink}"
|
|
)
|
|
|
|
def extract_tickers(self) -> List[str]:
|
|
"""Extract ticker symbols from title and content"""
|
|
text = f"{self.title} {self.content or ''}"
|
|
# Look for $TICKER or TICKER patterns
|
|
ticker_pattern = r'\b(?:\$)?([A-Z]{1,5})\b'
|
|
potential_tickers = re.findall(ticker_pattern, text.upper())
|
|
|
|
# Filter out common words that look like tickers
|
|
excluded = {'THE', 'AND', 'OR', 'FOR', 'TO', 'OF', 'IN', 'ON', 'AT', 'BY', 'UP', 'IS', 'IT', 'BE', 'AS', 'ARE', 'WAS', 'HE', 'SHE', 'WE', 'YOU', 'THEY', 'ALL', 'ANY', 'CAN', 'HAD', 'HER', 'HIS', 'HOW', 'ITS', 'MAY', 'NEW', 'NOW', 'OLD', 'SEE', 'TWO', 'WHO', 'BOY', 'DID', 'HAS', 'LET', 'PUT', 'SAY', 'SIX', 'TEN', 'USE', 'WAS', 'WIN', 'YES'}
|
|
|
|
tickers = [ticker for ticker in potential_tickers if ticker not in excluded]
|
|
return list(set(tickers)) # Remove duplicates
|
|
|
|
def has_reliable_sentiment(self) -> bool:
|
|
"""Check if sentiment analysis has sufficient confidence"""
|
|
return (self.sentiment_confidence is not None and
|
|
self.sentiment_confidence >= 0.5)
|
|
|
|
def to_agent_context(self) -> Dict[str, Any]:
|
|
"""Format post for agent consumption"""
|
|
sentiment_emoji = {"positive": "📈", "negative": "📉", "neutral": "➡️"}.get(self.sentiment_label, "❓")
|
|
|
|
return {
|
|
'post_id': self.post_id,
|
|
'subreddit': self.subreddit,
|
|
'title': self.title,
|
|
'content': self.content[:200] + '...' if self.content and len(self.content) > 200 else self.content,
|
|
'author': self.author,
|
|
'created_utc': self.created_utc.isoformat(),
|
|
'engagement': {
|
|
'upvotes': self.upvotes,
|
|
'comments_count': self.comments_count,
|
|
'score': self.upvotes - self.downvotes
|
|
},
|
|
'sentiment': {
|
|
'label': self.sentiment_label,
|
|
'score': self.sentiment_score,
|
|
'confidence': self.sentiment_confidence,
|
|
'emoji': sentiment_emoji,
|
|
'reliable': self.has_reliable_sentiment()
|
|
},
|
|
'tickers': self.tickers or [],
|
|
'url': self.url
|
|
}
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] SocialPost model handles all Reddit API fields properly
|
|
- [ ] Comprehensive validation for all fields including sentiment and embeddings
|
|
- [ ] from_praw_submission() creates valid domain objects from Reddit data
|
|
- [ ] extract_tickers() accurately finds ticker symbols in text
|
|
- [ ] to_agent_context() formats data for AI agent consumption
|
|
- [ ] Business rule validation prevents invalid state combinations
|
|
|
|
**Dependencies:** None (can run parallel with other tasks)
|
|
**Risk:** Low - Standard domain modeling
|
|
|
|
---
|
|
|
|
### Task 1.4: Repository Implementation (3 hours)
|
|
**Priority: Medium** | **Agent: Repository Specialist**
|
|
|
|
Implement SocialRepository with PostgreSQL operations, vector similarity search, and performance optimization.
|
|
|
|
**File:** `tradingagents/domains/socialmedia/repositories.py`
|
|
|
|
**Implementation:**
|
|
```python
|
|
from typing import List, Optional, Dict, Any, Tuple
|
|
from sqlalchemy import and_, or_, desc, text, func
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.exc import IntegrityError
|
|
from tradingagents.domains.socialmedia.entities import SocialMediaPostEntity
|
|
from tradingagents.domains.socialmedia.models import SocialPost
|
|
from tradingagents.database import DatabaseManager
|
|
from datetime import datetime, timedelta
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class SocialRepository:
|
|
"""PostgreSQL repository for social media posts with vector search capabilities"""
|
|
|
|
def __init__(self, db_manager: DatabaseManager):
|
|
self.db_manager = db_manager
|
|
|
|
async def upsert_batch(self, posts: List[SocialPost]) -> List[str]:
|
|
"""Batch upsert social media posts with deduplication"""
|
|
async with self.db_manager.get_session() as session:
|
|
saved_ids = []
|
|
|
|
for post in posts:
|
|
try:
|
|
# Check for existing post
|
|
existing = await session.execute(
|
|
text("SELECT id FROM social_media_posts WHERE post_id = :post_id"),
|
|
{"post_id": post.post_id}
|
|
)
|
|
|
|
if existing.first():
|
|
logger.debug(f"Skipping duplicate post: {post.post_id}")
|
|
continue
|
|
|
|
entity = SocialMediaPostEntity.from_domain(post)
|
|
session.add(entity)
|
|
saved_ids.append(post.post_id)
|
|
|
|
except IntegrityError as e:
|
|
logger.warning(f"Integrity error saving post {post.post_id}: {e}")
|
|
await session.rollback()
|
|
continue
|
|
|
|
await session.commit()
|
|
logger.info(f"Saved {len(saved_ids)} new posts to database")
|
|
return saved_ids
|
|
|
|
async def find_by_ticker(self, ticker: str, days: int = 30, limit: int = 50) -> List[SocialPost]:
|
|
"""Find posts mentioning specific ticker symbol"""
|
|
async with self.db_manager.get_session() as session:
|
|
cutoff_date = datetime.now() - timedelta(days=days)
|
|
|
|
result = await session.execute(
|
|
text("""
|
|
SELECT * FROM social_media_posts
|
|
WHERE :ticker = ANY(tickers)
|
|
AND created_utc >= :cutoff_date
|
|
ORDER BY created_utc DESC
|
|
LIMIT :limit
|
|
"""),
|
|
{
|
|
"ticker": ticker.upper(),
|
|
"cutoff_date": cutoff_date,
|
|
"limit": limit
|
|
}
|
|
)
|
|
|
|
entities = [SocialMediaPostEntity(**row) for row in result.mappings()]
|
|
return [entity.to_domain() for entity in entities]
|
|
|
|
async def find_by_subreddit(self, subreddit: str, hours: int = 24, limit: int = 100) -> List[SocialPost]:
|
|
"""Find recent posts from specific subreddit"""
|
|
async with self.db_manager.get_session() as session:
|
|
cutoff_date = datetime.now() - timedelta(hours=hours)
|
|
|
|
result = await session.execute(
|
|
text("""
|
|
SELECT * FROM social_media_posts
|
|
WHERE subreddit = :subreddit
|
|
AND created_utc >= :cutoff_date
|
|
ORDER BY created_utc DESC
|
|
LIMIT :limit
|
|
"""),
|
|
{
|
|
"subreddit": subreddit,
|
|
"cutoff_date": cutoff_date,
|
|
"limit": limit
|
|
}
|
|
)
|
|
|
|
entities = [SocialMediaPostEntity(**row) for row in result.mappings()]
|
|
return [entity.to_domain() for entity in entities]
|
|
|
|
async def find_similar_posts(
|
|
self,
|
|
query_embedding: List[float],
|
|
ticker: Optional[str] = None,
|
|
limit: int = 10,
|
|
similarity_threshold: float = 0.8
|
|
) -> List[Tuple[SocialPost, float]]:
|
|
"""Find similar posts using vector similarity search"""
|
|
async with self.db_manager.get_session() as session:
|
|
embedding_str = str(query_embedding)
|
|
|
|
base_query = """
|
|
SELECT *,
|
|
LEAST(
|
|
1 - (title_embedding <=> :embedding),
|
|
1 - (content_embedding <=> :embedding)
|
|
) as similarity_score
|
|
FROM social_media_posts
|
|
WHERE (title_embedding IS NOT NULL OR content_embedding IS NOT NULL)
|
|
"""
|
|
|
|
params = {"embedding": embedding_str}
|
|
|
|
if ticker:
|
|
base_query += " AND :ticker = ANY(tickers)"
|
|
params["ticker"] = ticker.upper()
|
|
|
|
base_query += """
|
|
AND LEAST(
|
|
1 - (title_embedding <=> :embedding),
|
|
1 - (content_embedding <=> :embedding)
|
|
) >= :threshold
|
|
ORDER BY similarity_score DESC
|
|
LIMIT :limit
|
|
"""
|
|
|
|
params.update({
|
|
"threshold": similarity_threshold,
|
|
"limit": limit
|
|
})
|
|
|
|
result = await session.execute(text(base_query), params)
|
|
|
|
posts_with_scores = []
|
|
for row in result.mappings():
|
|
entity = SocialMediaPostEntity(**{k: v for k, v in row.items() if k != 'similarity_score'})
|
|
post = entity.to_domain()
|
|
similarity = row['similarity_score']
|
|
posts_with_scores.append((post, similarity))
|
|
|
|
return posts_with_scores
|
|
|
|
async def get_sentiment_summary(
|
|
self,
|
|
ticker: Optional[str] = None,
|
|
subreddit: Optional[str] = None,
|
|
hours: int = 24
|
|
) -> Dict[str, Any]:
|
|
"""Get aggregated sentiment analysis for ticker or subreddit"""
|
|
async with self.db_manager.get_session() as session:
|
|
cutoff_date = datetime.now() - timedelta(hours=hours)
|
|
|
|
base_query = """
|
|
SELECT
|
|
sentiment_label,
|
|
COUNT(*) as count,
|
|
AVG((sentiment_score->>'score')::float) as avg_score,
|
|
AVG((sentiment_score->>'confidence')::float) as avg_confidence,
|
|
SUM(upvotes) as total_upvotes,
|
|
SUM(comments_count) as total_comments
|
|
FROM social_media_posts
|
|
WHERE created_utc >= :cutoff_date
|
|
AND sentiment_score IS NOT NULL
|
|
"""
|
|
|
|
params = {"cutoff_date": cutoff_date}
|
|
|
|
if ticker:
|
|
base_query += " AND :ticker = ANY(tickers)"
|
|
params["ticker"] = ticker.upper()
|
|
|
|
if subreddit:
|
|
base_query += " AND subreddit = :subreddit"
|
|
params["subreddit"] = subreddit
|
|
|
|
base_query += " GROUP BY sentiment_label"
|
|
|
|
result = await session.execute(text(base_query), params)
|
|
|
|
sentiment_counts = {}
|
|
total_posts = 0
|
|
weighted_score = 0
|
|
total_engagement = 0
|
|
|
|
for row in result.mappings():
|
|
label = row['sentiment_label']
|
|
count = row['count']
|
|
avg_score = float(row['avg_score'] or 0)
|
|
engagement = (row['total_upvotes'] or 0) + (row['total_comments'] or 0)
|
|
|
|
sentiment_counts[label] = {
|
|
'count': count,
|
|
'avg_score': avg_score,
|
|
'avg_confidence': float(row['avg_confidence'] or 0),
|
|
'engagement': engagement
|
|
}
|
|
|
|
total_posts += count
|
|
weighted_score += avg_score * count
|
|
total_engagement += engagement
|
|
|
|
return {
|
|
'ticker': ticker,
|
|
'subreddit': subreddit,
|
|
'period_hours': hours,
|
|
'total_posts': total_posts,
|
|
'sentiment_breakdown': sentiment_counts,
|
|
'overall_sentiment': weighted_score / total_posts if total_posts > 0 else 0.0,
|
|
'total_engagement': total_engagement,
|
|
'data_quality': {
|
|
'posts_with_sentiment': total_posts,
|
|
'period_start': cutoff_date.isoformat(),
|
|
'generated_at': datetime.now().isoformat()
|
|
}
|
|
}
|
|
|
|
async def cleanup_old_posts(self, days: int = 90) -> int:
|
|
"""Remove posts older than specified days"""
|
|
async with self.db_manager.get_session() as session:
|
|
cutoff_date = datetime.now() - timedelta(days=days)
|
|
|
|
result = await session.execute(
|
|
text("DELETE FROM social_media_posts WHERE created_utc < :cutoff_date"),
|
|
{"cutoff_date": cutoff_date}
|
|
)
|
|
|
|
deleted_count = result.rowcount
|
|
await session.commit()
|
|
|
|
logger.info(f"Cleaned up {deleted_count} posts older than {days} days")
|
|
return deleted_count
|
|
|
|
async def get_trending_tickers(self, hours: int = 24, min_mentions: int = 5) -> List[Dict[str, Any]]:
|
|
"""Find trending ticker symbols by mention frequency and sentiment"""
|
|
async with self.db_manager.get_session() as session:
|
|
cutoff_date = datetime.now() - timedelta(hours=hours)
|
|
|
|
result = await session.execute(
|
|
text("""
|
|
SELECT
|
|
unnest(tickers) as ticker,
|
|
COUNT(*) as mention_count,
|
|
AVG((sentiment_score->>'score')::float) as avg_sentiment,
|
|
SUM(upvotes) as total_upvotes,
|
|
SUM(comments_count) as total_comments
|
|
FROM social_media_posts
|
|
WHERE created_utc >= :cutoff_date
|
|
AND sentiment_score IS NOT NULL
|
|
AND array_length(tickers, 1) > 0
|
|
GROUP BY ticker
|
|
HAVING COUNT(*) >= :min_mentions
|
|
ORDER BY mention_count DESC, avg_sentiment DESC
|
|
LIMIT 20
|
|
"""),
|
|
{
|
|
"cutoff_date": cutoff_date,
|
|
"min_mentions": min_mentions
|
|
}
|
|
)
|
|
|
|
trending = []
|
|
for row in result.mappings():
|
|
trending.append({
|
|
'ticker': row['ticker'],
|
|
'mention_count': row['mention_count'],
|
|
'avg_sentiment': float(row['avg_sentiment'] or 0),
|
|
'total_upvotes': row['total_upvotes'] or 0,
|
|
'total_comments': row['total_comments'] or 0,
|
|
'engagement_score': (row['total_upvotes'] or 0) + (row['total_comments'] or 0)
|
|
})
|
|
|
|
return trending
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] Batch upsert operations with proper deduplication
|
|
- [ ] Vector similarity search using pgvectorscale indexes
|
|
- [ ] Efficient ticker-based queries with TimescaleDB optimization
|
|
- [ ] Comprehensive sentiment aggregation with engagement metrics
|
|
- [ ] Data cleanup operations with configurable retention
|
|
- [ ] Trending ticker analysis with minimum mention thresholds
|
|
- [ ] Proper error handling and logging throughout
|
|
|
|
**Dependencies:** Task 1.1 (database schema), Task 1.2 (entity model)
|
|
**Risk:** Medium - Complex vector search queries
|
|
|
|
---
|
|
|
|
## Phase 2: API Integration & Processing (12 hours)
|
|
|
|
### Task 2.1: Reddit Client Implementation (4 hours)
|
|
**Priority: Blocking** | **Agent: API Integration Specialist**
|
|
|
|
Implement RedditClient using PRAW with comprehensive rate limiting, error handling, and financial subreddit focus.
|
|
|
|
**File:** `tradingagents/domains/socialmedia/clients.py`
|
|
|
|
**Implementation:**
|
|
```python
|
|
import praw
|
|
import asyncio
|
|
from typing import List, Optional, Dict, Any, AsyncIterator
|
|
from datetime import datetime, timedelta
|
|
from tradingagents.config import TradingAgentsConfig
|
|
import logging
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RedditClient:
|
|
"""PRAW-based Reddit client with rate limiting and error handling"""
|
|
|
|
def __init__(self, config: TradingAgentsConfig):
|
|
self.config = config
|
|
self.reddit = None
|
|
self.last_request_time = 0
|
|
self.min_request_interval = 1.0 # 1 second between requests
|
|
self.financial_subreddits = [
|
|
'wallstreetbets', 'investing', 'stocks', 'SecurityAnalysis',
|
|
'ValueInvesting', 'financialindependence', 'StockMarket',
|
|
'options', 'dividends', 'pennystocks'
|
|
]
|
|
|
|
async def __aenter__(self):
|
|
"""Async context manager entry"""
|
|
self._initialize_reddit()
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
"""Async context manager exit"""
|
|
pass
|
|
|
|
def _initialize_reddit(self):
|
|
"""Initialize PRAW Reddit instance"""
|
|
try:
|
|
self.reddit = praw.Reddit(
|
|
client_id=self.config.reddit_client_id,
|
|
client_secret=self.config.reddit_client_secret,
|
|
user_agent=self.config.reddit_user_agent,
|
|
check_for_async=False
|
|
)
|
|
|
|
# Test authentication
|
|
self.reddit.user.me()
|
|
logger.info("Reddit client initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Reddit client: {e}")
|
|
raise
|
|
|
|
async def _rate_limit_delay(self):
|
|
"""Implement rate limiting between requests"""
|
|
current_time = time.time()
|
|
time_since_last = current_time - self.last_request_time
|
|
|
|
if time_since_last < self.min_request_interval:
|
|
delay = self.min_request_interval - time_since_last
|
|
await asyncio.sleep(delay)
|
|
|
|
self.last_request_time = time.time()
|
|
|
|
async def fetch_subreddit_posts(
|
|
self,
|
|
subreddit_name: str,
|
|
time_filter: str = 'day',
|
|
limit: int = 50,
|
|
sort_type: str = 'hot'
|
|
) -> List[Dict[str, Any]]:
|
|
"""Fetch posts from a specific subreddit"""
|
|
if not self.reddit:
|
|
self._initialize_reddit()
|
|
|
|
await self._rate_limit_delay()
|
|
|
|
try:
|
|
subreddit = self.reddit.subreddit(subreddit_name)
|
|
|
|
# Get submissions based on sort type
|
|
if sort_type == 'hot':
|
|
submissions = subreddit.hot(limit=limit)
|
|
elif sort_type == 'top':
|
|
submissions = subreddit.top(time_filter=time_filter, limit=limit)
|
|
elif sort_type == 'new':
|
|
submissions = subreddit.new(limit=limit)
|
|
else:
|
|
submissions = subreddit.hot(limit=limit)
|
|
|
|
posts = []
|
|
for submission in submissions:
|
|
# Skip removed or deleted posts
|
|
if submission.selftext == '[removed]' or submission.selftext == '[deleted]':
|
|
continue
|
|
|
|
post_data = self._extract_post_data(submission, subreddit_name)
|
|
posts.append(post_data)
|
|
|
|
logger.info(f"Fetched {len(posts)} posts from r/{subreddit_name}")
|
|
return posts
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching posts from r/{subreddit_name}: {e}")
|
|
return []
|
|
|
|
async def fetch_financial_posts_batch(
|
|
self,
|
|
subreddits: Optional[List[str]] = None,
|
|
time_filter: str = 'day',
|
|
posts_per_subreddit: int = 50
|
|
) -> Dict[str, List[Dict[str, Any]]]:
|
|
"""Fetch posts from multiple financial subreddits"""
|
|
if not subreddits:
|
|
subreddits = self.financial_subreddits
|
|
|
|
results = {}
|
|
|
|
for subreddit_name in subreddits:
|
|
try:
|
|
posts = await self.fetch_subreddit_posts(
|
|
subreddit_name=subreddit_name,
|
|
time_filter=time_filter,
|
|
limit=posts_per_subreddit
|
|
)
|
|
results[subreddit_name] = posts
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch from r/{subreddit_name}: {e}")
|
|
results[subreddit_name] = []
|
|
|
|
total_posts = sum(len(posts) for posts in results.values())
|
|
logger.info(f"Fetched {total_posts} total posts from {len(subreddits)} subreddits")
|
|
|
|
return results
|
|
|
|
async def search_posts(
|
|
self,
|
|
query: str,
|
|
subreddit_names: Optional[List[str]] = None,
|
|
time_filter: str = 'week',
|
|
limit: int = 25
|
|
) -> List[Dict[str, Any]]:
|
|
"""Search for posts containing specific terms"""
|
|
if not self.reddit:
|
|
self._initialize_reddit()
|
|
|
|
if not subreddit_names:
|
|
subreddit_names = self.financial_subreddits
|
|
|
|
all_posts = []
|
|
|
|
for subreddit_name in subreddit_names:
|
|
await self._rate_limit_delay()
|
|
|
|
try:
|
|
subreddit = self.reddit.subreddit(subreddit_name)
|
|
search_results = subreddit.search(
|
|
query=query,
|
|
time_filter=time_filter,
|
|
limit=limit,
|
|
sort='relevance'
|
|
)
|
|
|
|
for submission in search_results:
|
|
if submission.selftext not in ['[removed]', '[deleted]']:
|
|
post_data = self._extract_post_data(submission, subreddit_name)
|
|
all_posts.append(post_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Search error in r/{subreddit_name}: {e}")
|
|
continue
|
|
|
|
logger.info(f"Found {len(all_posts)} posts matching query: {query}")
|
|
return all_posts
|
|
|
|
def _extract_post_data(self, submission: Any, subreddit_name: str) -> Dict[str, Any]:
|
|
"""Extract structured data from PRAW submission"""
|
|
try:
|
|
return {
|
|
'post_id': submission.id,
|
|
'title': submission.title[:300], # Limit title length
|
|
'content': submission.selftext if submission.selftext else None,
|
|
'author': str(submission.author) if submission.author else '[deleted]',
|
|
'subreddit': subreddit_name,
|
|
'created_utc': datetime.fromtimestamp(submission.created_utc),
|
|
'upvotes': getattr(submission, 'ups', submission.score),
|
|
'downvotes': max(0, submission.score - getattr(submission, 'ups', submission.score)),
|
|
'comments_count': submission.num_comments,
|
|
'url': f"https://reddit.com{submission.permalink}",
|
|
'reddit_score': submission.score,
|
|
'upvote_ratio': getattr(submission, 'upvote_ratio', 0.5),
|
|
'is_self': submission.is_self,
|
|
'domain': submission.domain,
|
|
'flair_text': getattr(submission, 'link_flair_text', None)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error extracting post data: {e}")
|
|
return None
|
|
|
|
async def get_post_details(self, post_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get detailed information for a specific post"""
|
|
if not self.reddit:
|
|
self._initialize_reddit()
|
|
|
|
await self._rate_limit_delay()
|
|
|
|
try:
|
|
submission = self.reddit.submission(id=post_id)
|
|
return self._extract_post_data(submission, submission.subreddit.display_name)
|
|
except Exception as e:
|
|
logger.error(f"Error fetching post details for {post_id}: {e}")
|
|
return None
|
|
|
|
async def health_check(self) -> bool:
|
|
"""Check if Reddit API is accessible"""
|
|
try:
|
|
if not self.reddit:
|
|
self._initialize_reddit()
|
|
|
|
# Simple API call to verify connectivity
|
|
self.reddit.subreddit('wallstreetbets').hot(limit=1)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Reddit health check failed: {e}")
|
|
return False
|
|
```
|
|
|
|
**Testing Implementation:**
|
|
```python
|
|
# tests/domains/socialmedia/test_reddit_client.py
|
|
import pytest
|
|
import pytest_vcr
|
|
from unittest.mock import MagicMock, patch
|
|
from tradingagents.domains.socialmedia.clients import RedditClient
|
|
from tradingagents.config import TradingAgentsConfig
|
|
|
|
@pytest_vcr.use_cassette('reddit_fetch_posts.yaml')
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_subreddit_posts(reddit_client, trading_config):
|
|
"""Test fetching posts from a specific subreddit"""
|
|
async with reddit_client:
|
|
posts = await reddit_client.fetch_subreddit_posts('wallstreetbets', limit=10)
|
|
|
|
assert len(posts) > 0
|
|
for post in posts:
|
|
assert 'post_id' in post
|
|
assert 'title' in post
|
|
assert 'subreddit' in post
|
|
assert post['subreddit'] == 'wallstreetbets'
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] PRAW Reddit client properly authenticated and initialized
|
|
- [ ] Rate limiting implemented (1 request per second minimum)
|
|
- [ ] Comprehensive error handling for network issues and API limits
|
|
- [ ] Financial subreddit focus with configurable subreddit lists
|
|
- [ ] Structured data extraction from Reddit submissions
|
|
- [ ] Search functionality across multiple subreddits
|
|
- [ ] Health check capabilities for monitoring
|
|
- [ ] Test coverage with pytest-vcr cassettes
|
|
|
|
**Dependencies:** Reddit API credentials in TradingAgentsConfig
|
|
**Risk:** High - External API dependency, rate limiting complexity
|
|
|
|
---
|
|
|
|
### Task 2.2: OpenRouter LLM Sentiment Analysis (3 hours)
|
|
**Priority: Medium** | **Agent: LLM Integration Specialist**
|
|
|
|
Implement sentiment analysis using OpenRouter with social media-specific prompts and structured output parsing.
|
|
|
|
**File:** `tradingagents/domains/socialmedia/sentiment.py`
|
|
|
|
**Implementation:**
|
|
```python
|
|
from typing import Optional, Dict, Any, List
|
|
import json
|
|
import asyncio
|
|
from tradingagents.llm.openrouter_client import OpenRouterClient
|
|
from tradingagents.config import TradingAgentsConfig
|
|
from tradingagents.domains.socialmedia.models import SentimentScore
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class SocialSentimentAnalyzer:
|
|
"""OpenRouter-based sentiment analysis for social media posts"""
|
|
|
|
def __init__(self, config: TradingAgentsConfig):
|
|
self.config = config
|
|
self.client = OpenRouterClient(config)
|
|
self.batch_size = 5 # Process posts in batches
|
|
|
|
async def analyze_post_sentiment(self, post_text: str, ticker: Optional[str] = None) -> Optional[SentimentScore]:
|
|
"""Analyze sentiment of a single social media post"""
|
|
prompt = self._create_sentiment_prompt(post_text, ticker)
|
|
|
|
try:
|
|
response = await self.client.generate_response(
|
|
model=self.config.quick_think_llm,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
max_tokens=150,
|
|
temperature=0.1,
|
|
response_format={"type": "json_object"}
|
|
)
|
|
|
|
result = json.loads(response)
|
|
|
|
return SentimentScore(
|
|
sentiment=result.get('sentiment', 'neutral'),
|
|
confidence=float(result.get('confidence', 0.0)),
|
|
reasoning=result.get('reasoning')
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Sentiment analysis failed: {e}")
|
|
return None
|
|
|
|
async def analyze_batch(
|
|
self,
|
|
posts: List[Dict[str, Any]],
|
|
include_ticker: bool = True
|
|
) -> List[Optional[SentimentScore]]:
|
|
"""Analyze sentiment for multiple posts with rate limiting"""
|
|
results = []
|
|
|
|
for i in range(0, len(posts), self.batch_size):
|
|
batch = posts[i:i + self.batch_size]
|
|
batch_tasks = []
|
|
|
|
for post in batch:
|
|
text = self._combine_post_text(post)
|
|
ticker = None
|
|
|
|
if include_ticker and 'tickers' in post and post['tickers']:
|
|
ticker = post['tickers'][0] # Use first ticker if available
|
|
|
|
task = self.analyze_post_sentiment(text, ticker)
|
|
batch_tasks.append(task)
|
|
|
|
# Process batch with concurrency limit
|
|
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
|
|
|
for result in batch_results:
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Batch sentiment analysis error: {result}")
|
|
results.append(None)
|
|
else:
|
|
results.append(result)
|
|
|
|
# Rate limiting between batches
|
|
if i + self.batch_size < len(posts):
|
|
await asyncio.sleep(1.0)
|
|
|
|
successful_count = sum(1 for r in results if r is not None)
|
|
logger.info(f"Sentiment analysis completed: {successful_count}/{len(posts)} successful")
|
|
|
|
return results
|
|
|
|
def _create_sentiment_prompt(self, text: str, ticker: Optional[str] = None) -> str:
|
|
"""Create social media-specific sentiment analysis prompt"""
|
|
ticker_context = f" for ticker ${ticker}" if ticker else ""
|
|
|
|
return f"""
|
|
Analyze the financial sentiment of this Reddit post{ticker_context}. Consider:
|
|
- Trading/investment sentiment (not general mood)
|
|
- Informal language, slang, and memes common in financial social media
|
|
- Context clues like "diamond hands", "to the moon", "bearish", etc.
|
|
- Overall market outlook expressed in the post
|
|
|
|
Post text: "{text}"
|
|
|
|
Respond with JSON only:
|
|
{{
|
|
"sentiment": "positive|negative|neutral",
|
|
"confidence": 0.0-1.0,
|
|
"reasoning": "brief explanation of key factors"
|
|
}}
|
|
|
|
Guidelines:
|
|
- "positive": Bullish, optimistic about price/performance
|
|
- "negative": Bearish, pessimistic about price/performance
|
|
- "neutral": Mixed signals or no clear directional sentiment
|
|
- Confidence: How certain are you? (0.5+ for reliable sentiment)
|
|
- Reasoning: Key words/phrases that influenced the classification
|
|
""".strip()
|
|
|
|
def _combine_post_text(self, post: Dict[str, Any]) -> str:
|
|
"""Combine title and content for sentiment analysis"""
|
|
title = post.get('title', '')
|
|
content = post.get('content', '')
|
|
|
|
if content:
|
|
# Limit total text length for efficient processing
|
|
combined = f"{title} {content}"[:1000]
|
|
else:
|
|
combined = title
|
|
|
|
return combined.strip()
|
|
|
|
async def analyze_market_sentiment(
|
|
self,
|
|
posts: List[Dict[str, Any]],
|
|
ticker: str
|
|
) -> Dict[str, Any]:
|
|
"""Analyze overall market sentiment for a ticker from multiple posts"""
|
|
sentiments = await self.analyze_batch(posts, include_ticker=True)
|
|
|
|
# Filter out failed analyses
|
|
valid_sentiments = [s for s in sentiments if s is not None and s.confidence >= 0.5]
|
|
|
|
if not valid_sentiments:
|
|
return {
|
|
'ticker': ticker,
|
|
'overall_sentiment': 'neutral',
|
|
'confidence': 0.0,
|
|
'post_count': len(posts),
|
|
'analysis_success_rate': 0.0,
|
|
'sentiment_distribution': {'positive': 0, 'negative': 0, 'neutral': 0}
|
|
}
|
|
|
|
# Calculate sentiment distribution
|
|
sentiment_counts = {'positive': 0, 'negative': 0, 'neutral': 0}
|
|
confidence_sum = 0
|
|
|
|
for sentiment in valid_sentiments:
|
|
sentiment_counts[sentiment.sentiment] += 1
|
|
confidence_sum += sentiment.confidence
|
|
|
|
# Determine overall sentiment
|
|
total_valid = len(valid_sentiments)
|
|
positive_ratio = sentiment_counts['positive'] / total_valid
|
|
negative_ratio = sentiment_counts['negative'] / total_valid
|
|
|
|
if positive_ratio > 0.6:
|
|
overall_sentiment = 'positive'
|
|
elif negative_ratio > 0.6:
|
|
overall_sentiment = 'negative'
|
|
else:
|
|
overall_sentiment = 'neutral'
|
|
|
|
return {
|
|
'ticker': ticker,
|
|
'overall_sentiment': overall_sentiment,
|
|
'confidence': confidence_sum / total_valid,
|
|
'post_count': len(posts),
|
|
'analyzed_posts': total_valid,
|
|
'analysis_success_rate': total_valid / len(posts),
|
|
'sentiment_distribution': sentiment_counts,
|
|
'positive_ratio': positive_ratio,
|
|
'negative_ratio': negative_ratio,
|
|
'neutral_ratio': sentiment_counts['neutral'] / total_valid
|
|
}
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] OpenRouter integration for sentiment analysis with structured JSON output
|
|
- [ ] Social media-specific prompts handling informal language and financial slang
|
|
- [ ] Batch processing with rate limiting and error handling
|
|
- [ ] Confidence scoring for sentiment reliability
|
|
- [ ] Market sentiment aggregation across multiple posts
|
|
- [ ] Comprehensive error handling and logging
|
|
- [ ] Test coverage with mocked LLM responses
|
|
|
|
**Dependencies:** OpenRouter client implementation
|
|
**Risk:** Medium - LLM API reliability and cost management
|
|
|
|
---
|
|
|
|
### Task 2.3: Vector Embedding Generation (2 hours)
|
|
**Priority: Medium** | **Agent: ML Integration Specialist**
|
|
|
|
Implement vector embedding generation for semantic similarity search using OpenRouter embedding models.
|
|
|
|
**File:** `tradingagents/domains/socialmedia/embeddings.py`
|
|
|
|
**Implementation:**
|
|
```python
|
|
from typing import List, Optional, Dict, Any
|
|
import asyncio
|
|
import numpy as np
|
|
from tradingagents.llm.openrouter_client import OpenRouterClient
|
|
from tradingagents.config import TradingAgentsConfig
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class SocialEmbeddingGenerator:
|
|
"""Generate vector embeddings for social media posts using OpenRouter"""
|
|
|
|
def __init__(self, config: TradingAgentsConfig):
|
|
self.config = config
|
|
self.client = OpenRouterClient(config)
|
|
self.embedding_model = "text-embedding-3-large" # 1536 dimensions
|
|
self.max_text_length = 8000 # Token limit for embedding model
|
|
self.batch_size = 10
|
|
|
|
async def generate_post_embeddings(
|
|
self,
|
|
post: Dict[str, Any]
|
|
) -> Dict[str, Optional[List[float]]]:
|
|
"""Generate embeddings for post title and content separately"""
|
|
embeddings = {
|
|
'title_embedding': None,
|
|
'content_embedding': None
|
|
}
|
|
|
|
# Generate title embedding
|
|
title = post.get('title', '').strip()
|
|
if title:
|
|
embeddings['title_embedding'] = await self._generate_embedding(title)
|
|
|
|
# Generate content embedding if content exists
|
|
content = post.get('content', '').strip()
|
|
if content:
|
|
# Combine title and content for content embedding
|
|
combined_text = f"{title} {content}"[:self.max_text_length]
|
|
embeddings['content_embedding'] = await self._generate_embedding(combined_text)
|
|
|
|
return embeddings
|
|
|
|
async def generate_batch_embeddings(
|
|
self,
|
|
posts: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Optional[List[float]]]]:
|
|
"""Generate embeddings for multiple posts with batching"""
|
|
results = []
|
|
|
|
for i in range(0, len(posts), self.batch_size):
|
|
batch = posts[i:i + self.batch_size]
|
|
|
|
# Create tasks for concurrent processing
|
|
tasks = [self.generate_post_embeddings(post) for post in batch]
|
|
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
for result in batch_results:
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Embedding generation error: {result}")
|
|
results.append({'title_embedding': None, 'content_embedding': None})
|
|
else:
|
|
results.append(result)
|
|
|
|
# Rate limiting between batches
|
|
if i + self.batch_size < len(posts):
|
|
await asyncio.sleep(0.5)
|
|
|
|
successful_count = sum(
|
|
1 for r in results
|
|
if r.get('title_embedding') is not None or r.get('content_embedding') is not None
|
|
)
|
|
logger.info(f"Embedding generation completed: {successful_count}/{len(posts)} successful")
|
|
|
|
return results
|
|
|
|
async def generate_query_embedding(self, query: str) -> Optional[List[float]]:
|
|
"""Generate embedding for search query"""
|
|
return await self._generate_embedding(query[:self.max_text_length])
|
|
|
|
async def _generate_embedding(self, text: str) -> Optional[List[float]]:
|
|
"""Generate single embedding using OpenRouter"""
|
|
if not text.strip():
|
|
return None
|
|
|
|
try:
|
|
response = await self.client.create_embeddings(
|
|
model=self.embedding_model,
|
|
input=[text],
|
|
encoding_format="float"
|
|
)
|
|
|
|
if response and response.data:
|
|
embedding = response.data[0].embedding
|
|
|
|
# Validate embedding dimensions
|
|
if len(embedding) != 1536:
|
|
logger.error(f"Unexpected embedding dimension: {len(embedding)}")
|
|
return None
|
|
|
|
return embedding
|
|
|
|
except Exception as e:
|
|
logger.error(f"Embedding generation failed for text: {e}")
|
|
return None
|
|
|
|
return None
|
|
|
|
def calculate_similarity(
|
|
self,
|
|
embedding1: List[float],
|
|
embedding2: List[float]
|
|
) -> float:
|
|
"""Calculate cosine similarity between two embeddings"""
|
|
try:
|
|
# Convert to numpy arrays for efficient computation
|
|
vec1 = np.array(embedding1)
|
|
vec2 = np.array(embedding2)
|
|
|
|
# Cosine similarity: dot product / (magnitude1 * magnitude2)
|
|
dot_product = np.dot(vec1, vec2)
|
|
magnitude1 = np.linalg.norm(vec1)
|
|
magnitude2 = np.linalg.norm(vec2)
|
|
|
|
if magnitude1 == 0 or magnitude2 == 0:
|
|
return 0.0
|
|
|
|
similarity = dot_product / (magnitude1 * magnitude2)
|
|
return float(similarity)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Similarity calculation error: {e}")
|
|
return 0.0
|
|
|
|
def find_most_similar(
|
|
self,
|
|
query_embedding: List[float],
|
|
post_embeddings: List[Dict[str, Any]],
|
|
top_k: int = 10
|
|
) -> List[Dict[str, Any]]:
|
|
"""Find most similar posts to query embedding"""
|
|
similarities = []
|
|
|
|
for i, post_data in enumerate(post_embeddings):
|
|
max_similarity = 0.0
|
|
best_embedding_type = None
|
|
|
|
# Check title embedding similarity
|
|
title_emb = post_data.get('title_embedding')
|
|
if title_emb:
|
|
title_sim = self.calculate_similarity(query_embedding, title_emb)
|
|
if title_sim > max_similarity:
|
|
max_similarity = title_sim
|
|
best_embedding_type = 'title'
|
|
|
|
# Check content embedding similarity
|
|
content_emb = post_data.get('content_embedding')
|
|
if content_emb:
|
|
content_sim = self.calculate_similarity(query_embedding, content_emb)
|
|
if content_sim > max_similarity:
|
|
max_similarity = content_sim
|
|
best_embedding_type = 'content'
|
|
|
|
if max_similarity > 0:
|
|
similarities.append({
|
|
'post_index': i,
|
|
'similarity_score': max_similarity,
|
|
'embedding_type': best_embedding_type,
|
|
'post_data': post_data
|
|
})
|
|
|
|
# Sort by similarity score and return top k
|
|
similarities.sort(key=lambda x: x['similarity_score'], reverse=True)
|
|
return similarities[:top_k]
|
|
|
|
async def create_semantic_clusters(
|
|
self,
|
|
posts: List[Dict[str, Any]],
|
|
similarity_threshold: float = 0.8
|
|
) -> List[List[Dict[str, Any]]]:
|
|
"""Group similar posts into semantic clusters"""
|
|
if not posts:
|
|
return []
|
|
|
|
# Generate embeddings for all posts
|
|
embeddings_data = await self.generate_batch_embeddings(posts)
|
|
|
|
# Combine posts with their embeddings
|
|
posts_with_embeddings = []
|
|
for post, embeddings in zip(posts, embeddings_data):
|
|
if embeddings.get('title_embedding') or embeddings.get('content_embedding'):
|
|
posts_with_embeddings.append({**post, **embeddings})
|
|
|
|
clusters = []
|
|
processed = set()
|
|
|
|
for i, post in enumerate(posts_with_embeddings):
|
|
if i in processed:
|
|
continue
|
|
|
|
current_cluster = [post]
|
|
processed.add(i)
|
|
|
|
# Find similar posts for current cluster
|
|
for j, other_post in enumerate(posts_with_embeddings):
|
|
if j in processed or i == j:
|
|
continue
|
|
|
|
# Calculate similarity between posts
|
|
max_sim = 0.0
|
|
|
|
# Compare all embedding combinations
|
|
for emb1_type in ['title_embedding', 'content_embedding']:
|
|
for emb2_type in ['title_embedding', 'content_embedding']:
|
|
emb1 = post.get(emb1_type)
|
|
emb2 = other_post.get(emb2_type)
|
|
|
|
if emb1 and emb2:
|
|
sim = self.calculate_similarity(emb1, emb2)
|
|
max_sim = max(max_sim, sim)
|
|
|
|
if max_sim >= similarity_threshold:
|
|
current_cluster.append(other_post)
|
|
processed.add(j)
|
|
|
|
if len(current_cluster) > 1: # Only include clusters with multiple posts
|
|
clusters.append(current_cluster)
|
|
|
|
logger.info(f"Created {len(clusters)} semantic clusters from {len(posts)} posts")
|
|
return clusters
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] Vector embedding generation for post titles and content separately
|
|
- [ ] Batch processing with rate limiting for efficiency
|
|
- [ ] Cosine similarity calculation for semantic search
|
|
- [ ] Query embedding generation for search functionality
|
|
- [ ] Semantic clustering capabilities for related post discovery
|
|
- [ ] Proper error handling and dimension validation
|
|
- [ ] Test coverage with mocked embedding responses
|
|
|
|
**Dependencies:** OpenRouter client with embedding support
|
|
**Risk:** Low - Standard embedding generation patterns
|
|
|
|
---
|
|
|
|
### Task 2.4: Service Layer Implementation (3 hours)
|
|
**Priority: Medium** | **Agent: Service Integration Specialist**
|
|
|
|
Implement SocialMediaService that orchestrates Reddit collection, sentiment analysis, and embedding generation.
|
|
|
|
**File:** `tradingagents/domains/socialmedia/services.py`
|
|
|
|
**Implementation:**
|
|
```python
|
|
from typing import List, Optional, Dict, Any, Tuple
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
|
|
from tradingagents.domains.socialmedia.clients import RedditClient
|
|
from tradingagents.domains.socialmedia.repositories import SocialRepository
|
|
from tradingagents.domains.socialmedia.sentiment import SocialSentimentAnalyzer
|
|
from tradingagents.domains.socialmedia.embeddings import SocialEmbeddingGenerator
|
|
from tradingagents.domains.socialmedia.models import SocialPost, SocialContext
|
|
from tradingagents.config import TradingAgentsConfig
|
|
from tradingagents.database import DatabaseManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class SocialMediaService:
|
|
"""Orchestrates social media data collection, analysis, and storage"""
|
|
|
|
def __init__(self, config: TradingAgentsConfig, db_manager: DatabaseManager):
|
|
self.config = config
|
|
self.db_manager = db_manager
|
|
self.repository = SocialRepository(db_manager)
|
|
self.sentiment_analyzer = SocialSentimentAnalyzer(config)
|
|
self.embedding_generator = SocialEmbeddingGenerator(config)
|
|
|
|
# Configuration
|
|
self.financial_subreddits = [
|
|
'wallstreetbets', 'investing', 'stocks', 'SecurityAnalysis',
|
|
'ValueInvesting', 'financialindependence', 'StockMarket'
|
|
]
|
|
self.min_score_threshold = 10 # Minimum upvotes
|
|
self.max_posts_per_subreddit = 50
|
|
|
|
async def collect_and_process_posts(
|
|
self,
|
|
subreddits: Optional[List[str]] = None,
|
|
time_filter: str = 'day',
|
|
process_sentiment: bool = True,
|
|
generate_embeddings: bool = True
|
|
) -> Dict[str, Any]:
|
|
"""Main entry point for collecting and processing social media posts"""
|
|
if not subreddits:
|
|
subreddits = self.financial_subreddits
|
|
|
|
collection_start = datetime.now()
|
|
logger.info(f"Starting social media collection from {len(subreddits)} subreddits")
|
|
|
|
async with RedditClient(self.config) as reddit_client:
|
|
# Collect raw posts from Reddit
|
|
raw_posts_by_subreddit = await reddit_client.fetch_financial_posts_batch(
|
|
subreddits=subreddits,
|
|
time_filter=time_filter,
|
|
posts_per_subreddit=self.max_posts_per_subreddit
|
|
)
|
|
|
|
# Flatten and filter posts
|
|
all_raw_posts = []
|
|
for subreddit, posts in raw_posts_by_subreddit.items():
|
|
filtered_posts = [
|
|
post for post in posts
|
|
if post and post.get('reddit_score', 0) >= self.min_score_threshold
|
|
]
|
|
all_raw_posts.extend(filtered_posts)
|
|
|
|
logger.info(f"Collected {len(all_raw_posts)} posts meeting quality thresholds")
|
|
|
|
# Convert to domain objects and extract tickers
|
|
domain_posts = []
|
|
for raw_post in all_raw_posts:
|
|
try:
|
|
post = SocialPost(**raw_post)
|
|
post.tickers = post.extract_tickers() # Extract tickers from content
|
|
domain_posts.append(post)
|
|
except Exception as e:
|
|
logger.error(f"Error creating domain object: {e}")
|
|
continue
|
|
|
|
# Process sentiment analysis if requested
|
|
if process_sentiment and domain_posts:
|
|
await self._process_sentiment_analysis(domain_posts)
|
|
|
|
# Generate embeddings if requested
|
|
if generate_embeddings and domain_posts:
|
|
await self._process_embeddings(domain_posts)
|
|
|
|
# Save to database
|
|
saved_post_ids = await self.repository.upsert_batch(domain_posts)
|
|
|
|
collection_end = datetime.now()
|
|
processing_time = (collection_end - collection_start).total_seconds()
|
|
|
|
# Calculate success metrics
|
|
results = {
|
|
'collection_timestamp': collection_start.isoformat(),
|
|
'processing_time_seconds': processing_time,
|
|
'subreddits_processed': subreddits,
|
|
'total_posts_collected': len(all_raw_posts),
|
|
'posts_processed': len(domain_posts),
|
|
'posts_saved': len(saved_post_ids),
|
|
'sentiment_analysis_enabled': process_sentiment,
|
|
'embeddings_enabled': generate_embeddings,
|
|
'subreddit_breakdown': {}
|
|
}
|
|
|
|
# Add per-subreddit breakdown
|
|
for subreddit, posts in raw_posts_by_subreddit.items():
|
|
results['subreddit_breakdown'][subreddit] = {
|
|
'posts_collected': len(posts),
|
|
'posts_filtered': len([p for p in posts if p.get('reddit_score', 0) >= self.min_score_threshold])
|
|
}
|
|
|
|
logger.info(f"Collection completed: {len(saved_post_ids)} posts saved in {processing_time:.2f}s")
|
|
return results
|
|
|
|
async def get_social_context(
|
|
self,
|
|
ticker: str,
|
|
days: int = 7,
|
|
include_similar: bool = True,
|
|
similarity_query: Optional[str] = None
|
|
) -> SocialContext:
|
|
"""Get comprehensive social media context for a ticker"""
|
|
logger.info(f"Generating social context for {ticker} ({days} days)")
|
|
|
|
# Get direct ticker mentions
|
|
ticker_posts = await self.repository.find_by_ticker(ticker, days=days, limit=50)
|
|
|
|
similar_posts = []
|
|
if include_similar and ticker_posts:
|
|
# Use semantic search to find related discussions
|
|
if similarity_query:
|
|
query_embedding = await self.embedding_generator.generate_query_embedding(similarity_query)
|
|
if query_embedding:
|
|
similar_results = await self.repository.find_similar_posts(
|
|
query_embedding=query_embedding,
|
|
ticker=ticker,
|
|
limit=10
|
|
)
|
|
similar_posts = [post for post, score in similar_results]
|
|
|
|
# Get sentiment summary
|
|
sentiment_summary = await self.repository.get_sentiment_summary(
|
|
ticker=ticker,
|
|
hours=days * 24
|
|
)
|
|
|
|
# Find trending discussions
|
|
trending_tickers = await self.repository.get_trending_tickers(
|
|
hours=days * 24,
|
|
min_mentions=3
|
|
)
|
|
ticker_trend = next(
|
|
(trend for trend in trending_tickers if trend['ticker'] == ticker.upper()),
|
|
None
|
|
)
|
|
|
|
return SocialContext(
|
|
ticker=ticker,
|
|
period_days=days,
|
|
direct_mentions=ticker_posts,
|
|
similar_posts=similar_posts,
|
|
sentiment_summary=sentiment_summary,
|
|
trending_info=ticker_trend,
|
|
total_posts=len(ticker_posts) + len(similar_posts),
|
|
data_quality_score=self._calculate_data_quality(ticker_posts + similar_posts)
|
|
)
|
|
|
|
async def search_posts_semantic(
|
|
self,
|
|
query: str,
|
|
ticker: Optional[str] = None,
|
|
limit: int = 10,
|
|
min_similarity: float = 0.7
|
|
) -> List[Tuple[SocialPost, float]]:
|
|
"""Semantic search for social media posts"""
|
|
query_embedding = await self.embedding_generator.generate_query_embedding(query)
|
|
|
|
if not query_embedding:
|
|
logger.error(f"Failed to generate query embedding for: {query}")
|
|
return []
|
|
|
|
return await self.repository.find_similar_posts(
|
|
query_embedding=query_embedding,
|
|
ticker=ticker,
|
|
limit=limit,
|
|
similarity_threshold=min_similarity
|
|
)
|
|
|
|
async def get_subreddit_analysis(
|
|
self,
|
|
subreddit: str,
|
|
hours: int = 24
|
|
) -> Dict[str, Any]:
|
|
"""Get analysis of a specific subreddit's activity"""
|
|
posts = await self.repository.find_by_subreddit(subreddit, hours=hours)
|
|
|
|
if not posts:
|
|
return {
|
|
'subreddit': subreddit,
|
|
'period_hours': hours,
|
|
'total_posts': 0,
|
|
'message': f'No posts found for r/{subreddit} in the last {hours} hours'
|
|
}
|
|
|
|
# Analyze ticker mentions
|
|
ticker_counts = {}
|
|
for post in posts:
|
|
for ticker in post.tickers or []:
|
|
ticker_counts[ticker] = ticker_counts.get(ticker, 0) + 1
|
|
|
|
top_tickers = sorted(ticker_counts.items(), key=lambda x: x[1], reverse=True)[:10]
|
|
|
|
# Analyze sentiment distribution
|
|
sentiment_counts = {'positive': 0, 'negative': 0, 'neutral': 0}
|
|
reliable_sentiment_count = 0
|
|
|
|
for post in posts:
|
|
if post.sentiment_label:
|
|
sentiment_counts[post.sentiment_label] += 1
|
|
if post.has_reliable_sentiment():
|
|
reliable_sentiment_count += 1
|
|
|
|
# Calculate engagement metrics
|
|
total_upvotes = sum(post.upvotes for post in posts)
|
|
total_comments = sum(post.comments_count for post in posts)
|
|
avg_score = total_upvotes / len(posts) if posts else 0
|
|
|
|
return {
|
|
'subreddit': subreddit,
|
|
'period_hours': hours,
|
|
'total_posts': len(posts),
|
|
'engagement_metrics': {
|
|
'total_upvotes': total_upvotes,
|
|
'total_comments': total_comments,
|
|
'avg_score': avg_score,
|
|
'top_post_score': max(post.upvotes for post in posts) if posts else 0
|
|
},
|
|
'sentiment_analysis': {
|
|
'distribution': sentiment_counts,
|
|
'reliable_sentiment_posts': reliable_sentiment_count,
|
|
'sentiment_reliability': reliable_sentiment_count / len(posts) if posts else 0
|
|
},
|
|
'ticker_mentions': {
|
|
'top_tickers': top_tickers,
|
|
'unique_tickers': len(ticker_counts),
|
|
'total_mentions': sum(ticker_counts.values())
|
|
},
|
|
'data_quality': self._calculate_data_quality(posts)
|
|
}
|
|
|
|
async def _process_sentiment_analysis(self, posts: List[SocialPost]) -> None:
|
|
"""Process sentiment analysis for posts"""
|
|
logger.info(f"Processing sentiment analysis for {len(posts)} posts")
|
|
|
|
# Convert to dict format for sentiment analyzer
|
|
posts_data = []
|
|
for post in posts:
|
|
post_dict = post.dict()
|
|
posts_data.append(post_dict)
|
|
|
|
# Analyze sentiment in batches
|
|
sentiments = await self.sentiment_analyzer.analyze_batch(posts_data)
|
|
|
|
# Update posts with sentiment results
|
|
for post, sentiment in zip(posts, sentiments):
|
|
if sentiment:
|
|
post.sentiment_score = sentiment.score if hasattr(sentiment, 'score') else None
|
|
post.sentiment_label = sentiment.sentiment
|
|
post.sentiment_confidence = sentiment.confidence
|
|
post.sentiment_reasoning = sentiment.reasoning
|
|
|
|
successful_count = sum(1 for s in sentiments if s is not None)
|
|
logger.info(f"Sentiment analysis completed: {successful_count}/{len(posts)} successful")
|
|
|
|
async def _process_embeddings(self, posts: List[SocialPost]) -> None:
|
|
"""Process embedding generation for posts"""
|
|
logger.info(f"Generating embeddings for {len(posts)} posts")
|
|
|
|
# Convert to dict format for embedding generator
|
|
posts_data = []
|
|
for post in posts:
|
|
post_dict = post.dict()
|
|
posts_data.append(post_dict)
|
|
|
|
# Generate embeddings in batches
|
|
embeddings = await self.embedding_generator.generate_batch_embeddings(posts_data)
|
|
|
|
# Update posts with embedding results
|
|
for post, embedding_data in zip(posts, embeddings):
|
|
post.title_embedding = embedding_data.get('title_embedding')
|
|
post.content_embedding = embedding_data.get('content_embedding')
|
|
|
|
successful_count = sum(
|
|
1 for e in embeddings
|
|
if e.get('title_embedding') is not None or e.get('content_embedding') is not None
|
|
)
|
|
logger.info(f"Embedding generation completed: {successful_count}/{len(posts)} successful")
|
|
|
|
def _calculate_data_quality(self, posts: List[SocialPost]) -> Dict[str, float]:
|
|
"""Calculate data quality metrics for posts"""
|
|
if not posts:
|
|
return {'overall_score': 0.0}
|
|
|
|
sentiment_coverage = sum(1 for p in posts if p.sentiment_label is not None) / len(posts)
|
|
reliable_sentiment = sum(1 for p in posts if p.has_reliable_sentiment()) / len(posts)
|
|
embedding_coverage = sum(
|
|
1 for p in posts
|
|
if p.title_embedding is not None or p.content_embedding is not None
|
|
) / len(posts)
|
|
ticker_extraction = sum(1 for p in posts if p.tickers) / len(posts)
|
|
|
|
overall_score = (sentiment_coverage + reliable_sentiment + embedding_coverage + ticker_extraction) / 4
|
|
|
|
return {
|
|
'overall_score': overall_score,
|
|
'sentiment_coverage': sentiment_coverage,
|
|
'reliable_sentiment_ratio': reliable_sentiment,
|
|
'embedding_coverage': embedding_coverage,
|
|
'ticker_extraction_ratio': ticker_extraction
|
|
}
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] Orchestrates complete collection, analysis, and storage pipeline
|
|
- [ ] Integrates Reddit client, sentiment analyzer, and embedding generator
|
|
- [ ] Handles batch processing with proper error handling and logging
|
|
- [ ] Provides ticker-specific social context with sentiment and similarity
|
|
- [ ] Semantic search capabilities with configurable similarity thresholds
|
|
- [ ] Subreddit analysis with engagement and sentiment metrics
|
|
- [ ] Data quality scoring and monitoring
|
|
- [ ] Comprehensive test coverage with mocked dependencies
|
|
|
|
**Dependencies:** All Phase 2 tasks (clients, sentiment, embeddings)
|
|
**Risk:** Medium - Complex orchestration of multiple async services
|
|
|
|
---
|
|
|
|
## Phase 3: Integration & Validation (8 hours)
|
|
|
|
### Task 3.1: AgentToolkit Integration (3 hours)
|
|
**Priority: High** | **Agent: Agent Integration Specialist**
|
|
|
|
Add RAG-enhanced social media methods to AgentToolkit for AI agent consumption.
|
|
|
|
**File:** `tradingagents/agents/libs/agent_toolkit.py` (additions)
|
|
|
|
**Implementation:**
|
|
```python
|
|
# Additional methods for AgentToolkit class
|
|
|
|
async def get_reddit_sentiment(
|
|
self,
|
|
ticker: str,
|
|
days: int = 7,
|
|
include_context: bool = True
|
|
) -> str:
|
|
"""Get Reddit sentiment analysis for a specific ticker with RAG context"""
|
|
try:
|
|
if not hasattr(self, 'social_service'):
|
|
self.social_service = SocialMediaService(self.config, self.db_manager)
|
|
|
|
# Get comprehensive social context
|
|
social_context = await self.social_service.get_social_context(
|
|
ticker=ticker,
|
|
days=days,
|
|
include_similar=include_context
|
|
)
|
|
|
|
if not social_context.total_posts:
|
|
return f"No Reddit sentiment data found for ${ticker} in the last {days} days."
|
|
|
|
# Format for agent consumption
|
|
sentiment_summary = social_context.sentiment_summary
|
|
trending_info = social_context.trending_info
|
|
|
|
context = f"Reddit Sentiment Analysis for ${ticker} ({days}-day period):\n\n"
|
|
|
|
# Overall sentiment metrics
|
|
if sentiment_summary:
|
|
overall_score = sentiment_summary.get('overall_sentiment', 0.0)
|
|
sentiment_emoji = "📈" if overall_score > 0.1 else "📉" if overall_score < -0.1 else "➡️"
|
|
|
|
context += f"{sentiment_emoji} Overall Sentiment: {overall_score:.2f}/1.0\n"
|
|
context += f"📊 Analysis Coverage: {social_context.total_posts} posts analyzed\n"
|
|
|
|
# Sentiment breakdown
|
|
breakdown = sentiment_summary.get('sentiment_breakdown', {})
|
|
if breakdown:
|
|
context += f" • Positive: {breakdown.get('positive', {}).get('count', 0)} posts\n"
|
|
context += f" • Negative: {breakdown.get('negative', {}).get('count', 0)} posts\n"
|
|
context += f" • Neutral: {breakdown.get('neutral', {}).get('count', 0)} posts\n"
|
|
|
|
# Trending information
|
|
if trending_info:
|
|
context += f"\n🔥 Trending Status:\n"
|
|
context += f" • Mentions: {trending_info['mention_count']} posts\n"
|
|
context += f" • Engagement: {trending_info['engagement_score']} (upvotes + comments)\n"
|
|
context += f" • Avg Sentiment: {trending_info['avg_sentiment']:.2f}\n"
|
|
|
|
# Top discussions (sample posts)
|
|
if social_context.direct_mentions:
|
|
context += f"\n💬 Recent Discussions:\n"
|
|
for i, post in enumerate(social_context.direct_mentions[:5]):
|
|
sentiment_emoji = {"positive": "📈", "negative": "📉", "neutral": "➡️"}.get(
|
|
post.sentiment_label, "❓"
|
|
)
|
|
context += f"{i+1}. {sentiment_emoji} r/{post.subreddit}: {post.title[:100]}...\n"
|
|
context += f" Score: {post.upvotes} upvotes, {post.comments_count} comments\n"
|
|
if post.has_reliable_sentiment():
|
|
context += f" Sentiment: {post.sentiment_label} ({post.sentiment_confidence:.2f})\n"
|
|
|
|
# Data quality indicators
|
|
quality = social_context.data_quality_score
|
|
context += f"\n📋 Data Quality: {quality.get('overall_score', 0):.1%} coverage\n"
|
|
|
|
return context
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting Reddit sentiment for {ticker}: {e}")
|
|
return f"Error retrieving Reddit sentiment for ${ticker}: {str(e)}"
|
|
|
|
async def get_reddit_stock_info(
|
|
self,
|
|
ticker: str,
|
|
query: Optional[str] = None,
|
|
days: int = 7
|
|
) -> str:
|
|
"""Get Reddit stock information with optional semantic search"""
|
|
try:
|
|
if not hasattr(self, 'social_service'):
|
|
self.social_service = SocialMediaService(self.config, self.db_manager)
|
|
|
|
context = f"Reddit Stock Information for ${ticker}:\n\n"
|
|
|
|
if query:
|
|
# Semantic search for specific information
|
|
search_results = await self.social_service.search_posts_semantic(
|
|
query=query,
|
|
ticker=ticker,
|
|
limit=10,
|
|
min_similarity=0.7
|
|
)
|
|
|
|
if search_results:
|
|
context += f"🔍 Semantic Search Results for '{query}':\n"
|
|
for i, (post, similarity) in enumerate(search_results[:5]):
|
|
context += f"{i+1}. (Similarity: {similarity:.2f}) r/{post.subreddit}\n"
|
|
context += f" Title: {post.title}\n"
|
|
if post.content:
|
|
context += f" Content: {post.content[:150]}...\n"
|
|
context += f" Engagement: {post.upvotes} upvotes, {post.comments_count} comments\n\n"
|
|
else:
|
|
context += f"🔍 No relevant discussions found for '{query}' about ${ticker}\n\n"
|
|
|
|
# Get general stock context
|
|
social_context = await self.social_service.get_social_context(
|
|
ticker=ticker,
|
|
days=days,
|
|
include_similar=False
|
|
)
|
|
|
|
if social_context.direct_mentions:
|
|
context += f"📈 Recent Stock Discussions ({len(social_context.direct_mentions)} posts):\n"
|
|
|
|
# Group by subreddit for better organization
|
|
by_subreddit = {}
|
|
for post in social_context.direct_mentions:
|
|
if post.subreddit not in by_subreddit:
|
|
by_subreddit[post.subreddit] = []
|
|
by_subreddit[post.subreddit].append(post)
|
|
|
|
for subreddit, posts in by_subreddit.items():
|
|
context += f"\nr/{subreddit} ({len(posts)} posts):\n"
|
|
for post in posts[:3]: # Top 3 per subreddit
|
|
sentiment_info = ""
|
|
if post.has_reliable_sentiment():
|
|
sentiment_emoji = {"positive": "📈", "negative": "📉", "neutral": "➡️"}
|
|
emoji = sentiment_emoji.get(post.sentiment_label, "❓")
|
|
sentiment_info = f" {emoji} {post.sentiment_label}"
|
|
|
|
context += f" • {post.title[:80]}...{sentiment_info}\n"
|
|
context += f" {post.upvotes} upvotes, {post.comments_count} comments\n"
|
|
|
|
# Add trending context if available
|
|
if social_context.trending_info:
|
|
trend = social_context.trending_info
|
|
context += f"\n📊 Trending Analysis:\n"
|
|
context += f" • Market attention: {trend['mention_count']} recent mentions\n"
|
|
context += f" • Community sentiment: {trend['avg_sentiment']:.2f}/1.0\n"
|
|
context += f" • Total engagement: {trend['engagement_score']}\n"
|
|
|
|
return context
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting Reddit stock info for {ticker}: {e}")
|
|
return f"Error retrieving Reddit stock information for ${ticker}: {str(e)}"
|
|
|
|
async def search_social_posts(
|
|
self,
|
|
query: str,
|
|
ticker: Optional[str] = None,
|
|
limit: int = 10,
|
|
days: int = 30
|
|
) -> str:
|
|
"""Search social media posts using semantic similarity"""
|
|
try:
|
|
if not hasattr(self, 'social_service'):
|
|
self.social_service = SocialMediaService(self.config, self.db_manager)
|
|
|
|
# Perform semantic search
|
|
search_results = await self.social_service.search_posts_semantic(
|
|
query=query,
|
|
ticker=ticker,
|
|
limit=limit,
|
|
min_similarity=0.6
|
|
)
|
|
|
|
if not search_results:
|
|
ticker_context = f" about ${ticker}" if ticker else ""
|
|
return f"No relevant social media posts found for '{query}'{ticker_context}."
|
|
|
|
ticker_context = f" (${ticker})" if ticker else ""
|
|
context = f"Social Media Search Results for '{query}'{ticker_context}:\n\n"
|
|
context += f"Found {len(search_results)} relevant posts:\n\n"
|
|
|
|
for i, (post, similarity) in enumerate(search_results):
|
|
context += f"{i+1}. Relevance: {similarity:.2%} | r/{post.subreddit}\n"
|
|
context += f" Title: {post.title}\n"
|
|
|
|
if post.content:
|
|
# Show relevant snippet
|
|
content_preview = post.content[:200] + "..." if len(post.content) > 200 else post.content
|
|
context += f" Content: {content_preview}\n"
|
|
|
|
# Add sentiment if available
|
|
if post.has_reliable_sentiment():
|
|
sentiment_emoji = {"positive": "📈", "negative": "📉", "neutral": "➡️"}.get(
|
|
post.sentiment_label, "❓"
|
|
)
|
|
context += f" Sentiment: {sentiment_emoji} {post.sentiment_label} ({post.sentiment_confidence:.2f})\n"
|
|
|
|
# Add engagement metrics
|
|
context += f" Engagement: {post.upvotes} upvotes, {post.comments_count} comments\n"
|
|
context += f" Posted: {post.created_utc.strftime('%Y-%m-%d %H:%M')} UTC\n\n"
|
|
|
|
return context
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching social posts for '{query}': {e}")
|
|
return f"Error searching social media posts: {str(e)}"
|
|
|
|
async def get_subreddit_analysis(
|
|
self,
|
|
subreddit: str,
|
|
ticker: Optional[str] = None,
|
|
hours: int = 24
|
|
) -> str:
|
|
"""Get analysis of activity in a specific financial subreddit"""
|
|
try:
|
|
if not hasattr(self, 'social_service'):
|
|
self.social_service = SocialMediaService(self.config, self.db_manager)
|
|
|
|
analysis = await self.social_service.get_subreddit_analysis(subreddit, hours=hours)
|
|
|
|
if analysis['total_posts'] == 0:
|
|
return f"No recent activity found in r/{subreddit} in the last {hours} hours."
|
|
|
|
context = f"r/{subreddit} Analysis ({hours}-hour period):\n\n"
|
|
|
|
# Activity overview
|
|
context += f"📊 Activity Overview:\n"
|
|
context += f" • Total Posts: {analysis['total_posts']}\n"
|
|
context += f" • Total Upvotes: {analysis['engagement_metrics']['total_upvotes']:,}\n"
|
|
context += f" • Total Comments: {analysis['engagement_metrics']['total_comments']:,}\n"
|
|
context += f" • Avg Score: {analysis['engagement_metrics']['avg_score']:.1f}\n"
|
|
context += f" • Top Post Score: {analysis['engagement_metrics']['top_post_score']:,}\n\n"
|
|
|
|
# Sentiment analysis
|
|
sentiment_dist = analysis['sentiment_analysis']['distribution']
|
|
reliable_ratio = analysis['sentiment_analysis']['sentiment_reliability']
|
|
|
|
context += f"😊 Sentiment Analysis:\n"
|
|
context += f" • Positive: {sentiment_dist['positive']} posts\n"
|
|
context += f" • Negative: {sentiment_dist['negative']} posts\n"
|
|
context += f" • Neutral: {sentiment_dist['neutral']} posts\n"
|
|
context += f" • Reliability: {reliable_ratio:.1%} of posts have confident sentiment scores\n\n"
|
|
|
|
# Ticker mentions
|
|
ticker_info = analysis['ticker_mentions']
|
|
context += f"💰 Stock Mentions:\n"
|
|
context += f" • Unique Tickers: {ticker_info['unique_tickers']}\n"
|
|
context += f" • Total Mentions: {ticker_info['total_mentions']}\n"
|
|
|
|
if ticker_info['top_tickers']:
|
|
context += f" • Most Discussed:\n"
|
|
for ticker_symbol, count in ticker_info['top_tickers'][:5]:
|
|
context += f" - ${ticker_symbol}: {count} mentions\n"
|
|
|
|
# Filter for specific ticker if requested
|
|
if ticker:
|
|
ticker_mentions = next(
|
|
(count for symbol, count in ticker_info['top_tickers'] if symbol == ticker.upper()),
|
|
0
|
|
)
|
|
if ticker_mentions > 0:
|
|
context += f"\n🎯 ${ticker} Activity: {ticker_mentions} mentions in this period\n"
|
|
else:
|
|
context += f"\n🎯 ${ticker}: No mentions found in r/{subreddit} during this period\n"
|
|
|
|
# Data quality
|
|
quality = analysis['data_quality']['overall_score']
|
|
context += f"\n📋 Data Quality Score: {quality:.1%}\n"
|
|
|
|
return context
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error analyzing subreddit {subreddit}: {e}")
|
|
return f"Error analyzing r/{subreddit}: {str(e)}"
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] get_reddit_sentiment() provides comprehensive sentiment analysis with visual formatting
|
|
- [ ] get_reddit_stock_info() supports both general info and semantic search queries
|
|
- [ ] search_social_posts() enables semantic search across all social media content
|
|
- [ ] get_subreddit_analysis() provides detailed subreddit activity and ticker analysis
|
|
- [ ] All methods return human-readable formatted strings for AI agent consumption
|
|
- [ ] Proper error handling with fallback responses
|
|
- [ ] Methods integrate seamlessly with existing AgentToolkit patterns
|
|
- [ ] Test coverage with mocked service dependencies
|
|
|
|
**Dependencies:** Task 2.4 (SocialMediaService implementation)
|
|
**Risk:** Low - Standard AgentToolkit integration patterns
|
|
|
|
---
|
|
|
|
### Task 3.2: Dagster Pipeline Implementation (2 hours)
|
|
**Priority: Medium** | **Agent: Pipeline Specialist**
|
|
|
|
Implement Dagster asset for scheduled social media collection and processing.
|
|
|
|
**File:** `tradingagents/data/assets/social_media.py`
|
|
|
|
**Implementation:**
|
|
```python
|
|
from dagster import asset, AssetExecutionContext, Config, DailyPartitionsDefinition
|
|
from typing import Dict, Any, List
|
|
import asyncio
|
|
from datetime import datetime, timedelta
|
|
|
|
from tradingagents.domains.socialmedia.services import SocialMediaService
|
|
from tradingagents.config import TradingAgentsConfig
|
|
from tradingagents.database import DatabaseManager
|
|
|
|
class SocialMediaCollectionConfig(Config):
|
|
"""Configuration for social media collection"""
|
|
subreddits: List[str] = [
|
|
'wallstreetbets', 'investing', 'stocks', 'SecurityAnalysis',
|
|
'ValueInvesting', 'StockMarket', 'options'
|
|
]
|
|
time_filter: str = 'day'
|
|
process_sentiment: bool = True
|
|
generate_embeddings: bool = True
|
|
max_posts_per_subreddit: int = 50
|
|
cleanup_old_data: bool = True
|
|
retention_days: int = 90
|
|
|
|
@asset(
|
|
partitions_def=DailyPartitionsDefinition(start_date="2024-01-01"),
|
|
group_name="social_media",
|
|
description="Daily collection of Reddit posts from financial subreddits with sentiment analysis and embeddings",
|
|
compute_kind="python",
|
|
tags={"domain": "socialmedia", "source": "reddit"}
|
|
)
|
|
async def reddit_financial_posts(
|
|
context: AssetExecutionContext,
|
|
config: SocialMediaCollectionConfig
|
|
) -> Dict[str, Any]:
|
|
"""Daily collection and processing of Reddit financial posts"""
|
|
|
|
partition_date = context.partition_key
|
|
context.log.info(f"Starting social media collection for partition: {partition_date}")
|
|
|
|
# Initialize services
|
|
trading_config = TradingAgentsConfig.from_env()
|
|
db_manager = DatabaseManager(trading_config)
|
|
social_service = SocialMediaService(trading_config, db_manager)
|
|
|
|
collection_start = datetime.now()
|
|
|
|
try:
|
|
# Main collection and processing
|
|
results = await social_service.collect_and_process_posts(
|
|
subreddits=config.subreddits,
|
|
time_filter=config.time_filter,
|
|
process_sentiment=config.process_sentiment,
|
|
generate_embeddings=config.generate_embeddings
|
|
)
|
|
|
|
# Log detailed results
|
|
context.log.info(f"Collection completed successfully:")
|
|
context.log.info(f" - Total posts collected: {results['total_posts_collected']}")
|
|
context.log.info(f" - Posts processed: {results['posts_processed']}")
|
|
context.log.info(f" - Posts saved: {results['posts_saved']}")
|
|
context.log.info(f" - Processing time: {results['processing_time_seconds']:.2f}s")
|
|
|
|
# Log per-subreddit breakdown
|
|
for subreddit, breakdown in results['subreddit_breakdown'].items():
|
|
context.log.info(f" - r/{subreddit}: {breakdown['posts_collected']} collected, "
|
|
f"{breakdown['posts_filtered']} after filtering")
|
|
|
|
# Data quality check
|
|
if results['posts_saved'] == 0:
|
|
context.log.warning("No posts were saved - possible data quality issues")
|
|
elif results['posts_saved'] < results['posts_processed'] * 0.5:
|
|
context.log.warning(f"Low save rate: {results['posts_saved']}/{results['posts_processed']} posts saved")
|
|
|
|
# Cleanup old data if configured
|
|
if config.cleanup_old_data:
|
|
try:
|
|
deleted_count = await social_service.repository.cleanup_old_posts(
|
|
days=config.retention_days
|
|
)
|
|
context.log.info(f"Cleaned up {deleted_count} posts older than {config.retention_days} days")
|
|
results['cleanup_deleted_count'] = deleted_count
|
|
except Exception as e:
|
|
context.log.error(f"Cleanup failed: {e}")
|
|
results['cleanup_error'] = str(e)
|
|
|
|
# Add partition metadata
|
|
results.update({
|
|
'partition_date': partition_date,
|
|
'asset_name': 'reddit_financial_posts',
|
|
'collection_success': True
|
|
})
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
context.log.error(f"Social media collection failed: {e}")
|
|
|
|
# Return error results for monitoring
|
|
return {
|
|
'partition_date': partition_date,
|
|
'asset_name': 'reddit_financial_posts',
|
|
'collection_success': False,
|
|
'error_message': str(e),
|
|
'processing_time_seconds': (datetime.now() - collection_start).total_seconds(),
|
|
'total_posts_collected': 0,
|
|
'posts_processed': 0,
|
|
'posts_saved': 0
|
|
}
|
|
|
|
finally:
|
|
# Always close database connections
|
|
if 'db_manager' in locals():
|
|
await db_manager.close_all()
|
|
|
|
@asset(
|
|
deps=[reddit_financial_posts],
|
|
group_name="social_media",
|
|
description="Generate daily social media analytics and trending analysis",
|
|
compute_kind="python",
|
|
tags={"domain": "socialmedia", "analytics": "trending"}
|
|
)
|
|
async def social_media_analytics(context: AssetExecutionContext) -> Dict[str, Any]:
|
|
"""Generate analytics and trending analysis from collected social media data"""
|
|
|
|
context.log.info("Generating social media analytics")
|
|
|
|
# Initialize services
|
|
trading_config = TradingAgentsConfig.from_env()
|
|
db_manager = DatabaseManager(trading_config)
|
|
social_service = SocialMediaService(trading_config, db_manager)
|
|
|
|
try:
|
|
# Get trending tickers analysis
|
|
trending_tickers = await social_service.repository.get_trending_tickers(
|
|
hours=24,
|
|
min_mentions=5
|
|
)
|
|
|
|
context.log.info(f"Found {len(trending_tickers)} trending tickers")
|
|
|
|
# Analyze top subreddits
|
|
financial_subreddits = [
|
|
'wallstreetbets', 'investing', 'stocks', 'SecurityAnalysis',
|
|
'ValueInvesting', 'StockMarket'
|
|
]
|
|
|
|
subreddit_analysis = {}
|
|
for subreddit in financial_subreddits:
|
|
analysis = await social_service.get_subreddit_analysis(subreddit, hours=24)
|
|
subreddit_analysis[subreddit] = analysis
|
|
|
|
if analysis['total_posts'] > 0:
|
|
context.log.info(f"r/{subreddit}: {analysis['total_posts']} posts, "
|
|
f"{analysis['ticker_mentions']['unique_tickers']} unique tickers")
|
|
|
|
# Calculate overall sentiment trends
|
|
overall_sentiment_summary = {}
|
|
for ticker_info in trending_tickers[:10]: # Top 10 trending
|
|
ticker = ticker_info['ticker']
|
|
sentiment_data = await social_service.repository.get_sentiment_summary(
|
|
ticker=ticker,
|
|
hours=24
|
|
)
|
|
overall_sentiment_summary[ticker] = sentiment_data
|
|
|
|
analytics_results = {
|
|
'generated_at': datetime.now().isoformat(),
|
|
'period_hours': 24,
|
|
'trending_tickers': trending_tickers,
|
|
'subreddit_analysis': subreddit_analysis,
|
|
'sentiment_trends': overall_sentiment_summary,
|
|
'analytics_success': True
|
|
}
|
|
|
|
# Log key insights
|
|
if trending_tickers:
|
|
top_ticker = trending_tickers[0]
|
|
context.log.info(f"Most trending ticker: ${top_ticker['ticker']} "
|
|
f"({top_ticker['mention_count']} mentions, "
|
|
f"{top_ticker['avg_sentiment']:.2f} sentiment)")
|
|
|
|
return analytics_results
|
|
|
|
except Exception as e:
|
|
context.log.error(f"Analytics generation failed: {e}")
|
|
return {
|
|
'generated_at': datetime.now().isoformat(),
|
|
'analytics_success': False,
|
|
'error_message': str(e)
|
|
}
|
|
|
|
finally:
|
|
if 'db_manager' in locals():
|
|
await db_manager.close_all()
|
|
|
|
@asset(
|
|
deps=[social_media_analytics],
|
|
group_name="social_media",
|
|
description="Data quality monitoring and validation for social media pipeline",
|
|
compute_kind="python",
|
|
tags={"domain": "socialmedia", "monitoring": "data_quality"}
|
|
)
|
|
async def social_media_quality_check(context: AssetExecutionContext) -> Dict[str, Any]:
|
|
"""Monitor data quality and pipeline health for social media assets"""
|
|
|
|
context.log.info("Performing social media data quality checks")
|
|
|
|
trading_config = TradingAgentsConfig.from_env()
|
|
db_manager = DatabaseManager(trading_config)
|
|
social_service = SocialMediaService(trading_config, db_manager)
|
|
|
|
try:
|
|
# Check recent data volume
|
|
recent_posts = await social_service.repository.find_by_subreddit(
|
|
'wallstreetbets', # Use as representative subreddit
|
|
hours=24,
|
|
limit=1000
|
|
)
|
|
|
|
# Quality metrics
|
|
total_posts = len(recent_posts)
|
|
posts_with_sentiment = sum(1 for p in recent_posts if p.sentiment_label is not None)
|
|
posts_with_embeddings = sum(
|
|
1 for p in recent_posts
|
|
if p.title_embedding is not None or p.content_embedding is not None
|
|
)
|
|
posts_with_tickers = sum(1 for p in recent_posts if p.tickers)
|
|
|
|
# Calculate quality percentages
|
|
sentiment_coverage = posts_with_sentiment / total_posts if total_posts > 0 else 0
|
|
embedding_coverage = posts_with_embeddings / total_posts if total_posts > 0 else 0
|
|
ticker_coverage = posts_with_tickers / total_posts if total_posts > 0 else 0
|
|
|
|
# Quality thresholds
|
|
quality_checks = {
|
|
'data_volume_check': total_posts >= 100, # Expect at least 100 posts per day
|
|
'sentiment_coverage_check': sentiment_coverage >= 0.8, # 80% should have sentiment
|
|
'embedding_coverage_check': embedding_coverage >= 0.7, # 70% should have embeddings
|
|
'ticker_coverage_check': ticker_coverage >= 0.3 # 30% should have ticker mentions
|
|
}
|
|
|
|
overall_health = all(quality_checks.values())
|
|
|
|
# Log quality results
|
|
context.log.info(f"Data quality assessment:")
|
|
context.log.info(f" - Total posts (24h): {total_posts}")
|
|
context.log.info(f" - Sentiment coverage: {sentiment_coverage:.1%}")
|
|
context.log.info(f" - Embedding coverage: {embedding_coverage:.1%}")
|
|
context.log.info(f" - Ticker coverage: {ticker_coverage:.1%}")
|
|
context.log.info(f" - Overall health: {'PASS' if overall_health else 'FAIL'}")
|
|
|
|
# Alert on quality issues
|
|
for check_name, passed in quality_checks.items():
|
|
if not passed:
|
|
context.log.warning(f"Quality check failed: {check_name}")
|
|
|
|
return {
|
|
'check_timestamp': datetime.now().isoformat(),
|
|
'total_posts_24h': total_posts,
|
|
'quality_metrics': {
|
|
'sentiment_coverage': sentiment_coverage,
|
|
'embedding_coverage': embedding_coverage,
|
|
'ticker_coverage': ticker_coverage
|
|
},
|
|
'quality_checks': quality_checks,
|
|
'overall_health': overall_health,
|
|
'quality_check_success': True
|
|
}
|
|
|
|
except Exception as e:
|
|
context.log.error(f"Quality check failed: {e}")
|
|
return {
|
|
'check_timestamp': datetime.now().isoformat(),
|
|
'quality_check_success': False,
|
|
'error_message': str(e)
|
|
}
|
|
|
|
finally:
|
|
if 'db_manager' in locals():
|
|
await db_manager.close_all()
|
|
|
|
# Schedule configuration for the social media pipeline
|
|
SOCIAL_MEDIA_SCHEDULE = {
|
|
"reddit_financial_posts": "0 6,18 * * *", # 6 AM and 6 PM UTC daily
|
|
"social_media_analytics": "30 7,19 * * *", # 30 minutes after collection
|
|
"social_media_quality_check": "0 8,20 * * *" # 1 hour after collection
|
|
}
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] Daily scheduled collection from financial subreddits
|
|
- [ ] Sentiment analysis and embedding generation in pipeline
|
|
- [ ] Analytics generation with trending ticker analysis
|
|
- [ ] Data quality monitoring with configurable thresholds
|
|
- [ ] Proper error handling and logging throughout pipeline
|
|
- [ ] Cleanup of old data based on retention policies
|
|
- [ ] Integration with existing Dagster infrastructure
|
|
- [ ] Monitoring and alerting on pipeline failures
|
|
|
|
**Dependencies:** Task 2.4 (SocialMediaService)
|
|
**Risk:** Low - Standard Dagster asset patterns
|
|
|
|
---
|
|
|
|
### Task 3.3: Comprehensive Testing Suite (3 hours)
|
|
**Priority: High** | **Agent: Testing Specialist**
|
|
|
|
Implement comprehensive test suite covering all socialmedia domain components with >85% coverage.
|
|
|
|
**Test Structure:**
|
|
```
|
|
tests/domains/socialmedia/
|
|
├── conftest.py # Fixtures and test configuration
|
|
├── test_entities.py # SQLAlchemy entity tests
|
|
├── test_models.py # Domain model validation tests
|
|
├── test_reddit_client.py # API integration with VCR
|
|
├── test_sentiment_analyzer.py # LLM sentiment analysis
|
|
├── test_embedding_generator.py # Vector embedding generation
|
|
├── test_social_repository.py # Database operations
|
|
├── test_social_service.py # Service orchestration
|
|
├── test_agent_toolkit.py # AgentToolkit integration
|
|
├── test_dagster_assets.py # Pipeline testing
|
|
└── fixtures/
|
|
├── reddit_responses.yaml # VCR cassettes
|
|
├── sample_posts.json # Test data
|
|
└── embeddings.json # Sample embeddings
|
|
```
|
|
|
|
**Implementation Samples:**
|
|
|
|
**conftest.py:**
|
|
```python
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import MagicMock, AsyncMock
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from tradingagents.config import TradingAgentsConfig
|
|
from tradingagents.database import DatabaseManager
|
|
from tradingagents.domains.socialmedia.entities import SocialMediaPostEntity
|
|
from tradingagents.domains.socialmedia.models import SocialPost
|
|
from tradingagents.domains.socialmedia.services import SocialMediaService
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop():
|
|
"""Create event loop for async tests"""
|
|
loop = asyncio.new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
@pytest.fixture
|
|
def test_config():
|
|
"""Test configuration"""
|
|
return TradingAgentsConfig(
|
|
reddit_client_id="test_client_id",
|
|
reddit_client_secret="test_secret",
|
|
reddit_user_agent="test_agent",
|
|
openrouter_api_key="test_openrouter_key",
|
|
quick_think_llm="test/model",
|
|
database_url="sqlite:///test.db"
|
|
)
|
|
|
|
@pytest.fixture
|
|
async def db_session(test_config):
|
|
"""Test database session"""
|
|
engine = create_engine(test_config.database_url, echo=False)
|
|
SocialMediaPostEntity.metadata.create_all(engine)
|
|
SessionLocal = sessionmaker(bind=engine)
|
|
session = SessionLocal()
|
|
yield session
|
|
session.close()
|
|
SocialMediaPostEntity.metadata.drop_all(engine)
|
|
|
|
@pytest.fixture
|
|
def sample_social_post():
|
|
"""Sample SocialPost for testing"""
|
|
return SocialPost(
|
|
post_id="test123",
|
|
title="AAPL to the moon! 🚀",
|
|
content="Apple stock is going to explode higher after earnings!",
|
|
author="test_user",
|
|
subreddit="wallstreetbets",
|
|
created_utc=datetime(2024, 1, 15, 10, 0, 0),
|
|
upvotes=150,
|
|
downvotes=25,
|
|
comments_count=45,
|
|
url="https://reddit.com/r/wallstreetbets/test123",
|
|
tickers=["AAPL"],
|
|
sentiment_score=0.8,
|
|
sentiment_label="positive",
|
|
sentiment_confidence=0.9
|
|
)
|
|
|
|
@pytest.fixture
|
|
def mock_social_service(test_config):
|
|
"""Mocked SocialMediaService"""
|
|
service = MagicMock(spec=SocialMediaService)
|
|
service.config = test_config
|
|
service.repository = AsyncMock()
|
|
service.sentiment_analyzer = AsyncMock()
|
|
service.embedding_generator = AsyncMock()
|
|
return service
|
|
```
|
|
|
|
**test_models.py:**
|
|
```python
|
|
import pytest
|
|
from datetime import datetime
|
|
from tradingagents.domains.socialmedia.models import SocialPost, SentimentScore
|
|
|
|
def test_social_post_validation():
|
|
"""Test SocialPost validation rules"""
|
|
# Valid post
|
|
post = SocialPost(
|
|
post_id="abc123",
|
|
title="Test post",
|
|
author="test_user",
|
|
subreddit="stocks",
|
|
created_utc=datetime.now(),
|
|
upvotes=10,
|
|
downvotes=2,
|
|
comments_count=5,
|
|
url="https://reddit.com/test"
|
|
)
|
|
assert post.post_id == "abc123"
|
|
assert post.tickers == []
|
|
|
|
def test_extract_tickers():
|
|
"""Test ticker extraction from post content"""
|
|
post = SocialPost(
|
|
post_id="abc123",
|
|
title="AAPL and $TSLA are great buys",
|
|
content="I think MSFT will outperform this year",
|
|
author="test_user",
|
|
subreddit="investing",
|
|
created_utc=datetime.now(),
|
|
upvotes=10,
|
|
downvotes=0,
|
|
comments_count=3,
|
|
url="https://reddit.com/test"
|
|
)
|
|
|
|
tickers = post.extract_tickers()
|
|
assert "AAPL" in tickers
|
|
assert "TSLA" in tickers
|
|
assert "MSFT" in tickers
|
|
assert len(tickers) == 3
|
|
|
|
def test_sentiment_validation():
|
|
"""Test sentiment score validation"""
|
|
# Valid sentiment
|
|
sentiment = SentimentScore(
|
|
sentiment="positive",
|
|
confidence=0.85,
|
|
reasoning="Bullish language and positive outlook"
|
|
)
|
|
assert sentiment.confidence == 0.85
|
|
|
|
# Invalid confidence
|
|
with pytest.raises(ValueError):
|
|
SentimentScore(
|
|
sentiment="positive",
|
|
confidence=1.5 # > 1.0
|
|
)
|
|
|
|
@pytest.mark.parametrize("sentiment_score,sentiment_label,confidence,expected_reliable", [
|
|
(0.8, "positive", 0.9, True),
|
|
(0.3, "neutral", 0.4, False),
|
|
(-0.6, "negative", 0.7, True),
|
|
(None, None, None, False)
|
|
])
|
|
def test_has_reliable_sentiment(sentiment_score, sentiment_label, confidence, expected_reliable):
|
|
"""Test sentiment reliability check"""
|
|
post = SocialPost(
|
|
post_id="test",
|
|
title="Test",
|
|
author="user",
|
|
subreddit="test",
|
|
created_utc=datetime.now(),
|
|
upvotes=1,
|
|
downvotes=0,
|
|
comments_count=0,
|
|
url="test",
|
|
sentiment_score=sentiment_score,
|
|
sentiment_label=sentiment_label,
|
|
sentiment_confidence=confidence
|
|
)
|
|
|
|
assert post.has_reliable_sentiment() == expected_reliable
|
|
```
|
|
|
|
**test_social_repository.py:**
|
|
```python
|
|
import pytest
|
|
from datetime import datetime, timedelta
|
|
from tradingagents.domains.socialmedia.repositories import SocialRepository
|
|
from tradingagents.domains.socialmedia.models import SocialPost
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_batch_deduplication(social_repository, sample_social_post):
|
|
"""Test batch upsert with deduplication"""
|
|
posts = [sample_social_post, sample_social_post] # Duplicate posts
|
|
|
|
saved_ids = await social_repository.upsert_batch(posts)
|
|
|
|
assert len(saved_ids) == 1 # Only one saved due to deduplication
|
|
assert saved_ids[0] == sample_social_post.post_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_find_by_ticker(social_repository, sample_social_post):
|
|
"""Test finding posts by ticker symbol"""
|
|
await social_repository.upsert_batch([sample_social_post])
|
|
|
|
posts = await social_repository.find_by_ticker("AAPL", days=7)
|
|
|
|
assert len(posts) == 1
|
|
assert posts[0].post_id == sample_social_post.post_id
|
|
assert "AAPL" in posts[0].tickers
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vector_similarity_search(social_repository, sample_social_post):
|
|
"""Test vector similarity search"""
|
|
# Add post with embedding
|
|
sample_social_post.title_embedding = [0.1] * 1536 # Mock embedding
|
|
await social_repository.upsert_batch([sample_social_post])
|
|
|
|
# Search with similar embedding
|
|
query_embedding = [0.1] * 1536
|
|
results = await social_repository.find_similar_posts(
|
|
query_embedding=query_embedding,
|
|
limit=5
|
|
)
|
|
|
|
assert len(results) >= 0 # May be empty if similarity too low
|
|
if results:
|
|
post, similarity = results[0]
|
|
assert isinstance(similarity, float)
|
|
assert 0 <= similarity <= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sentiment_summary(social_repository, sample_social_post):
|
|
"""Test sentiment aggregation"""
|
|
await social_repository.upsert_batch([sample_social_post])
|
|
|
|
summary = await social_repository.get_sentiment_summary(
|
|
ticker="AAPL",
|
|
hours=24
|
|
)
|
|
|
|
assert summary['ticker'] == "AAPL"
|
|
assert summary['total_posts'] >= 0
|
|
assert 'sentiment_breakdown' in summary
|
|
assert 'overall_sentiment' in summary
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_old_posts(social_repository, sample_social_post):
|
|
"""Test cleanup of old posts"""
|
|
# Create old post
|
|
old_post = sample_social_post.copy()
|
|
old_post.post_id = "old_post"
|
|
old_post.created_utc = datetime.now() - timedelta(days=100)
|
|
|
|
await social_repository.upsert_batch([old_post])
|
|
|
|
deleted_count = await social_repository.cleanup_old_posts(days=90)
|
|
|
|
assert deleted_count >= 1
|
|
```
|
|
|
|
**test_reddit_client.py (with VCR):**
|
|
```python
|
|
import pytest
|
|
import pytest_vcr
|
|
from tradingagents.domains.socialmedia.clients import RedditClient
|
|
|
|
@pytest_vcr.use_cassette('fixtures/reddit_fetch_posts.yaml')
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_subreddit_posts(test_config):
|
|
"""Test fetching posts from Reddit API"""
|
|
async with RedditClient(test_config) as client:
|
|
posts = await client.fetch_subreddit_posts(
|
|
subreddit_name="wallstreetbets",
|
|
limit=10
|
|
)
|
|
|
|
assert len(posts) > 0
|
|
for post in posts:
|
|
assert 'post_id' in post
|
|
assert 'title' in post
|
|
assert 'subreddit' in post
|
|
assert post['subreddit'] == 'wallstreetbets'
|
|
|
|
@pytest_vcr.use_cassette('fixtures/reddit_search.yaml')
|
|
@pytest.mark.asyncio
|
|
async def test_search_posts(test_config):
|
|
"""Test Reddit post search functionality"""
|
|
async with RedditClient(test_config) as client:
|
|
posts = await client.search_posts(
|
|
query="AAPL",
|
|
subreddit_names=["investing"],
|
|
limit=5
|
|
)
|
|
|
|
assert isinstance(posts, list)
|
|
if posts: # May be empty in test
|
|
for post in posts:
|
|
assert 'post_id' in post
|
|
assert 'title' in post
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check(test_config):
|
|
"""Test Reddit API health check"""
|
|
async with RedditClient(test_config) as client:
|
|
health = await client.health_check()
|
|
assert isinstance(health, bool)
|
|
```
|
|
|
|
**Acceptance Criteria:**
|
|
- [ ] >85% test coverage across all socialmedia domain components
|
|
- [ ] Unit tests for all domain models with validation edge cases
|
|
- [ ] Integration tests for Reddit API client with VCR cassettes
|
|
- [ ] Repository tests with real PostgreSQL database operations
|
|
- [ ] Service layer tests with proper mocking of dependencies
|
|
- [ ] AgentToolkit integration tests
|
|
- [ ] Dagster pipeline asset tests with mocked data
|
|
- [ ] Performance benchmarks for vector similarity queries
|
|
- [ ] Error handling and edge case coverage
|
|
- [ ] Test fixtures and sample data for consistent testing
|
|
|
|
**Dependencies:** All implementation tasks
|
|
**Risk:** Low - Standard testing patterns
|
|
|
|
---
|
|
|
|
## Implementation Dependencies & Parallel Execution
|
|
|
|
### Phase 1 Dependencies
|
|
- Task 1.1 → Task 1.2 (Entity depends on database schema)
|
|
- Task 1.3 can run parallel with 1.1 and 1.2
|
|
- Task 1.4 depends on 1.1 and 1.2
|
|
|
|
### Phase 2 Dependencies
|
|
- All Phase 2 tasks can run in parallel
|
|
- Task 2.4 depends on 2.1, 2.2, and 2.3
|
|
|
|
### Phase 3 Dependencies
|
|
- Task 3.1 depends on Task 2.4
|
|
- Task 3.2 depends on Task 2.4
|
|
- Task 3.3 can start after any component is complete
|
|
|
|
### Risk Assessment
|
|
|
|
**High Risk Tasks:**
|
|
- Task 2.1 (Reddit Client) - External API complexity, rate limiting
|
|
|
|
**Medium Risk Tasks:**
|
|
- Task 1.1 (Database Migration) - Extension dependencies
|
|
- Task 1.4 (Repository) - Complex vector queries
|
|
- Task 2.2 (Sentiment Analysis) - LLM API reliability
|
|
- Task 2.4 (Service Layer) - Complex orchestration
|
|
|
|
**Low Risk Tasks:**
|
|
- Task 1.2 (Entity Implementation)
|
|
- Task 1.3 (Domain Models)
|
|
- Task 2.3 (Embedding Generation)
|
|
- Task 3.1 (AgentToolkit Integration)
|
|
- Task 3.2 (Dagster Pipeline)
|
|
- Task 3.3 (Testing Suite)
|
|
|
|
## Success Criteria Summary
|
|
|
|
### Functionality
|
|
- ✅ Complete Reddit data collection with PRAW integration
|
|
- ✅ OpenRouter LLM sentiment analysis with confidence scoring
|
|
- ✅ Vector embeddings for semantic similarity search
|
|
- ✅ PostgreSQL + TimescaleDB + pgvectorscale data persistence
|
|
- ✅ AgentToolkit RAG methods for AI agent integration
|
|
- ✅ Daily Dagster pipeline for automated collection
|
|
- ✅ Comprehensive error handling and resilience
|
|
|
|
### Performance
|
|
- ✅ <2 second social context queries for AI agents
|
|
- ✅ <1 second vector similarity search (top 10 results)
|
|
- ✅ <5 seconds batch processing 1000 posts
|
|
- ✅ Efficient TimescaleDB time-series queries
|
|
|
|
### Quality
|
|
- ✅ >85% test coverage across all components
|
|
- ✅ Data quality monitoring and validation
|
|
- ✅ Comprehensive logging and observability
|
|
- ✅ Best-effort processing with graceful degradation
|
|
|
|
### Integration
|
|
- ✅ Seamless integration with existing TradingAgents architecture
|
|
- ✅ Follows news domain patterns for consistency
|
|
- ✅ Compatible with multi-agent trading workflows
|
|
- ✅ Production-ready deployment capability
|
|
|
|
This comprehensive task breakdown enables efficient parallel development by multiple AI agents while ensuring complete coverage of the socialmedia domain implementation requirements. |