diff --git a/tradingagents/database/__init__.py b/tradingagents/database/__init__.py index 70ae3227..566427af 100644 --- a/tradingagents/database/__init__.py +++ b/tradingagents/database/__init__.py @@ -1,5 +1,6 @@ from .base import Base from .engine import get_db_session, get_engine, init_database, reset_engine +from .services import AnalysisService, DiscoveryService, TradingService __all__ = [ "Base", @@ -7,4 +8,7 @@ __all__ = [ "get_engine", "init_database", "reset_engine", + "AnalysisService", + "DiscoveryService", + "TradingService", ] diff --git a/tradingagents/database/services/__init__.py b/tradingagents/database/services/__init__.py new file mode 100644 index 00000000..1571974c --- /dev/null +++ b/tradingagents/database/services/__init__.py @@ -0,0 +1,9 @@ +from .analysis import AnalysisService +from .discovery import DiscoveryService +from .trading import TradingService + +__all__ = [ + "AnalysisService", + "DiscoveryService", + "TradingService", +] diff --git a/tradingagents/database/services/analysis.py b/tradingagents/database/services/analysis.py new file mode 100644 index 00000000..6ef0f0a3 --- /dev/null +++ b/tradingagents/database/services/analysis.py @@ -0,0 +1,116 @@ +import json +from datetime import datetime +from typing import Any + +from sqlalchemy.orm import Session + +from tradingagents.database.models import ( + AnalysisSession, + AnalystReport, + InvestmentDebate, + RiskDebate, +) +from tradingagents.database.repositories import ( + AnalysisSessionRepository, + AnalystReportRepository, + InvestmentDebateRepository, + RiskDebateRepository, +) + + +class AnalysisService: + def __init__(self, session: Session): + self.session = session + self.sessions = AnalysisSessionRepository(session) + self.reports = AnalystReportRepository(session) + self.investment_debates = InvestmentDebateRepository(session) + self.risk_debates = RiskDebateRepository(session) + + def create_session(self, ticker: str, trade_date: str) -> AnalysisSession: + return self.sessions.create( + { + "ticker": ticker, + "trade_date": trade_date, + "status": "running", + } + ) + + def save_analyst_report( + self, session_id: str, analyst_type: str, report_content: str + ) -> AnalystReport: + return self.reports.create( + { + "session_id": session_id, + "analyst_type": analyst_type, + "report_content": report_content, + } + ) + + def save_investment_debate( + self, session_id: str, debate_state: dict[str, Any] + ) -> InvestmentDebate: + return self.investment_debates.create( + { + "session_id": session_id, + "bull_history": json.dumps(debate_state.get("bull_history", [])), + "bear_history": json.dumps(debate_state.get("bear_history", [])), + "debate_history": json.dumps(debate_state.get("history", [])), + "judge_decision": debate_state.get("judge_decision", ""), + "investment_plan": debate_state.get("current_response", ""), + "debate_rounds": len(debate_state.get("history", [])), + } + ) + + def save_risk_debate( + self, session_id: str, debate_state: dict[str, Any] + ) -> RiskDebate: + return self.risk_debates.create( + { + "session_id": session_id, + "risky_history": json.dumps(debate_state.get("risky_history", [])), + "safe_history": json.dumps(debate_state.get("safe_history", [])), + "neutral_history": json.dumps(debate_state.get("neutral_history", [])), + "debate_history": json.dumps(debate_state.get("history", [])), + "judge_decision": debate_state.get("judge_decision", ""), + "debate_rounds": len(debate_state.get("history", [])), + } + ) + + def save_full_state( + self, ticker: str, trade_date: str, final_state: dict[str, Any] + ) -> AnalysisSession: + analysis_session = self.create_session(ticker, trade_date) + + analyst_mappings = [ + ("market", "market_report"), + ("sentiment", "sentiment_report"), + ("news", "news_report"), + ("fundamentals", "fundamentals_report"), + ] + + for analyst_type, state_key in analyst_mappings: + report_content = final_state.get(state_key, "") + if report_content: + self.save_analyst_report( + analysis_session.id, analyst_type, report_content + ) + + investment_debate_state = final_state.get("investment_debate_state", {}) + if investment_debate_state: + self.save_investment_debate(analysis_session.id, investment_debate_state) + + risk_debate_state = final_state.get("risk_debate_state", {}) + if risk_debate_state: + self.save_risk_debate(analysis_session.id, risk_debate_state) + + self.sessions.mark_completed(analysis_session.id) + + return analysis_session + + def get_session_by_ticker_date( + self, ticker: str, trade_date: str + ) -> AnalysisSession | None: + return self.sessions.get_by_ticker_and_date(ticker, trade_date) + + def get_latest_session(self, ticker: str) -> AnalysisSession | None: + return self.sessions.get_latest_by_ticker(ticker) diff --git a/tradingagents/database/services/discovery.py b/tradingagents/database/services/discovery.py new file mode 100644 index 00000000..20968f2e --- /dev/null +++ b/tradingagents/database/services/discovery.py @@ -0,0 +1,158 @@ +import json +from datetime import datetime + +from sqlalchemy.orm import Session + +from tradingagents.database.models import ( + DiscoveryArticle, + DiscoveryRun, + TrendingStockResult, +) +from tradingagents.database.repositories.base import BaseRepository + + +class DiscoveryRunRepository(BaseRepository[DiscoveryRun]): + def __init__(self, session: Session): + super().__init__(session, DiscoveryRun) + + def get_latest(self) -> DiscoveryRun | None: + return ( + self.session.query(DiscoveryRun) + .order_by(DiscoveryRun.created_at.desc()) + .first() + ) + + def get_completed(self, limit: int = 10) -> list[DiscoveryRun]: + return ( + self.session.query(DiscoveryRun) + .filter(DiscoveryRun.status == "completed") + .order_by(DiscoveryRun.created_at.desc()) + .limit(limit) + .all() + ) + + +class TrendingStockResultRepository(BaseRepository[TrendingStockResult]): + def __init__(self, session: Session): + super().__init__(session, TrendingStockResult) + + def get_by_run(self, run_id: str) -> list[TrendingStockResult]: + return ( + self.session.query(TrendingStockResult) + .filter(TrendingStockResult.discovery_run_id == run_id) + .order_by(TrendingStockResult.trending_score.desc()) + .all() + ) + + def get_by_ticker(self, ticker: str, limit: int = 10) -> list[TrendingStockResult]: + return ( + self.session.query(TrendingStockResult) + .filter(TrendingStockResult.ticker == ticker) + .order_by(TrendingStockResult.created_at.desc()) + .limit(limit) + .all() + ) + + +class DiscoveryArticleRepository(BaseRepository[DiscoveryArticle]): + def __init__(self, session: Session): + super().__init__(session, DiscoveryArticle) + + def get_by_run(self, run_id: str) -> list[DiscoveryArticle]: + return ( + self.session.query(DiscoveryArticle) + .filter(DiscoveryArticle.discovery_run_id == run_id) + .all() + ) + + +class DiscoveryService: + def __init__(self, session: Session): + self.session = session + self.runs = DiscoveryRunRepository(session) + self.stocks = TrendingStockResultRepository(session) + self.articles = DiscoveryArticleRepository(session) + + def create_run(self, lookback_period: str, max_results: int) -> DiscoveryRun: + return self.runs.create( + { + "lookback_period": lookback_period, + "max_results": max_results, + "status": "running", + } + ) + + def save_trending_stock( + self, + run_id: str, + ticker: str, + company_name: str, + trending_score: float, + mention_count: int, + sector: str, + event_type: str, + summary: str | None = None, + source_articles: list[str] | None = None, + ) -> TrendingStockResult: + return self.stocks.create( + { + "discovery_run_id": run_id, + "ticker": ticker, + "company_name": company_name, + "trending_score": trending_score, + "mention_count": mention_count, + "sector": sector, + "event_type": event_type, + "summary": summary, + "source_articles": json.dumps(source_articles or []), + } + ) + + def save_article( + self, + run_id: str, + title: str, + source: str, + url: str | None = None, + published_at: datetime | None = None, + content_snippet: str | None = None, + ) -> DiscoveryArticle: + return self.articles.create( + { + "discovery_run_id": run_id, + "title": title, + "source": source, + "url": url, + "published_at": published_at, + "content_snippet": content_snippet, + } + ) + + def complete_run(self, run_id: str, stocks_found: int) -> DiscoveryRun | None: + run = self.runs.get(run_id) + if run: + run.status = "completed" + run.completed_at = datetime.utcnow() + run.stocks_found = stocks_found + self.session.flush() + return run + + def fail_run(self, run_id: str, error: str) -> DiscoveryRun | None: + run = self.runs.get(run_id) + if run: + run.status = "failed" + run.completed_at = datetime.utcnow() + run.error_message = error + self.session.flush() + return run + + def get_latest_run(self) -> DiscoveryRun | None: + return self.runs.get_latest() + + def get_trending_by_run(self, run_id: str) -> list[TrendingStockResult]: + return self.stocks.get_by_run(run_id) + + def get_stock_history( + self, ticker: str, limit: int = 10 + ) -> list[TrendingStockResult]: + return self.stocks.get_by_ticker(ticker, limit) diff --git a/tradingagents/database/services/trading.py b/tradingagents/database/services/trading.py new file mode 100644 index 00000000..0db4fb62 --- /dev/null +++ b/tradingagents/database/services/trading.py @@ -0,0 +1,91 @@ +from typing import Any + +from sqlalchemy.orm import Session + +from tradingagents.database.models import ( + TradeExecution, + TradeReflection, + TradingDecision, +) +from tradingagents.database.repositories import ( + TradeExecutionRepository, + TradeReflectionRepository, + TradingDecisionRepository, +) + + +class TradingService: + def __init__(self, session: Session): + self.session = session + self.decisions = TradingDecisionRepository(session) + self.executions = TradeExecutionRepository(session) + self.reflections = TradeReflectionRepository(session) + + def save_trading_decision( + self, + session_id: str, + ticker: str, + final_state: dict[str, Any], + signal: str, + ) -> TradingDecision: + decision_map = {"BUY": "buy", "SELL": "sell", "HOLD": "hold"} + decision = decision_map.get(signal.upper(), "hold") + + return self.decisions.create( + { + "session_id": session_id, + "ticker": ticker, + "decision": decision, + "trader_plan": final_state.get("trader_investment_plan", ""), + "investment_plan": final_state.get("investment_plan", ""), + "final_decision_text": final_state.get("final_trade_decision", ""), + } + ) + + def record_execution( + self, + decision_id: str, + ticker: str, + action: str, + quantity: int, + price: float, + ) -> TradeExecution: + return self.executions.create( + { + "decision_id": decision_id, + "ticker": ticker, + "action": action, + "quantity": quantity, + "price": price, + "status": "executed", + } + ) + + def save_reflection( + self, + ticker: str, + trade_date: str, + decision_id: str | None, + returns_losses: float, + reflection_content: str, + ) -> TradeReflection: + return self.reflections.create( + { + "ticker": ticker, + "trade_date": trade_date, + "decision_id": decision_id, + "actual_return": returns_losses, + "reflection_content": reflection_content, + } + ) + + def get_decision_by_session(self, session_id: str) -> TradingDecision | None: + return self.decisions.get_by_session(session_id) + + def get_decisions_by_ticker( + self, ticker: str, limit: int = 100 + ) -> list[TradingDecision]: + return self.decisions.get_by_ticker(ticker, limit) + + def get_pending_executions(self) -> list[TradeExecution]: + return self.executions.get_pending() diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 34595c25..38074098 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -4,7 +4,7 @@ import os import threading from datetime import date, datetime from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI @@ -35,6 +35,13 @@ from tradingagents.agents.utils.agent_utils import ( get_stock_data, ) from tradingagents.agents.utils.memory import FinancialSituationMemory +from tradingagents.database import ( + AnalysisService, + DiscoveryService, + TradingService, + get_db_session, + init_database, +) from tradingagents.dataflows.config import get_config, set_config from tradingagents.dataflows.interface import get_bulk_news from tradingagents.validation import validate_date, validate_ticker @@ -73,6 +80,11 @@ class TradingAgentsGraph: exist_ok=True, ) + self.db_enabled = self.config.get("database_enabled", False) + if self.db_enabled: + db_path = self.config.get("database_path") + init_database(db_path) + if ( self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" @@ -235,6 +247,23 @@ class TradingAgentsGraph: ) as f: json.dump(self.log_states_dict, f, indent=4) + if self.db_enabled: + self._persist_to_database(trade_date, final_state) + + def _persist_to_database(self, trade_date, final_state: dict[str, Any]) -> None: + with get_db_session() as session: + analysis_service = AnalysisService(session) + trading_service = TradingService(session) + + analysis_session = analysis_service.save_full_state( + self.ticker, str(trade_date), final_state + ) + + signal = self.process_signal(final_state.get("final_trade_decision", "")) + trading_service.save_trading_decision( + analysis_session.id, self.ticker, final_state, signal + ) + def reflect_and_remember(self, returns_losses) -> None: self.reflector.reflect_bull_researcher( self.curr_state, returns_losses, self.bull_memory @@ -351,8 +380,47 @@ class TradingAgentsGraph: result.status = DiscoveryStatus.COMPLETED result.completed_at = datetime.now() + if self.db_enabled: + self._persist_discovery_to_database(request, result) + return result + def _persist_discovery_to_database( + self, request: DiscoveryRequest, result: DiscoveryResult + ) -> None: + with get_db_session() as session: + discovery_service = DiscoveryService(session) + + run = discovery_service.create_run( + request.lookback_period, request.max_results or 20 + ) + + for stock in result.trending_stocks: + discovery_service.save_trending_stock( + run.id, + stock.ticker, + stock.company_name, + stock.trending_score, + stock.mention_count, + stock.sector.value + if hasattr(stock.sector, "value") + else str(stock.sector), + stock.event_type.value + if hasattr(stock.event_type, "value") + else str(stock.event_type), + stock.summary, + [a.get("title", "") for a in stock.source_articles] + if stock.source_articles + else None, + ) + + if result.status == DiscoveryStatus.COMPLETED: + discovery_service.complete_run(run.id, len(result.trending_stocks)) + else: + discovery_service.fail_run( + run.id, result.error_message or "Unknown error" + ) + def analyze_trending( self, trending_stock: TrendingStock,