115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
from datetime import datetime
|
|
|
|
from sqlalchemy import and_
|
|
from sqlalchemy.orm import Session
|
|
|
|
from tradingagents.database.models.analysis import (
|
|
AnalysisSession,
|
|
AnalystReport,
|
|
InvestmentDebate,
|
|
RiskDebate,
|
|
)
|
|
from tradingagents.database.repositories.base import BaseRepository
|
|
|
|
|
|
class AnalysisSessionRepository(BaseRepository[AnalysisSession]):
|
|
def __init__(self, session: Session):
|
|
super().__init__(session, AnalysisSession)
|
|
|
|
def get_by_ticker_and_date(
|
|
self, ticker: str, trade_date: str
|
|
) -> AnalysisSession | None:
|
|
return (
|
|
self.session.query(AnalysisSession)
|
|
.filter(
|
|
and_(
|
|
AnalysisSession.ticker == ticker,
|
|
AnalysisSession.trade_date == trade_date,
|
|
)
|
|
)
|
|
.first()
|
|
)
|
|
|
|
def get_latest_by_ticker(self, ticker: str) -> AnalysisSession | None:
|
|
return (
|
|
self.session.query(AnalysisSession)
|
|
.filter(AnalysisSession.ticker == ticker)
|
|
.order_by(AnalysisSession.created_at.desc())
|
|
.first()
|
|
)
|
|
|
|
def get_completed_sessions(self, limit: int = 100) -> list[AnalysisSession]:
|
|
return (
|
|
self.session.query(AnalysisSession)
|
|
.filter(AnalysisSession.status == "completed")
|
|
.order_by(AnalysisSession.completed_at.desc())
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
def mark_completed(self, session_id: str) -> AnalysisSession | None:
|
|
obj = self.get(session_id)
|
|
if obj:
|
|
obj.status = "completed"
|
|
obj.completed_at = datetime.utcnow()
|
|
self.session.flush()
|
|
return obj
|
|
|
|
def mark_failed(self, session_id: str) -> AnalysisSession | None:
|
|
obj = self.get(session_id)
|
|
if obj:
|
|
obj.status = "failed"
|
|
obj.completed_at = datetime.utcnow()
|
|
self.session.flush()
|
|
return obj
|
|
|
|
|
|
class AnalystReportRepository(BaseRepository[AnalystReport]):
|
|
def __init__(self, session: Session):
|
|
super().__init__(session, AnalystReport)
|
|
|
|
def get_by_session_and_type(
|
|
self, session_id: str, analyst_type: str
|
|
) -> AnalystReport | None:
|
|
return (
|
|
self.session.query(AnalystReport)
|
|
.filter(
|
|
and_(
|
|
AnalystReport.session_id == session_id,
|
|
AnalystReport.analyst_type == analyst_type,
|
|
)
|
|
)
|
|
.first()
|
|
)
|
|
|
|
def get_all_by_session(self, session_id: str) -> list[AnalystReport]:
|
|
return (
|
|
self.session.query(AnalystReport)
|
|
.filter(AnalystReport.session_id == session_id)
|
|
.all()
|
|
)
|
|
|
|
|
|
class InvestmentDebateRepository(BaseRepository[InvestmentDebate]):
|
|
def __init__(self, session: Session):
|
|
super().__init__(session, InvestmentDebate)
|
|
|
|
def get_by_session(self, session_id: str) -> InvestmentDebate | None:
|
|
return (
|
|
self.session.query(InvestmentDebate)
|
|
.filter(InvestmentDebate.session_id == session_id)
|
|
.first()
|
|
)
|
|
|
|
|
|
class RiskDebateRepository(BaseRepository[RiskDebate]):
|
|
def __init__(self, session: Session):
|
|
super().__init__(session, RiskDebate)
|
|
|
|
def get_by_session(self, session_id: str) -> RiskDebate | None:
|
|
return (
|
|
self.session.query(RiskDebate)
|
|
.filter(RiskDebate.session_id == session_id)
|
|
.first()
|
|
)
|