TradingAgents/app/core/services/trading_analysis.py

129 lines
5.0 KiB
Python

import asyncio
import datetime
import json
from typing import Dict, List, Optional
from sqlmodel import Session, select
from app.domain.models import User, AnalysisSession, AnalysisStatus
from app.core.schemas.analysis import AnalysisSessionCreate
from app.core.config import settings
from cli.models import AnalystType
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from app.api.deps import get_db
from app.core.websocket_manager import WebSocketManager
class TradingAnalysisService:
def __init__(self, user: User, db: Session):
self.user = user
self.db = db
self.websocket_manager = WebSocketManager()
async def run_analysis(self, session_id: int):
"""분석 실행"""
session = self.get_session(session_id=session_id)
if not session:
return
try:
session.status = AnalysisStatus.RUNNING
session.started_at = datetime.datetime.utcnow()
self.db.add(session)
self.db.commit()
self.db.refresh(session)
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_started',
'session_id': session.id,
'message': '분석을 시작합니다...'
}
)
# Prepare config for TradingAgentsGraph
config = DEFAULT_CONFIG.copy()
config.update({
'openai_api_key': settings.OPENAI_API_KEY,
'llm_provider': session.llm_provider,
'backend_url': session.backend_url,
'shallow_thinking_model': session.shallow_thinker,
'deep_thinking_model': session.deep_thinker,
})
# Progress callback for websocket
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
progress_percent = int((step / total) * 99) if total > 0 else 0
await self.websocket_manager.send_to_user(self.user.id, {
'type': 'analysis_progress',
'session_id': session.id,
'message_type': message_type,
'content': content,
'agent': agent,
'progress': progress_percent,
})
trading_graph = TradingAgentsGraph(
config=config,
selected_analysts=session.analysts_selected,
)
input_data = {
'company_of_interest': session.ticker,
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
}
final_state, result = await asyncio.to_thread(
trading_graph.propagate,
input_data['company_of_interest'],
input_data['trade_date']
)
session.status = AnalysisStatus.COMPLETED
session.completed_at = datetime.datetime.utcnow()
session.final_report = json.dumps(final_state) # Store full state as JSON
self.db.add(session)
self.db.commit()
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_completed',
'session_id': session.id,
'message': '분석이 완료되었습니다.',
'result': result
}
)
except Exception as e:
session.status = AnalysisStatus.FAILED
session.error_message = str(e)
self.db.add(session)
self.db.commit()
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_failed',
'session_id': session.id,
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
}
)
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
session = AnalysisSession(
**analysis_in.dict(),
user_id=self.user.id,
analysis_date=datetime.date.today()
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
return self.db.exec(statement).first()
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
return self.db.exec(statement).all()