129 lines
5.0 KiB
Python
129 lines
5.0 KiB
Python
import asyncio
|
|
import datetime
|
|
import json
|
|
from typing import Dict, List, Optional
|
|
from sqlmodel import Session, select
|
|
from app.domain.models import User, AnalysisSession, AnalysisStatus
|
|
from app.core.schemas.analysis import AnalysisSessionCreate
|
|
from app.core.config import settings
|
|
from cli.models import AnalystType
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
from app.api.deps import get_db
|
|
from app.core.websocket_manager import WebSocketManager
|
|
|
|
class TradingAnalysisService:
|
|
def __init__(self, user: User, db: Session):
|
|
self.user = user
|
|
self.db = db
|
|
self.websocket_manager = WebSocketManager()
|
|
|
|
async def run_analysis(self, session_id: int):
|
|
"""분석 실행"""
|
|
session = self.get_session(session_id=session_id)
|
|
if not session:
|
|
return
|
|
|
|
try:
|
|
session.status = AnalysisStatus.RUNNING
|
|
session.started_at = datetime.datetime.utcnow()
|
|
self.db.add(session)
|
|
self.db.commit()
|
|
self.db.refresh(session)
|
|
|
|
await self.websocket_manager.send_to_user(
|
|
self.user.id,
|
|
{
|
|
'type': 'analysis_started',
|
|
'session_id': session.id,
|
|
'message': '분석을 시작합니다...'
|
|
}
|
|
)
|
|
|
|
# Prepare config for TradingAgentsGraph
|
|
config = DEFAULT_CONFIG.copy()
|
|
config.update({
|
|
'openai_api_key': settings.OPENAI_API_KEY,
|
|
'llm_provider': session.llm_provider,
|
|
'backend_url': session.backend_url,
|
|
'shallow_thinking_model': session.shallow_thinker,
|
|
'deep_thinking_model': session.deep_thinker,
|
|
})
|
|
|
|
# Progress callback for websocket
|
|
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
|
|
progress_percent = int((step / total) * 99) if total > 0 else 0
|
|
await self.websocket_manager.send_to_user(self.user.id, {
|
|
'type': 'analysis_progress',
|
|
'session_id': session.id,
|
|
'message_type': message_type,
|
|
'content': content,
|
|
'agent': agent,
|
|
'progress': progress_percent,
|
|
})
|
|
|
|
trading_graph = TradingAgentsGraph(
|
|
config=config,
|
|
selected_analysts=session.analysts_selected,
|
|
)
|
|
|
|
input_data = {
|
|
'company_of_interest': session.ticker,
|
|
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
|
|
}
|
|
|
|
final_state, result = await asyncio.to_thread(
|
|
trading_graph.propagate,
|
|
input_data['company_of_interest'],
|
|
input_data['trade_date']
|
|
)
|
|
|
|
session.status = AnalysisStatus.COMPLETED
|
|
session.completed_at = datetime.datetime.utcnow()
|
|
session.final_report = json.dumps(final_state) # Store full state as JSON
|
|
self.db.add(session)
|
|
self.db.commit()
|
|
|
|
await self.websocket_manager.send_to_user(
|
|
self.user.id,
|
|
{
|
|
'type': 'analysis_completed',
|
|
'session_id': session.id,
|
|
'message': '분석이 완료되었습니다.',
|
|
'result': result
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
session.status = AnalysisStatus.FAILED
|
|
session.error_message = str(e)
|
|
self.db.add(session)
|
|
self.db.commit()
|
|
await self.websocket_manager.send_to_user(
|
|
self.user.id,
|
|
{
|
|
'type': 'analysis_failed',
|
|
'session_id': session.id,
|
|
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
|
|
}
|
|
)
|
|
|
|
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
|
|
session = AnalysisSession(
|
|
**analysis_in.dict(),
|
|
user_id=self.user.id,
|
|
analysis_date=datetime.date.today()
|
|
)
|
|
self.db.add(session)
|
|
self.db.commit()
|
|
self.db.refresh(session)
|
|
return session
|
|
|
|
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
|
|
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
|
|
return self.db.exec(statement).first()
|
|
|
|
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
|
|
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
|
|
return self.db.exec(statement).all()
|