159 lines
4.9 KiB
Python
159 lines
4.9 KiB
Python
import json
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from tradingagents.database.models import (
|
|
DiscoveryArticle,
|
|
DiscoveryRun,
|
|
TrendingStockResult,
|
|
)
|
|
from tradingagents.database.repositories.base import BaseRepository
|
|
|
|
|
|
class DiscoveryRunRepository(BaseRepository[DiscoveryRun]):
|
|
def __init__(self, session: Session):
|
|
super().__init__(session, DiscoveryRun)
|
|
|
|
def get_latest(self) -> DiscoveryRun | None:
|
|
return (
|
|
self.session.query(DiscoveryRun)
|
|
.order_by(DiscoveryRun.created_at.desc())
|
|
.first()
|
|
)
|
|
|
|
def get_completed(self, limit: int = 10) -> list[DiscoveryRun]:
|
|
return (
|
|
self.session.query(DiscoveryRun)
|
|
.filter(DiscoveryRun.status == "completed")
|
|
.order_by(DiscoveryRun.created_at.desc())
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
|
|
class TrendingStockResultRepository(BaseRepository[TrendingStockResult]):
|
|
def __init__(self, session: Session):
|
|
super().__init__(session, TrendingStockResult)
|
|
|
|
def get_by_run(self, run_id: str) -> list[TrendingStockResult]:
|
|
return (
|
|
self.session.query(TrendingStockResult)
|
|
.filter(TrendingStockResult.discovery_run_id == run_id)
|
|
.order_by(TrendingStockResult.trending_score.desc())
|
|
.all()
|
|
)
|
|
|
|
def get_by_ticker(self, ticker: str, limit: int = 10) -> list[TrendingStockResult]:
|
|
return (
|
|
self.session.query(TrendingStockResult)
|
|
.filter(TrendingStockResult.ticker == ticker)
|
|
.order_by(TrendingStockResult.created_at.desc())
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
|
|
class DiscoveryArticleRepository(BaseRepository[DiscoveryArticle]):
|
|
def __init__(self, session: Session):
|
|
super().__init__(session, DiscoveryArticle)
|
|
|
|
def get_by_run(self, run_id: str) -> list[DiscoveryArticle]:
|
|
return (
|
|
self.session.query(DiscoveryArticle)
|
|
.filter(DiscoveryArticle.discovery_run_id == run_id)
|
|
.all()
|
|
)
|
|
|
|
|
|
class DiscoveryService:
|
|
def __init__(self, session: Session):
|
|
self.session = session
|
|
self.runs = DiscoveryRunRepository(session)
|
|
self.stocks = TrendingStockResultRepository(session)
|
|
self.articles = DiscoveryArticleRepository(session)
|
|
|
|
def create_run(self, lookback_period: str, max_results: int) -> DiscoveryRun:
|
|
return self.runs.create(
|
|
{
|
|
"lookback_period": lookback_period,
|
|
"max_results": max_results,
|
|
"status": "running",
|
|
}
|
|
)
|
|
|
|
def save_trending_stock(
|
|
self,
|
|
run_id: str,
|
|
ticker: str,
|
|
company_name: str,
|
|
trending_score: float,
|
|
mention_count: int,
|
|
sector: str,
|
|
event_type: str,
|
|
summary: str | None = None,
|
|
source_articles: list[str] | None = None,
|
|
) -> TrendingStockResult:
|
|
return self.stocks.create(
|
|
{
|
|
"discovery_run_id": run_id,
|
|
"ticker": ticker,
|
|
"company_name": company_name,
|
|
"trending_score": trending_score,
|
|
"mention_count": mention_count,
|
|
"sector": sector,
|
|
"event_type": event_type,
|
|
"summary": summary,
|
|
"source_articles": json.dumps(source_articles or []),
|
|
}
|
|
)
|
|
|
|
def save_article(
|
|
self,
|
|
run_id: str,
|
|
title: str,
|
|
source: str,
|
|
url: str | None = None,
|
|
published_at: datetime | None = None,
|
|
content_snippet: str | None = None,
|
|
) -> DiscoveryArticle:
|
|
return self.articles.create(
|
|
{
|
|
"discovery_run_id": run_id,
|
|
"title": title,
|
|
"source": source,
|
|
"url": url,
|
|
"published_at": published_at,
|
|
"content_snippet": content_snippet,
|
|
}
|
|
)
|
|
|
|
def complete_run(self, run_id: str, stocks_found: int) -> DiscoveryRun | None:
|
|
run = self.runs.get(run_id)
|
|
if run:
|
|
run.status = "completed"
|
|
run.completed_at = datetime.utcnow()
|
|
run.stocks_found = stocks_found
|
|
self.session.flush()
|
|
return run
|
|
|
|
def fail_run(self, run_id: str, error: str) -> DiscoveryRun | None:
|
|
run = self.runs.get(run_id)
|
|
if run:
|
|
run.status = "failed"
|
|
run.completed_at = datetime.utcnow()
|
|
run.error_message = error
|
|
self.session.flush()
|
|
return run
|
|
|
|
def get_latest_run(self) -> DiscoveryRun | None:
|
|
return self.runs.get_latest()
|
|
|
|
def get_trending_by_run(self, run_id: str) -> list[TrendingStockResult]:
|
|
return self.stocks.get_by_run(run_id)
|
|
|
|
def get_stock_history(
|
|
self, ticker: str, limit: int = 10
|
|
) -> list[TrendingStockResult]:
|
|
return self.stocks.get_by_ticker(ticker, limit)
|