TradingAgents/tradingagents/dataflows/ticker_semantic_db.py

395 lines
14 KiB
Python

"""
Ticker Semantic Database
------------------------
Creates and maintains a database of ticker descriptions with embeddings
for semantic matching against news events.
This enables news-driven discovery by finding tickers semantically related
to breaking news, rather than waiting for social media buzz or price action.
"""
import json
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
import chromadb
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
from tradingagents.dataflows.y_finance import get_ticker_info
from tradingagents.utils.logger import get_logger
# Load environment variables
load_dotenv()
logger = get_logger(__name__)
class TickerSemanticDB:
"""Manages ticker descriptions and embeddings for semantic search."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize the ticker semantic database.
Args:
config: Configuration dict with:
- project_dir: Base directory for storage
- use_openai_embeddings: If True, use OpenAI; else use local HF model
- embedding_model: Model name (default: text-embedding-3-small)
"""
self.config = config
self.use_openai = config.get("use_openai_embeddings", True)
# Setup embedding backend
if self.use_openai:
self.embedding_model = config.get("embedding_model", "text-embedding-3-small")
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY not found in environment")
self.openai_client = OpenAI(api_key=openai_api_key)
self.embedding_dim = 1536 # OpenAI text-embedding-3-small dimension
else:
# Local sentence-transformers embedding backend.
from sentence_transformers import SentenceTransformer
self.embedding_model = config.get("embedding_model", "BAAI/bge-small-en-v1.5")
self.local_model = SentenceTransformer(self.embedding_model)
self.embedding_dim = self.local_model.get_sentence_embedding_dimension()
# Setup ChromaDB for persistent storage
project_dir = config.get("project_dir", ".")
embedding_model_safe = self.embedding_model.replace("/", "_").replace(" ", "_")
db_dir = os.path.join(project_dir, "ticker_semantic_db", embedding_model_safe)
os.makedirs(db_dir, exist_ok=True)
self.chroma_client = chromadb.PersistentClient(path=db_dir)
# Get or create collection
collection_name = "ticker_descriptions"
try:
self.collection = self.chroma_client.get_collection(name=collection_name)
logger.info(f"Loaded existing ticker database: {self.collection.count()} tickers")
except Exception:
self.collection = self.chroma_client.create_collection(
name=collection_name,
metadata={"description": "Ticker descriptions with metadata for semantic search"},
)
logger.info("Created new ticker database collection")
def get_embedding(self, text: str) -> List[float]:
"""Generate embedding for text using configured backend."""
if self.use_openai:
response = self.openai_client.embeddings.create(model=self.embedding_model, input=text)
return response.data[0].embedding
else:
# Local HuggingFace model
embedding = self.local_model.encode(text, convert_to_numpy=True)
return embedding.tolist()
def fetch_ticker_info(self, symbol: str) -> Optional[Dict[str, Any]]:
"""
Fetch ticker information from Yahoo Finance.
Args:
symbol: Stock ticker symbol
Returns:
Dict with ticker metadata or None if fetch fails
"""
try:
info = get_ticker_info(symbol)
# Extract relevant fields
description = info.get("longBusinessSummary", "")
if not description:
# Fallback to shorter description if available
description = info.get("description", f"{symbol} - No description available")
# Build metadata dict
ticker_data = {
"symbol": symbol.upper(),
"name": info.get("longName", info.get("shortName", symbol)),
"description": description,
"industry": info.get("industry", "Unknown"),
"sector": info.get("sector", "Unknown"),
"market_cap": info.get("marketCap", 0),
"revenue": info.get("totalRevenue", 0),
"country": info.get("country", "US"),
"website": info.get("website", ""),
"employees": info.get("fullTimeEmployees", 0),
"last_updated": datetime.now().isoformat(),
}
return ticker_data
except Exception as e:
logger.warning(f"Error fetching {symbol}: {e}")
return None
def add_ticker(self, symbol: str, force_refresh: bool = False) -> bool:
"""
Add a single ticker to the database.
Args:
symbol: Stock ticker symbol
force_refresh: If True, refresh even if ticker exists
Returns:
True if added successfully, False otherwise
"""
# Check if already exists
if not force_refresh:
try:
existing = self.collection.get(ids=[symbol.upper()])
if existing and existing["ids"]:
return True # Already exists
except Exception:
pass
# Fetch ticker info
ticker_data = self.fetch_ticker_info(symbol)
if not ticker_data:
return False
# Generate embedding from description
try:
embedding = self.get_embedding(ticker_data["description"])
except Exception as e:
logger.error(f"Error generating embedding for {symbol}: {e}")
return False
# Store in ChromaDB
try:
# Store description as document, metadata as metadata, embedding as embedding
self.collection.upsert(
ids=[symbol.upper()],
documents=[ticker_data["description"]],
embeddings=[embedding],
metadatas=[
{
"symbol": ticker_data["symbol"],
"name": ticker_data["name"],
"industry": ticker_data["industry"],
"sector": ticker_data["sector"],
"market_cap": ticker_data["market_cap"],
"revenue": ticker_data["revenue"],
"country": ticker_data["country"],
"website": ticker_data["website"],
"employees": ticker_data["employees"],
"last_updated": ticker_data["last_updated"],
}
],
)
return True
except Exception as e:
logger.error(f"Error storing {symbol}: {e}")
return False
def build_database(
self,
ticker_file: str,
max_tickers: Optional[int] = None,
skip_existing: bool = True,
batch_size: int = 100,
):
"""
Build the ticker database from a file.
Args:
ticker_file: Path to file with ticker symbols (one per line)
max_tickers: Maximum number of tickers to process (None = all)
skip_existing: If True, skip tickers already in DB
batch_size: Number of tickers to process before showing progress
"""
# Read ticker file
with open(ticker_file, "r") as f:
tickers = [line.strip().upper() for line in f if line.strip()]
if max_tickers:
tickers = tickers[:max_tickers]
logger.info("Building ticker semantic database...")
logger.info(f"Source: {ticker_file}")
logger.info(f"Total tickers: {len(tickers)}")
logger.info(f"Embedding model: {self.embedding_model}")
# Get existing tickers if skipping
existing_tickers = set()
if skip_existing:
try:
existing = self.collection.get(include=[])
existing_tickers = set(existing["ids"])
logger.info(f"Existing tickers in DB: {len(existing_tickers)}")
except Exception:
pass
# Process tickers
success_count = 0
skip_count = 0
fail_count = 0
for i, symbol in enumerate(tqdm(tickers, desc="Processing tickers")):
# Skip if exists
if skip_existing and symbol in existing_tickers:
skip_count += 1
continue
# Add ticker
if self.add_ticker(symbol, force_refresh=not skip_existing):
success_count += 1
else:
fail_count += 1
logger.info("Database build complete!")
logger.info(f"Success: {success_count}")
logger.info(f"Skipped: {skip_count}")
logger.info(f"Failed: {fail_count}")
logger.info(f"Total in DB: {self.collection.count()}")
def search_by_text(
self, query_text: str, top_k: int = 10, filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Search for tickers semantically related to query text.
Args:
query_text: Text to search for (e.g., news summary)
top_k: Number of top matches to return
filters: Optional metadata filters (e.g., {"sector": "Technology"})
Returns:
List of ticker matches with metadata and similarity scores
"""
# Generate embedding for query
query_embedding = self.get_embedding(query_text)
# Search ChromaDB
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=filters, # Apply metadata filters if provided
include=["documents", "metadatas", "distances"],
)
# Format results
matches = []
for i in range(len(results["ids"][0])):
distance = results["distances"][0][i]
similarity = 1 / (1 + distance)
match = {
"symbol": results["ids"][0][i],
"description": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"similarity_score": similarity, # Normalize distance to (0, 1]
}
matches.append(match)
return matches
def get_ticker_info(self, symbol: str) -> Optional[Dict[str, Any]]:
"""Get stored information for a specific ticker."""
try:
result = self.collection.get(ids=[symbol.upper()], include=["documents", "metadatas"])
if not result["ids"]:
return None
return {
"symbol": result["ids"][0],
"description": result["documents"][0],
"metadata": result["metadatas"][0],
}
except Exception:
return None
def get_stats(self) -> Dict[str, Any]:
"""Get database statistics."""
try:
count = self.collection.count()
# Get sector breakdown
all_data = self.collection.get(include=["metadatas"])
sectors = {}
industries = {}
for metadata in all_data["metadatas"]:
sector = metadata.get("sector", "Unknown")
industry = metadata.get("industry", "Unknown")
sectors[sector] = sectors.get(sector, 0) + 1
industries[industry] = industries.get(industry, 0) + 1
return {
"total_tickers": count,
"sectors": sectors,
"industries": industries,
"embedding_model": self.embedding_model,
"embedding_dimension": self.embedding_dim,
}
except Exception as e:
return {"error": str(e)}
def main():
"""CLI for building/managing the ticker database."""
import argparse
parser = argparse.ArgumentParser(description="Build ticker semantic database")
parser.add_argument("--ticker-file", default="data/tickers.txt", help="Path to ticker file")
parser.add_argument(
"--max-tickers", type=int, default=None, help="Maximum tickers to process (default: all)"
)
parser.add_argument(
"--use-local",
action="store_true",
help="Use local HuggingFace embeddings instead of OpenAI",
)
parser.add_argument(
"--force-refresh", action="store_true", help="Refresh all tickers even if they exist"
)
parser.add_argument("--stats", action="store_true", help="Show database statistics")
parser.add_argument("--search", type=str, help="Search for tickers by text query")
args = parser.parse_args()
# Load config
from tradingagents.default_config import DEFAULT_CONFIG
config = {
"project_dir": DEFAULT_CONFIG["project_dir"],
"use_openai_embeddings": not args.use_local,
}
# Initialize database
db = TickerSemanticDB(config)
# Execute command
if args.stats:
stats = db.get_stats()
logger.info("📊 Database Statistics:")
logger.info(json.dumps(stats, indent=2))
elif args.search:
logger.info(f"🔍 Searching for: {args.search}")
matches = db.search_by_text(args.search, top_k=10)
logger.info("Top matches:")
for i, match in enumerate(matches, 1):
logger.info(f"{i}. {match['symbol']} - {match['metadata']['name']}")
logger.debug(f" Sector: {match['metadata']['sector']}")
logger.debug(f" Similarity: {match['similarity_score']:.3f}")
logger.debug(f" Description: {match['description'][:150]}...")
else:
# Build database
db.build_database(
ticker_file=args.ticker_file,
max_tickers=args.max_tickers,
skip_existing=not args.force_refresh,
)
if __name__ == "__main__":
main()