TradingAgents/tradingagents/database/services/discovery.py

159 lines
4.9 KiB
Python

import json
from datetime import datetime
from sqlalchemy.orm import Session
from tradingagents.database.models import (
DiscoveryArticle,
DiscoveryRun,
TrendingStockResult,
)
from tradingagents.database.repositories.base import BaseRepository
class DiscoveryRunRepository(BaseRepository[DiscoveryRun]):
def __init__(self, session: Session):
super().__init__(session, DiscoveryRun)
def get_latest(self) -> DiscoveryRun | None:
return (
self.session.query(DiscoveryRun)
.order_by(DiscoveryRun.created_at.desc())
.first()
)
def get_completed(self, limit: int = 10) -> list[DiscoveryRun]:
return (
self.session.query(DiscoveryRun)
.filter(DiscoveryRun.status == "completed")
.order_by(DiscoveryRun.created_at.desc())
.limit(limit)
.all()
)
class TrendingStockResultRepository(BaseRepository[TrendingStockResult]):
def __init__(self, session: Session):
super().__init__(session, TrendingStockResult)
def get_by_run(self, run_id: str) -> list[TrendingStockResult]:
return (
self.session.query(TrendingStockResult)
.filter(TrendingStockResult.discovery_run_id == run_id)
.order_by(TrendingStockResult.trending_score.desc())
.all()
)
def get_by_ticker(self, ticker: str, limit: int = 10) -> list[TrendingStockResult]:
return (
self.session.query(TrendingStockResult)
.filter(TrendingStockResult.ticker == ticker)
.order_by(TrendingStockResult.created_at.desc())
.limit(limit)
.all()
)
class DiscoveryArticleRepository(BaseRepository[DiscoveryArticle]):
def __init__(self, session: Session):
super().__init__(session, DiscoveryArticle)
def get_by_run(self, run_id: str) -> list[DiscoveryArticle]:
return (
self.session.query(DiscoveryArticle)
.filter(DiscoveryArticle.discovery_run_id == run_id)
.all()
)
class DiscoveryService:
def __init__(self, session: Session):
self.session = session
self.runs = DiscoveryRunRepository(session)
self.stocks = TrendingStockResultRepository(session)
self.articles = DiscoveryArticleRepository(session)
def create_run(self, lookback_period: str, max_results: int) -> DiscoveryRun:
return self.runs.create(
{
"lookback_period": lookback_period,
"max_results": max_results,
"status": "running",
}
)
def save_trending_stock(
self,
run_id: str,
ticker: str,
company_name: str,
trending_score: float,
mention_count: int,
sector: str,
event_type: str,
summary: str | None = None,
source_articles: list[str] | None = None,
) -> TrendingStockResult:
return self.stocks.create(
{
"discovery_run_id": run_id,
"ticker": ticker,
"company_name": company_name,
"trending_score": trending_score,
"mention_count": mention_count,
"sector": sector,
"event_type": event_type,
"summary": summary,
"source_articles": json.dumps(source_articles or []),
}
)
def save_article(
self,
run_id: str,
title: str,
source: str,
url: str | None = None,
published_at: datetime | None = None,
content_snippet: str | None = None,
) -> DiscoveryArticle:
return self.articles.create(
{
"discovery_run_id": run_id,
"title": title,
"source": source,
"url": url,
"published_at": published_at,
"content_snippet": content_snippet,
}
)
def complete_run(self, run_id: str, stocks_found: int) -> DiscoveryRun | None:
run = self.runs.get(run_id)
if run:
run.status = "completed"
run.completed_at = datetime.utcnow()
run.stocks_found = stocks_found
self.session.flush()
return run
def fail_run(self, run_id: str, error: str) -> DiscoveryRun | None:
run = self.runs.get(run_id)
if run:
run.status = "failed"
run.completed_at = datetime.utcnow()
run.error_message = error
self.session.flush()
return run
def get_latest_run(self) -> DiscoveryRun | None:
return self.runs.get_latest()
def get_trending_by_run(self, run_id: str) -> list[TrendingStockResult]:
return self.stocks.get_by_run(run_id)
def get_stock_history(
self, ticker: str, limit: int = 10
) -> list[TrendingStockResult]:
return self.stocks.get_by_ticker(ticker, limit)