feat: integrate database persistence with TradingAgentsGraph

Add database services layer with AnalysisService, TradingService, and DiscoveryService
for persisting analysis sessions, trading decisions, and discovery runs.

Integration with TradingAgentsGraph:
- Add config options: database_enabled, database_path
- Persist analysis state to database in _log_state when enabled
- Persist discovery results to database when enabled
- Save analyst reports, debates, and trading decisions

Database integration is opt-in via config setting.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 11:31:21 -05:00
parent c39f9aab36
commit 1db81e1fc6
6 changed files with 447 additions and 1 deletions

View File

@ -1,5 +1,6 @@
from .base import Base from .base import Base
from .engine import get_db_session, get_engine, init_database, reset_engine from .engine import get_db_session, get_engine, init_database, reset_engine
from .services import AnalysisService, DiscoveryService, TradingService
__all__ = [ __all__ = [
"Base", "Base",
@ -7,4 +8,7 @@ __all__ = [
"get_engine", "get_engine",
"init_database", "init_database",
"reset_engine", "reset_engine",
"AnalysisService",
"DiscoveryService",
"TradingService",
] ]

View File

@ -0,0 +1,9 @@
from .analysis import AnalysisService
from .discovery import DiscoveryService
from .trading import TradingService
__all__ = [
"AnalysisService",
"DiscoveryService",
"TradingService",
]

View File

@ -0,0 +1,116 @@
import json
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
from tradingagents.database.models import (
AnalysisSession,
AnalystReport,
InvestmentDebate,
RiskDebate,
)
from tradingagents.database.repositories import (
AnalysisSessionRepository,
AnalystReportRepository,
InvestmentDebateRepository,
RiskDebateRepository,
)
class AnalysisService:
def __init__(self, session: Session):
self.session = session
self.sessions = AnalysisSessionRepository(session)
self.reports = AnalystReportRepository(session)
self.investment_debates = InvestmentDebateRepository(session)
self.risk_debates = RiskDebateRepository(session)
def create_session(self, ticker: str, trade_date: str) -> AnalysisSession:
return self.sessions.create(
{
"ticker": ticker,
"trade_date": trade_date,
"status": "running",
}
)
def save_analyst_report(
self, session_id: str, analyst_type: str, report_content: str
) -> AnalystReport:
return self.reports.create(
{
"session_id": session_id,
"analyst_type": analyst_type,
"report_content": report_content,
}
)
def save_investment_debate(
self, session_id: str, debate_state: dict[str, Any]
) -> InvestmentDebate:
return self.investment_debates.create(
{
"session_id": session_id,
"bull_history": json.dumps(debate_state.get("bull_history", [])),
"bear_history": json.dumps(debate_state.get("bear_history", [])),
"debate_history": json.dumps(debate_state.get("history", [])),
"judge_decision": debate_state.get("judge_decision", ""),
"investment_plan": debate_state.get("current_response", ""),
"debate_rounds": len(debate_state.get("history", [])),
}
)
def save_risk_debate(
self, session_id: str, debate_state: dict[str, Any]
) -> RiskDebate:
return self.risk_debates.create(
{
"session_id": session_id,
"risky_history": json.dumps(debate_state.get("risky_history", [])),
"safe_history": json.dumps(debate_state.get("safe_history", [])),
"neutral_history": json.dumps(debate_state.get("neutral_history", [])),
"debate_history": json.dumps(debate_state.get("history", [])),
"judge_decision": debate_state.get("judge_decision", ""),
"debate_rounds": len(debate_state.get("history", [])),
}
)
def save_full_state(
self, ticker: str, trade_date: str, final_state: dict[str, Any]
) -> AnalysisSession:
analysis_session = self.create_session(ticker, trade_date)
analyst_mappings = [
("market", "market_report"),
("sentiment", "sentiment_report"),
("news", "news_report"),
("fundamentals", "fundamentals_report"),
]
for analyst_type, state_key in analyst_mappings:
report_content = final_state.get(state_key, "")
if report_content:
self.save_analyst_report(
analysis_session.id, analyst_type, report_content
)
investment_debate_state = final_state.get("investment_debate_state", {})
if investment_debate_state:
self.save_investment_debate(analysis_session.id, investment_debate_state)
risk_debate_state = final_state.get("risk_debate_state", {})
if risk_debate_state:
self.save_risk_debate(analysis_session.id, risk_debate_state)
self.sessions.mark_completed(analysis_session.id)
return analysis_session
def get_session_by_ticker_date(
self, ticker: str, trade_date: str
) -> AnalysisSession | None:
return self.sessions.get_by_ticker_and_date(ticker, trade_date)
def get_latest_session(self, ticker: str) -> AnalysisSession | None:
return self.sessions.get_latest_by_ticker(ticker)

View File

@ -0,0 +1,158 @@
import json
from datetime import datetime
from sqlalchemy.orm import Session
from tradingagents.database.models import (
DiscoveryArticle,
DiscoveryRun,
TrendingStockResult,
)
from tradingagents.database.repositories.base import BaseRepository
class DiscoveryRunRepository(BaseRepository[DiscoveryRun]):
def __init__(self, session: Session):
super().__init__(session, DiscoveryRun)
def get_latest(self) -> DiscoveryRun | None:
return (
self.session.query(DiscoveryRun)
.order_by(DiscoveryRun.created_at.desc())
.first()
)
def get_completed(self, limit: int = 10) -> list[DiscoveryRun]:
return (
self.session.query(DiscoveryRun)
.filter(DiscoveryRun.status == "completed")
.order_by(DiscoveryRun.created_at.desc())
.limit(limit)
.all()
)
class TrendingStockResultRepository(BaseRepository[TrendingStockResult]):
def __init__(self, session: Session):
super().__init__(session, TrendingStockResult)
def get_by_run(self, run_id: str) -> list[TrendingStockResult]:
return (
self.session.query(TrendingStockResult)
.filter(TrendingStockResult.discovery_run_id == run_id)
.order_by(TrendingStockResult.trending_score.desc())
.all()
)
def get_by_ticker(self, ticker: str, limit: int = 10) -> list[TrendingStockResult]:
return (
self.session.query(TrendingStockResult)
.filter(TrendingStockResult.ticker == ticker)
.order_by(TrendingStockResult.created_at.desc())
.limit(limit)
.all()
)
class DiscoveryArticleRepository(BaseRepository[DiscoveryArticle]):
def __init__(self, session: Session):
super().__init__(session, DiscoveryArticle)
def get_by_run(self, run_id: str) -> list[DiscoveryArticle]:
return (
self.session.query(DiscoveryArticle)
.filter(DiscoveryArticle.discovery_run_id == run_id)
.all()
)
class DiscoveryService:
def __init__(self, session: Session):
self.session = session
self.runs = DiscoveryRunRepository(session)
self.stocks = TrendingStockResultRepository(session)
self.articles = DiscoveryArticleRepository(session)
def create_run(self, lookback_period: str, max_results: int) -> DiscoveryRun:
return self.runs.create(
{
"lookback_period": lookback_period,
"max_results": max_results,
"status": "running",
}
)
def save_trending_stock(
self,
run_id: str,
ticker: str,
company_name: str,
trending_score: float,
mention_count: int,
sector: str,
event_type: str,
summary: str | None = None,
source_articles: list[str] | None = None,
) -> TrendingStockResult:
return self.stocks.create(
{
"discovery_run_id": run_id,
"ticker": ticker,
"company_name": company_name,
"trending_score": trending_score,
"mention_count": mention_count,
"sector": sector,
"event_type": event_type,
"summary": summary,
"source_articles": json.dumps(source_articles or []),
}
)
def save_article(
self,
run_id: str,
title: str,
source: str,
url: str | None = None,
published_at: datetime | None = None,
content_snippet: str | None = None,
) -> DiscoveryArticle:
return self.articles.create(
{
"discovery_run_id": run_id,
"title": title,
"source": source,
"url": url,
"published_at": published_at,
"content_snippet": content_snippet,
}
)
def complete_run(self, run_id: str, stocks_found: int) -> DiscoveryRun | None:
run = self.runs.get(run_id)
if run:
run.status = "completed"
run.completed_at = datetime.utcnow()
run.stocks_found = stocks_found
self.session.flush()
return run
def fail_run(self, run_id: str, error: str) -> DiscoveryRun | None:
run = self.runs.get(run_id)
if run:
run.status = "failed"
run.completed_at = datetime.utcnow()
run.error_message = error
self.session.flush()
return run
def get_latest_run(self) -> DiscoveryRun | None:
return self.runs.get_latest()
def get_trending_by_run(self, run_id: str) -> list[TrendingStockResult]:
return self.stocks.get_by_run(run_id)
def get_stock_history(
self, ticker: str, limit: int = 10
) -> list[TrendingStockResult]:
return self.stocks.get_by_ticker(ticker, limit)

View File

@ -0,0 +1,91 @@
from typing import Any
from sqlalchemy.orm import Session
from tradingagents.database.models import (
TradeExecution,
TradeReflection,
TradingDecision,
)
from tradingagents.database.repositories import (
TradeExecutionRepository,
TradeReflectionRepository,
TradingDecisionRepository,
)
class TradingService:
def __init__(self, session: Session):
self.session = session
self.decisions = TradingDecisionRepository(session)
self.executions = TradeExecutionRepository(session)
self.reflections = TradeReflectionRepository(session)
def save_trading_decision(
self,
session_id: str,
ticker: str,
final_state: dict[str, Any],
signal: str,
) -> TradingDecision:
decision_map = {"BUY": "buy", "SELL": "sell", "HOLD": "hold"}
decision = decision_map.get(signal.upper(), "hold")
return self.decisions.create(
{
"session_id": session_id,
"ticker": ticker,
"decision": decision,
"trader_plan": final_state.get("trader_investment_plan", ""),
"investment_plan": final_state.get("investment_plan", ""),
"final_decision_text": final_state.get("final_trade_decision", ""),
}
)
def record_execution(
self,
decision_id: str,
ticker: str,
action: str,
quantity: int,
price: float,
) -> TradeExecution:
return self.executions.create(
{
"decision_id": decision_id,
"ticker": ticker,
"action": action,
"quantity": quantity,
"price": price,
"status": "executed",
}
)
def save_reflection(
self,
ticker: str,
trade_date: str,
decision_id: str | None,
returns_losses: float,
reflection_content: str,
) -> TradeReflection:
return self.reflections.create(
{
"ticker": ticker,
"trade_date": trade_date,
"decision_id": decision_id,
"actual_return": returns_losses,
"reflection_content": reflection_content,
}
)
def get_decision_by_session(self, session_id: str) -> TradingDecision | None:
return self.decisions.get_by_session(session_id)
def get_decisions_by_ticker(
self, ticker: str, limit: int = 100
) -> list[TradingDecision]:
return self.decisions.get_by_ticker(ticker, limit)
def get_pending_executions(self) -> list[TradeExecution]:
return self.executions.get_pending()

View File

@ -4,7 +4,7 @@ import os
import threading import threading
from datetime import date, datetime from datetime import date, datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple from typing import Any
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
@ -35,6 +35,13 @@ from tradingagents.agents.utils.agent_utils import (
get_stock_data, get_stock_data,
) )
from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.database import (
AnalysisService,
DiscoveryService,
TradingService,
get_db_session,
init_database,
)
from tradingagents.dataflows.config import get_config, set_config from tradingagents.dataflows.config import get_config, set_config
from tradingagents.dataflows.interface import get_bulk_news from tradingagents.dataflows.interface import get_bulk_news
from tradingagents.validation import validate_date, validate_ticker from tradingagents.validation import validate_date, validate_ticker
@ -73,6 +80,11 @@ class TradingAgentsGraph:
exist_ok=True, exist_ok=True,
) )
self.db_enabled = self.config.get("database_enabled", False)
if self.db_enabled:
db_path = self.config.get("database_path")
init_database(db_path)
if ( if (
self.config["llm_provider"].lower() == "openai" self.config["llm_provider"].lower() == "openai"
or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "ollama"
@ -235,6 +247,23 @@ class TradingAgentsGraph:
) as f: ) as f:
json.dump(self.log_states_dict, f, indent=4) json.dump(self.log_states_dict, f, indent=4)
if self.db_enabled:
self._persist_to_database(trade_date, final_state)
def _persist_to_database(self, trade_date, final_state: dict[str, Any]) -> None:
with get_db_session() as session:
analysis_service = AnalysisService(session)
trading_service = TradingService(session)
analysis_session = analysis_service.save_full_state(
self.ticker, str(trade_date), final_state
)
signal = self.process_signal(final_state.get("final_trade_decision", ""))
trading_service.save_trading_decision(
analysis_session.id, self.ticker, final_state, signal
)
def reflect_and_remember(self, returns_losses) -> None: def reflect_and_remember(self, returns_losses) -> None:
self.reflector.reflect_bull_researcher( self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory self.curr_state, returns_losses, self.bull_memory
@ -351,8 +380,47 @@ class TradingAgentsGraph:
result.status = DiscoveryStatus.COMPLETED result.status = DiscoveryStatus.COMPLETED
result.completed_at = datetime.now() result.completed_at = datetime.now()
if self.db_enabled:
self._persist_discovery_to_database(request, result)
return result return result
def _persist_discovery_to_database(
self, request: DiscoveryRequest, result: DiscoveryResult
) -> None:
with get_db_session() as session:
discovery_service = DiscoveryService(session)
run = discovery_service.create_run(
request.lookback_period, request.max_results or 20
)
for stock in result.trending_stocks:
discovery_service.save_trending_stock(
run.id,
stock.ticker,
stock.company_name,
stock.trending_score,
stock.mention_count,
stock.sector.value
if hasattr(stock.sector, "value")
else str(stock.sector),
stock.event_type.value
if hasattr(stock.event_type, "value")
else str(stock.event_type),
stock.summary,
[a.get("title", "") for a in stock.source_articles]
if stock.source_articles
else None,
)
if result.status == DiscoveryStatus.COMPLETED:
discovery_service.complete_run(run.id, len(result.trending_stocks))
else:
discovery_service.fail_run(
run.id, result.error_message or "Unknown error"
)
def analyze_trending( def analyze_trending(
self, self,
trending_stock: TrendingStock, trending_stock: TrendingStock,