TradingAgents/tradingagents/database/repositories/analysis.py

115 lines
3.4 KiB
Python

from datetime import datetime
from sqlalchemy import and_
from sqlalchemy.orm import Session
from tradingagents.database.models.analysis import (
AnalysisSession,
AnalystReport,
InvestmentDebate,
RiskDebate,
)
from tradingagents.database.repositories.base import BaseRepository
class AnalysisSessionRepository(BaseRepository[AnalysisSession]):
def __init__(self, session: Session):
super().__init__(session, AnalysisSession)
def get_by_ticker_and_date(
self, ticker: str, trade_date: str
) -> AnalysisSession | None:
return (
self.session.query(AnalysisSession)
.filter(
and_(
AnalysisSession.ticker == ticker,
AnalysisSession.trade_date == trade_date,
)
)
.first()
)
def get_latest_by_ticker(self, ticker: str) -> AnalysisSession | None:
return (
self.session.query(AnalysisSession)
.filter(AnalysisSession.ticker == ticker)
.order_by(AnalysisSession.created_at.desc())
.first()
)
def get_completed_sessions(self, limit: int = 100) -> list[AnalysisSession]:
return (
self.session.query(AnalysisSession)
.filter(AnalysisSession.status == "completed")
.order_by(AnalysisSession.completed_at.desc())
.limit(limit)
.all()
)
def mark_completed(self, session_id: str) -> AnalysisSession | None:
obj = self.get(session_id)
if obj:
obj.status = "completed"
obj.completed_at = datetime.utcnow()
self.session.flush()
return obj
def mark_failed(self, session_id: str) -> AnalysisSession | None:
obj = self.get(session_id)
if obj:
obj.status = "failed"
obj.completed_at = datetime.utcnow()
self.session.flush()
return obj
class AnalystReportRepository(BaseRepository[AnalystReport]):
def __init__(self, session: Session):
super().__init__(session, AnalystReport)
def get_by_session_and_type(
self, session_id: str, analyst_type: str
) -> AnalystReport | None:
return (
self.session.query(AnalystReport)
.filter(
and_(
AnalystReport.session_id == session_id,
AnalystReport.analyst_type == analyst_type,
)
)
.first()
)
def get_all_by_session(self, session_id: str) -> list[AnalystReport]:
return (
self.session.query(AnalystReport)
.filter(AnalystReport.session_id == session_id)
.all()
)
class InvestmentDebateRepository(BaseRepository[InvestmentDebate]):
def __init__(self, session: Session):
super().__init__(session, InvestmentDebate)
def get_by_session(self, session_id: str) -> InvestmentDebate | None:
return (
self.session.query(InvestmentDebate)
.filter(InvestmentDebate.session_id == session_id)
.first()
)
class RiskDebateRepository(BaseRepository[RiskDebate]):
def __init__(self, session: Session):
super().__init__(session, RiskDebate)
def get_by_session(self, session_id: str) -> RiskDebate | None:
return (
self.session.query(RiskDebate)
.filter(RiskDebate.session_id == session_id)
.first()
)