457 lines
17 KiB
Python
457 lines
17 KiB
Python
"""Analysis execution and state management."""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
import traceback
|
|
import uuid
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
from datetime import datetime
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from api.database import Analysis, AnalysisLog, AnalysisReport, SessionLocal
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AnalysisExecutor:
|
|
"""Manages analysis execution with thread pool and state tracking."""
|
|
|
|
def __init__(self, max_workers: int = None):
|
|
"""
|
|
Initialize the executor.
|
|
|
|
Args:
|
|
max_workers: Maximum concurrent analyses (default: from env or 4)
|
|
"""
|
|
if max_workers is None:
|
|
max_workers = int(os.getenv("MAX_CONCURRENT_ANALYSES", "4"))
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
self.active_analyses: Dict[str, Future] = {}
|
|
self.status_callbacks: Dict[str, List[Callable]] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def register_status_callback(self, analysis_id: str, callback: Callable):
|
|
"""Register a callback for status updates."""
|
|
with self._lock:
|
|
if analysis_id not in self.status_callbacks:
|
|
self.status_callbacks[analysis_id] = []
|
|
self.status_callbacks[analysis_id].append(callback)
|
|
|
|
def unregister_status_callbacks(self, analysis_id: str):
|
|
"""Remove all callbacks for an analysis."""
|
|
with self._lock:
|
|
if analysis_id in self.status_callbacks:
|
|
del self.status_callbacks[analysis_id]
|
|
|
|
def _notify_callbacks(self, analysis_id: str, status_data: Dict[str, Any]):
|
|
"""Notify all registered callbacks."""
|
|
with self._lock:
|
|
callbacks = self.status_callbacks.get(analysis_id, [])
|
|
|
|
for callback in callbacks:
|
|
try:
|
|
callback(status_data)
|
|
except Exception as e:
|
|
print(f"Error in status callback: {e}")
|
|
|
|
def start_analysis(
|
|
self,
|
|
analysis_id: str,
|
|
ticker: str,
|
|
analysis_date: str,
|
|
selected_analysts: List[str],
|
|
config: Dict[str, Any],
|
|
) -> str:
|
|
"""
|
|
Start a new analysis in the background.
|
|
|
|
Args:
|
|
analysis_id: Unique analysis ID
|
|
ticker: Ticker symbol
|
|
analysis_date: Analysis date
|
|
selected_analysts: List of analyst types
|
|
config: Trading agents configuration
|
|
|
|
Returns:
|
|
analysis_id
|
|
"""
|
|
future = self.executor.submit(
|
|
self._run_analysis,
|
|
analysis_id,
|
|
ticker,
|
|
analysis_date,
|
|
selected_analysts,
|
|
config,
|
|
)
|
|
|
|
with self._lock:
|
|
self.active_analyses[analysis_id] = future
|
|
|
|
# Cleanup when done
|
|
future.add_done_callback(lambda f: self._cleanup_analysis(analysis_id))
|
|
|
|
return analysis_id
|
|
|
|
def _cleanup_analysis(self, analysis_id: str):
|
|
"""Clean up after analysis completes."""
|
|
with self._lock:
|
|
if analysis_id in self.active_analyses:
|
|
del self.active_analyses[analysis_id]
|
|
self.unregister_status_callbacks(analysis_id)
|
|
|
|
def cancel_analysis(self, analysis_id: str) -> bool:
|
|
"""
|
|
Attempt to cancel a running analysis.
|
|
|
|
Returns:
|
|
True if cancelled, False if not found or already completed
|
|
"""
|
|
with self._lock:
|
|
future = self.active_analyses.get(analysis_id)
|
|
|
|
if future and not future.done():
|
|
cancelled = future.cancel()
|
|
if cancelled:
|
|
# Update database status
|
|
db = SessionLocal()
|
|
try:
|
|
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
|
if analysis:
|
|
analysis.status = "cancelled"
|
|
analysis.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
return cancelled
|
|
return False
|
|
|
|
def get_status(self, analysis_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get current status of an analysis."""
|
|
db = SessionLocal()
|
|
try:
|
|
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
|
if not analysis:
|
|
return None
|
|
|
|
return {
|
|
"id": analysis.id,
|
|
"status": analysis.status,
|
|
"progress_percentage": analysis.progress_percentage,
|
|
"current_agent": analysis.current_agent,
|
|
"updated_at": analysis.updated_at,
|
|
}
|
|
finally:
|
|
db.close()
|
|
|
|
def _update_status(
|
|
self,
|
|
analysis_id: str,
|
|
status: Optional[str] = None,
|
|
progress: Optional[int] = None,
|
|
current_agent: Optional[str] = None,
|
|
error_message: Optional[str] = None,
|
|
):
|
|
"""Update analysis status in database and notify callbacks."""
|
|
db = SessionLocal()
|
|
try:
|
|
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
|
if not analysis:
|
|
return
|
|
|
|
if status:
|
|
analysis.status = status
|
|
if progress is not None:
|
|
analysis.progress_percentage = progress
|
|
if current_agent:
|
|
analysis.current_agent = current_agent
|
|
if error_message:
|
|
analysis.error_message = error_message
|
|
|
|
analysis.updated_at = datetime.utcnow()
|
|
|
|
if status == "completed":
|
|
analysis.completed_at = datetime.utcnow()
|
|
analysis.progress_percentage = 100
|
|
|
|
db.commit()
|
|
db.refresh(analysis)
|
|
|
|
# Notify callbacks
|
|
status_data = {
|
|
"type": "status_update",
|
|
"analysis_id": analysis.id,
|
|
"status": analysis.status,
|
|
"progress_percentage": analysis.progress_percentage,
|
|
"current_agent": analysis.current_agent,
|
|
"timestamp": analysis.updated_at.isoformat(),
|
|
}
|
|
self._notify_callbacks(analysis_id, status_data)
|
|
|
|
finally:
|
|
db.close()
|
|
|
|
def _store_log(self, analysis_id: str, log_type: str, content: str):
|
|
"""Store a log entry."""
|
|
db = SessionLocal()
|
|
try:
|
|
log = AnalysisLog(
|
|
analysis_id=analysis_id,
|
|
log_type=log_type,
|
|
content=content,
|
|
timestamp=datetime.utcnow(),
|
|
)
|
|
db.add(log)
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
def _store_report(self, analysis_id: str, report_type: str, content: str):
|
|
"""Store or update a report section."""
|
|
db = SessionLocal()
|
|
try:
|
|
# Check if report already exists
|
|
report = (
|
|
db.query(AnalysisReport)
|
|
.filter(
|
|
AnalysisReport.analysis_id == analysis_id,
|
|
AnalysisReport.report_type == report_type,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if report:
|
|
# Update existing
|
|
report.content = content
|
|
report.created_at = datetime.utcnow()
|
|
else:
|
|
# Create new
|
|
report = AnalysisReport(
|
|
analysis_id=analysis_id,
|
|
report_type=report_type,
|
|
content=content,
|
|
)
|
|
db.add(report)
|
|
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
def _run_analysis(
|
|
self,
|
|
analysis_id: str,
|
|
ticker: str,
|
|
analysis_date: str,
|
|
selected_analysts: List[str],
|
|
config: Dict[str, Any],
|
|
):
|
|
"""Execute the analysis (runs in thread pool)."""
|
|
logger.info(f"Starting analysis {analysis_id} for {ticker} on {analysis_date}")
|
|
try:
|
|
# Update status to running
|
|
self._update_status(analysis_id, status="running", progress=0)
|
|
logger.info(f"Analysis {analysis_id}: Initializing trading graph...")
|
|
|
|
# Initialize the graph with unique analysis_id for memory isolation
|
|
graph = TradingAgentsGraph(
|
|
selected_analysts=selected_analysts,
|
|
config=config,
|
|
debug=False,
|
|
analysis_id=analysis_id,
|
|
)
|
|
|
|
# Create initial state
|
|
init_agent_state = graph.propagator.create_initial_state(ticker, analysis_date)
|
|
init_agent_state["asset_class"] = config.get("asset_class", "equity")
|
|
args = graph.propagator.get_graph_args()
|
|
|
|
# Track agent progress
|
|
agent_order = self._get_agent_order(selected_analysts)
|
|
total_agents = len(agent_order)
|
|
current_agent_index = 0
|
|
|
|
# Stream the analysis
|
|
trace = []
|
|
for chunk in graph.graph.stream(init_agent_state, **args):
|
|
if len(chunk.get("messages", [])) == 0:
|
|
continue
|
|
|
|
# Process the chunk
|
|
last_message = chunk["messages"][-1]
|
|
|
|
# Extract content
|
|
if hasattr(last_message, "content"):
|
|
content = self._extract_content(last_message.content)
|
|
msg_type = "Reasoning"
|
|
|
|
# Store log
|
|
self._store_log(analysis_id, msg_type, content)
|
|
|
|
# Handle tool calls
|
|
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
|
for tool_call in last_message.tool_calls:
|
|
if isinstance(tool_call, dict):
|
|
tool_name = tool_call["name"]
|
|
tool_args = tool_call["args"]
|
|
else:
|
|
tool_name = tool_call.name
|
|
tool_args = tool_call.args
|
|
|
|
args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items())
|
|
self._store_log(
|
|
analysis_id, "Tool Call", f"{tool_name}({args_str})"
|
|
)
|
|
|
|
# Check for completed reports
|
|
for report_type in [
|
|
"market_report",
|
|
"sentiment_report",
|
|
"news_report",
|
|
"fundamentals_report",
|
|
]:
|
|
if report_type in chunk and chunk[report_type]:
|
|
self._store_report(analysis_id, report_type, chunk[report_type])
|
|
current_agent_index += 1
|
|
progress = int((current_agent_index / total_agents) * 100)
|
|
agent_name = self._get_agent_name(report_type)
|
|
self._update_status(
|
|
analysis_id,
|
|
progress=min(progress, 95),
|
|
current_agent=agent_name,
|
|
)
|
|
|
|
# Check for investment debate state
|
|
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
|
debate_state = chunk["investment_debate_state"]
|
|
if "judge_decision" in debate_state and debate_state["judge_decision"]:
|
|
self._store_report(
|
|
analysis_id, "investment_plan", debate_state["judge_decision"]
|
|
)
|
|
current_agent_index += 1
|
|
progress = int((current_agent_index / total_agents) * 100)
|
|
self._update_status(
|
|
analysis_id,
|
|
progress=min(progress, 98),
|
|
current_agent="Research Manager",
|
|
)
|
|
|
|
# Check for trader plan
|
|
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
|
self._store_report(
|
|
analysis_id, "trader_investment_plan", chunk["trader_investment_plan"]
|
|
)
|
|
self._update_status(
|
|
analysis_id,
|
|
progress=99,
|
|
current_agent="Trader",
|
|
)
|
|
|
|
trace.append(chunk)
|
|
|
|
# Get final state
|
|
if trace:
|
|
final_state = trace[-1]
|
|
|
|
# Store final trade decision
|
|
if "final_trade_decision" in final_state:
|
|
decision = graph.process_signal(final_state["final_trade_decision"])
|
|
self._store_report(
|
|
analysis_id, "final_trade_decision", final_state["final_trade_decision"]
|
|
)
|
|
self._store_log(
|
|
analysis_id, "System", f"Final decision: {decision}"
|
|
)
|
|
|
|
# Mark as completed
|
|
logger.info(f"Analysis {analysis_id} completed successfully")
|
|
self._update_status(analysis_id, status="completed", progress=100)
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
error_trace = traceback.format_exc()
|
|
logger.error(f"Analysis {analysis_id} failed: {error_msg}")
|
|
logger.error(f"Traceback:\n{error_trace}")
|
|
self._update_status(
|
|
analysis_id, status="failed", error_message=error_msg
|
|
)
|
|
self._store_log(analysis_id, "System", f"Error: {error_msg}\n\nTraceback:\n{error_trace}")
|
|
finally:
|
|
# Clean up ChromaDB collections to prevent memory leaks
|
|
try:
|
|
if 'graph' in locals():
|
|
graph.cleanup_memories()
|
|
logger.info(f"Analysis {analysis_id}: Cleaned up memory collections")
|
|
except Exception as cleanup_error:
|
|
logger.warning(f"Analysis {analysis_id}: Failed to cleanup memories: {cleanup_error}")
|
|
|
|
def _get_agent_order(self, selected_analysts: List[str]) -> List[str]:
|
|
"""Get the order of agents for progress tracking."""
|
|
agents = selected_analysts.copy()
|
|
agents.extend(["bull_researcher", "bear_researcher", "research_manager", "trader", "risk", "portfolio"])
|
|
return agents
|
|
|
|
def _get_agent_name(self, report_type: str) -> str:
|
|
"""Get human-readable agent name from report type."""
|
|
mapping = {
|
|
"market_report": "Market Analyst",
|
|
"sentiment_report": "Social Analyst",
|
|
"news_report": "News Analyst",
|
|
"fundamentals_report": "Fundamentals Analyst",
|
|
}
|
|
return mapping.get(report_type, "Unknown")
|
|
|
|
def _extract_content(self, content: Any) -> str:
|
|
"""Extract string content from various message formats."""
|
|
if isinstance(content, str):
|
|
return content
|
|
elif isinstance(content, list):
|
|
text_parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
if item.get("type") == "text":
|
|
text_parts.append(item.get("text", ""))
|
|
elif item.get("type") == "tool_use":
|
|
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
|
|
else:
|
|
text_parts.append(str(item))
|
|
return " ".join(text_parts)
|
|
else:
|
|
return str(content)
|
|
|
|
def shutdown(self):
|
|
"""Shutdown the executor and cancel all running analyses."""
|
|
with self._lock:
|
|
# Cancel all active analyses
|
|
for analysis_id in list(self.active_analyses.keys()):
|
|
self.cancel_analysis(analysis_id)
|
|
|
|
# Shutdown executor
|
|
self.executor.shutdown(wait=True)
|
|
|
|
|
|
# Global executor instance
|
|
_executor: Optional[AnalysisExecutor] = None
|
|
|
|
|
|
def get_executor() -> AnalysisExecutor:
|
|
"""Get the global executor instance."""
|
|
global _executor
|
|
if _executor is None:
|
|
_executor = AnalysisExecutor()
|
|
return _executor
|
|
|
|
|
|
def shutdown_executor():
|
|
"""Shutdown the global executor."""
|
|
global _executor
|
|
if _executor:
|
|
_executor.shutdown()
|
|
_executor = None
|
|
|