feat: integrate database persistence with TradingAgentsGraph

Add database services layer with AnalysisService, TradingService, and DiscoveryService
for persisting analysis sessions, trading decisions, and discovery runs.

Integration with TradingAgentsGraph:
- Add config options: database_enabled, database_path
- Persist analysis state to database in _log_state when enabled
- Persist discovery results to database when enabled
- Save analyst reports, debates, and trading decisions

Database integration is opt-in via config setting.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 11:31:21 -05:00
parent c39f9aab36
commit 1db81e1fc6
6 changed files with 447 additions and 1 deletions

View File

@ -1,5 +1,6 @@
from .base import Base
from .engine import get_db_session, get_engine, init_database, reset_engine
from .services import AnalysisService, DiscoveryService, TradingService
__all__ = [
"Base",
@ -7,4 +8,7 @@ __all__ = [
"get_engine",
"init_database",
"reset_engine",
"AnalysisService",
"DiscoveryService",
"TradingService",
]

View File

@ -0,0 +1,9 @@
from .analysis import AnalysisService
from .discovery import DiscoveryService
from .trading import TradingService
__all__ = [
"AnalysisService",
"DiscoveryService",
"TradingService",
]

View File

@ -0,0 +1,116 @@
import json
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
from tradingagents.database.models import (
AnalysisSession,
AnalystReport,
InvestmentDebate,
RiskDebate,
)
from tradingagents.database.repositories import (
AnalysisSessionRepository,
AnalystReportRepository,
InvestmentDebateRepository,
RiskDebateRepository,
)
class AnalysisService:
def __init__(self, session: Session):
self.session = session
self.sessions = AnalysisSessionRepository(session)
self.reports = AnalystReportRepository(session)
self.investment_debates = InvestmentDebateRepository(session)
self.risk_debates = RiskDebateRepository(session)
def create_session(self, ticker: str, trade_date: str) -> AnalysisSession:
return self.sessions.create(
{
"ticker": ticker,
"trade_date": trade_date,
"status": "running",
}
)
def save_analyst_report(
self, session_id: str, analyst_type: str, report_content: str
) -> AnalystReport:
return self.reports.create(
{
"session_id": session_id,
"analyst_type": analyst_type,
"report_content": report_content,
}
)
def save_investment_debate(
self, session_id: str, debate_state: dict[str, Any]
) -> InvestmentDebate:
return self.investment_debates.create(
{
"session_id": session_id,
"bull_history": json.dumps(debate_state.get("bull_history", [])),
"bear_history": json.dumps(debate_state.get("bear_history", [])),
"debate_history": json.dumps(debate_state.get("history", [])),
"judge_decision": debate_state.get("judge_decision", ""),
"investment_plan": debate_state.get("current_response", ""),
"debate_rounds": len(debate_state.get("history", [])),
}
)
def save_risk_debate(
self, session_id: str, debate_state: dict[str, Any]
) -> RiskDebate:
return self.risk_debates.create(
{
"session_id": session_id,
"risky_history": json.dumps(debate_state.get("risky_history", [])),
"safe_history": json.dumps(debate_state.get("safe_history", [])),
"neutral_history": json.dumps(debate_state.get("neutral_history", [])),
"debate_history": json.dumps(debate_state.get("history", [])),
"judge_decision": debate_state.get("judge_decision", ""),
"debate_rounds": len(debate_state.get("history", [])),
}
)
def save_full_state(
self, ticker: str, trade_date: str, final_state: dict[str, Any]
) -> AnalysisSession:
analysis_session = self.create_session(ticker, trade_date)
analyst_mappings = [
("market", "market_report"),
("sentiment", "sentiment_report"),
("news", "news_report"),
("fundamentals", "fundamentals_report"),
]
for analyst_type, state_key in analyst_mappings:
report_content = final_state.get(state_key, "")
if report_content:
self.save_analyst_report(
analysis_session.id, analyst_type, report_content
)
investment_debate_state = final_state.get("investment_debate_state", {})
if investment_debate_state:
self.save_investment_debate(analysis_session.id, investment_debate_state)
risk_debate_state = final_state.get("risk_debate_state", {})
if risk_debate_state:
self.save_risk_debate(analysis_session.id, risk_debate_state)
self.sessions.mark_completed(analysis_session.id)
return analysis_session
def get_session_by_ticker_date(
self, ticker: str, trade_date: str
) -> AnalysisSession | None:
return self.sessions.get_by_ticker_and_date(ticker, trade_date)
def get_latest_session(self, ticker: str) -> AnalysisSession | None:
return self.sessions.get_latest_by_ticker(ticker)

View File

@ -0,0 +1,158 @@
import json
from datetime import datetime
from sqlalchemy.orm import Session
from tradingagents.database.models import (
DiscoveryArticle,
DiscoveryRun,
TrendingStockResult,
)
from tradingagents.database.repositories.base import BaseRepository
class DiscoveryRunRepository(BaseRepository[DiscoveryRun]):
def __init__(self, session: Session):
super().__init__(session, DiscoveryRun)
def get_latest(self) -> DiscoveryRun | None:
return (
self.session.query(DiscoveryRun)
.order_by(DiscoveryRun.created_at.desc())
.first()
)
def get_completed(self, limit: int = 10) -> list[DiscoveryRun]:
return (
self.session.query(DiscoveryRun)
.filter(DiscoveryRun.status == "completed")
.order_by(DiscoveryRun.created_at.desc())
.limit(limit)
.all()
)
class TrendingStockResultRepository(BaseRepository[TrendingStockResult]):
def __init__(self, session: Session):
super().__init__(session, TrendingStockResult)
def get_by_run(self, run_id: str) -> list[TrendingStockResult]:
return (
self.session.query(TrendingStockResult)
.filter(TrendingStockResult.discovery_run_id == run_id)
.order_by(TrendingStockResult.trending_score.desc())
.all()
)
def get_by_ticker(self, ticker: str, limit: int = 10) -> list[TrendingStockResult]:
return (
self.session.query(TrendingStockResult)
.filter(TrendingStockResult.ticker == ticker)
.order_by(TrendingStockResult.created_at.desc())
.limit(limit)
.all()
)
class DiscoveryArticleRepository(BaseRepository[DiscoveryArticle]):
def __init__(self, session: Session):
super().__init__(session, DiscoveryArticle)
def get_by_run(self, run_id: str) -> list[DiscoveryArticle]:
return (
self.session.query(DiscoveryArticle)
.filter(DiscoveryArticle.discovery_run_id == run_id)
.all()
)
class DiscoveryService:
def __init__(self, session: Session):
self.session = session
self.runs = DiscoveryRunRepository(session)
self.stocks = TrendingStockResultRepository(session)
self.articles = DiscoveryArticleRepository(session)
def create_run(self, lookback_period: str, max_results: int) -> DiscoveryRun:
return self.runs.create(
{
"lookback_period": lookback_period,
"max_results": max_results,
"status": "running",
}
)
def save_trending_stock(
self,
run_id: str,
ticker: str,
company_name: str,
trending_score: float,
mention_count: int,
sector: str,
event_type: str,
summary: str | None = None,
source_articles: list[str] | None = None,
) -> TrendingStockResult:
return self.stocks.create(
{
"discovery_run_id": run_id,
"ticker": ticker,
"company_name": company_name,
"trending_score": trending_score,
"mention_count": mention_count,
"sector": sector,
"event_type": event_type,
"summary": summary,
"source_articles": json.dumps(source_articles or []),
}
)
def save_article(
self,
run_id: str,
title: str,
source: str,
url: str | None = None,
published_at: datetime | None = None,
content_snippet: str | None = None,
) -> DiscoveryArticle:
return self.articles.create(
{
"discovery_run_id": run_id,
"title": title,
"source": source,
"url": url,
"published_at": published_at,
"content_snippet": content_snippet,
}
)
def complete_run(self, run_id: str, stocks_found: int) -> DiscoveryRun | None:
run = self.runs.get(run_id)
if run:
run.status = "completed"
run.completed_at = datetime.utcnow()
run.stocks_found = stocks_found
self.session.flush()
return run
def fail_run(self, run_id: str, error: str) -> DiscoveryRun | None:
run = self.runs.get(run_id)
if run:
run.status = "failed"
run.completed_at = datetime.utcnow()
run.error_message = error
self.session.flush()
return run
def get_latest_run(self) -> DiscoveryRun | None:
return self.runs.get_latest()
def get_trending_by_run(self, run_id: str) -> list[TrendingStockResult]:
return self.stocks.get_by_run(run_id)
def get_stock_history(
self, ticker: str, limit: int = 10
) -> list[TrendingStockResult]:
return self.stocks.get_by_ticker(ticker, limit)

View File

@ -0,0 +1,91 @@
from typing import Any
from sqlalchemy.orm import Session
from tradingagents.database.models import (
TradeExecution,
TradeReflection,
TradingDecision,
)
from tradingagents.database.repositories import (
TradeExecutionRepository,
TradeReflectionRepository,
TradingDecisionRepository,
)
class TradingService:
def __init__(self, session: Session):
self.session = session
self.decisions = TradingDecisionRepository(session)
self.executions = TradeExecutionRepository(session)
self.reflections = TradeReflectionRepository(session)
def save_trading_decision(
self,
session_id: str,
ticker: str,
final_state: dict[str, Any],
signal: str,
) -> TradingDecision:
decision_map = {"BUY": "buy", "SELL": "sell", "HOLD": "hold"}
decision = decision_map.get(signal.upper(), "hold")
return self.decisions.create(
{
"session_id": session_id,
"ticker": ticker,
"decision": decision,
"trader_plan": final_state.get("trader_investment_plan", ""),
"investment_plan": final_state.get("investment_plan", ""),
"final_decision_text": final_state.get("final_trade_decision", ""),
}
)
def record_execution(
self,
decision_id: str,
ticker: str,
action: str,
quantity: int,
price: float,
) -> TradeExecution:
return self.executions.create(
{
"decision_id": decision_id,
"ticker": ticker,
"action": action,
"quantity": quantity,
"price": price,
"status": "executed",
}
)
def save_reflection(
self,
ticker: str,
trade_date: str,
decision_id: str | None,
returns_losses: float,
reflection_content: str,
) -> TradeReflection:
return self.reflections.create(
{
"ticker": ticker,
"trade_date": trade_date,
"decision_id": decision_id,
"actual_return": returns_losses,
"reflection_content": reflection_content,
}
)
def get_decision_by_session(self, session_id: str) -> TradingDecision | None:
return self.decisions.get_by_session(session_id)
def get_decisions_by_ticker(
self, ticker: str, limit: int = 100
) -> list[TradingDecision]:
return self.decisions.get_by_ticker(ticker, limit)
def get_pending_executions(self) -> list[TradeExecution]:
return self.executions.get_pending()

View File

@ -4,7 +4,7 @@ import os
import threading
from datetime import date, datetime
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
@ -35,6 +35,13 @@ from tradingagents.agents.utils.agent_utils import (
get_stock_data,
)
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.database import (
AnalysisService,
DiscoveryService,
TradingService,
get_db_session,
init_database,
)
from tradingagents.dataflows.config import get_config, set_config
from tradingagents.dataflows.interface import get_bulk_news
from tradingagents.validation import validate_date, validate_ticker
@ -73,6 +80,11 @@ class TradingAgentsGraph:
exist_ok=True,
)
self.db_enabled = self.config.get("database_enabled", False)
if self.db_enabled:
db_path = self.config.get("database_path")
init_database(db_path)
if (
self.config["llm_provider"].lower() == "openai"
or self.config["llm_provider"] == "ollama"
@ -235,6 +247,23 @@ class TradingAgentsGraph:
) as f:
json.dump(self.log_states_dict, f, indent=4)
if self.db_enabled:
self._persist_to_database(trade_date, final_state)
def _persist_to_database(self, trade_date, final_state: dict[str, Any]) -> None:
with get_db_session() as session:
analysis_service = AnalysisService(session)
trading_service = TradingService(session)
analysis_session = analysis_service.save_full_state(
self.ticker, str(trade_date), final_state
)
signal = self.process_signal(final_state.get("final_trade_decision", ""))
trading_service.save_trading_decision(
analysis_session.id, self.ticker, final_state, signal
)
def reflect_and_remember(self, returns_losses) -> None:
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
@ -351,8 +380,47 @@ class TradingAgentsGraph:
result.status = DiscoveryStatus.COMPLETED
result.completed_at = datetime.now()
if self.db_enabled:
self._persist_discovery_to_database(request, result)
return result
def _persist_discovery_to_database(
self, request: DiscoveryRequest, result: DiscoveryResult
) -> None:
with get_db_session() as session:
discovery_service = DiscoveryService(session)
run = discovery_service.create_run(
request.lookback_period, request.max_results or 20
)
for stock in result.trending_stocks:
discovery_service.save_trending_stock(
run.id,
stock.ticker,
stock.company_name,
stock.trending_score,
stock.mention_count,
stock.sector.value
if hasattr(stock.sector, "value")
else str(stock.sector),
stock.event_type.value
if hasattr(stock.event_type, "value")
else str(stock.event_type),
stock.summary,
[a.get("title", "") for a in stock.source_articles]
if stock.source_articles
else None,
)
if result.status == DiscoveryStatus.COMPLETED:
discovery_service.complete_run(run.id, len(result.trending_stocks))
else:
discovery_service.fail_run(
run.id, result.error_message or "Unknown error"
)
def analyze_trending(
self,
trending_stock: TrendingStock,