feat: integrate database persistence with TradingAgentsGraph
Add database services layer with AnalysisService, TradingService, and DiscoveryService for persisting analysis sessions, trading decisions, and discovery runs. Integration with TradingAgentsGraph: - Add config options: database_enabled, database_path - Persist analysis state to database in _log_state when enabled - Persist discovery results to database when enabled - Save analyst reports, debates, and trading decisions Database integration is opt-in via config setting. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
c39f9aab36
commit
1db81e1fc6
|
|
@ -1,5 +1,6 @@
|
|||
from .base import Base
|
||||
from .engine import get_db_session, get_engine, init_database, reset_engine
|
||||
from .services import AnalysisService, DiscoveryService, TradingService
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
|
|
@ -7,4 +8,7 @@ __all__ = [
|
|||
"get_engine",
|
||||
"init_database",
|
||||
"reset_engine",
|
||||
"AnalysisService",
|
||||
"DiscoveryService",
|
||||
"TradingService",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
from .analysis import AnalysisService
|
||||
from .discovery import DiscoveryService
|
||||
from .trading import TradingService
|
||||
|
||||
__all__ = [
|
||||
"AnalysisService",
|
||||
"DiscoveryService",
|
||||
"TradingService",
|
||||
]
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tradingagents.database.models import (
|
||||
AnalysisSession,
|
||||
AnalystReport,
|
||||
InvestmentDebate,
|
||||
RiskDebate,
|
||||
)
|
||||
from tradingagents.database.repositories import (
|
||||
AnalysisSessionRepository,
|
||||
AnalystReportRepository,
|
||||
InvestmentDebateRepository,
|
||||
RiskDebateRepository,
|
||||
)
|
||||
|
||||
|
||||
class AnalysisService:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.sessions = AnalysisSessionRepository(session)
|
||||
self.reports = AnalystReportRepository(session)
|
||||
self.investment_debates = InvestmentDebateRepository(session)
|
||||
self.risk_debates = RiskDebateRepository(session)
|
||||
|
||||
def create_session(self, ticker: str, trade_date: str) -> AnalysisSession:
|
||||
return self.sessions.create(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"trade_date": trade_date,
|
||||
"status": "running",
|
||||
}
|
||||
)
|
||||
|
||||
def save_analyst_report(
|
||||
self, session_id: str, analyst_type: str, report_content: str
|
||||
) -> AnalystReport:
|
||||
return self.reports.create(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"analyst_type": analyst_type,
|
||||
"report_content": report_content,
|
||||
}
|
||||
)
|
||||
|
||||
def save_investment_debate(
|
||||
self, session_id: str, debate_state: dict[str, Any]
|
||||
) -> InvestmentDebate:
|
||||
return self.investment_debates.create(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"bull_history": json.dumps(debate_state.get("bull_history", [])),
|
||||
"bear_history": json.dumps(debate_state.get("bear_history", [])),
|
||||
"debate_history": json.dumps(debate_state.get("history", [])),
|
||||
"judge_decision": debate_state.get("judge_decision", ""),
|
||||
"investment_plan": debate_state.get("current_response", ""),
|
||||
"debate_rounds": len(debate_state.get("history", [])),
|
||||
}
|
||||
)
|
||||
|
||||
def save_risk_debate(
|
||||
self, session_id: str, debate_state: dict[str, Any]
|
||||
) -> RiskDebate:
|
||||
return self.risk_debates.create(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"risky_history": json.dumps(debate_state.get("risky_history", [])),
|
||||
"safe_history": json.dumps(debate_state.get("safe_history", [])),
|
||||
"neutral_history": json.dumps(debate_state.get("neutral_history", [])),
|
||||
"debate_history": json.dumps(debate_state.get("history", [])),
|
||||
"judge_decision": debate_state.get("judge_decision", ""),
|
||||
"debate_rounds": len(debate_state.get("history", [])),
|
||||
}
|
||||
)
|
||||
|
||||
def save_full_state(
|
||||
self, ticker: str, trade_date: str, final_state: dict[str, Any]
|
||||
) -> AnalysisSession:
|
||||
analysis_session = self.create_session(ticker, trade_date)
|
||||
|
||||
analyst_mappings = [
|
||||
("market", "market_report"),
|
||||
("sentiment", "sentiment_report"),
|
||||
("news", "news_report"),
|
||||
("fundamentals", "fundamentals_report"),
|
||||
]
|
||||
|
||||
for analyst_type, state_key in analyst_mappings:
|
||||
report_content = final_state.get(state_key, "")
|
||||
if report_content:
|
||||
self.save_analyst_report(
|
||||
analysis_session.id, analyst_type, report_content
|
||||
)
|
||||
|
||||
investment_debate_state = final_state.get("investment_debate_state", {})
|
||||
if investment_debate_state:
|
||||
self.save_investment_debate(analysis_session.id, investment_debate_state)
|
||||
|
||||
risk_debate_state = final_state.get("risk_debate_state", {})
|
||||
if risk_debate_state:
|
||||
self.save_risk_debate(analysis_session.id, risk_debate_state)
|
||||
|
||||
self.sessions.mark_completed(analysis_session.id)
|
||||
|
||||
return analysis_session
|
||||
|
||||
def get_session_by_ticker_date(
|
||||
self, ticker: str, trade_date: str
|
||||
) -> AnalysisSession | None:
|
||||
return self.sessions.get_by_ticker_and_date(ticker, trade_date)
|
||||
|
||||
def get_latest_session(self, ticker: str) -> AnalysisSession | None:
|
||||
return self.sessions.get_latest_by_ticker(ticker)
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tradingagents.database.models import (
|
||||
DiscoveryArticle,
|
||||
DiscoveryRun,
|
||||
TrendingStockResult,
|
||||
)
|
||||
from tradingagents.database.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class DiscoveryRunRepository(BaseRepository[DiscoveryRun]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, DiscoveryRun)
|
||||
|
||||
def get_latest(self) -> DiscoveryRun | None:
|
||||
return (
|
||||
self.session.query(DiscoveryRun)
|
||||
.order_by(DiscoveryRun.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_completed(self, limit: int = 10) -> list[DiscoveryRun]:
|
||||
return (
|
||||
self.session.query(DiscoveryRun)
|
||||
.filter(DiscoveryRun.status == "completed")
|
||||
.order_by(DiscoveryRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class TrendingStockResultRepository(BaseRepository[TrendingStockResult]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, TrendingStockResult)
|
||||
|
||||
def get_by_run(self, run_id: str) -> list[TrendingStockResult]:
|
||||
return (
|
||||
self.session.query(TrendingStockResult)
|
||||
.filter(TrendingStockResult.discovery_run_id == run_id)
|
||||
.order_by(TrendingStockResult.trending_score.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_by_ticker(self, ticker: str, limit: int = 10) -> list[TrendingStockResult]:
|
||||
return (
|
||||
self.session.query(TrendingStockResult)
|
||||
.filter(TrendingStockResult.ticker == ticker)
|
||||
.order_by(TrendingStockResult.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class DiscoveryArticleRepository(BaseRepository[DiscoveryArticle]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, DiscoveryArticle)
|
||||
|
||||
def get_by_run(self, run_id: str) -> list[DiscoveryArticle]:
|
||||
return (
|
||||
self.session.query(DiscoveryArticle)
|
||||
.filter(DiscoveryArticle.discovery_run_id == run_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class DiscoveryService:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.runs = DiscoveryRunRepository(session)
|
||||
self.stocks = TrendingStockResultRepository(session)
|
||||
self.articles = DiscoveryArticleRepository(session)
|
||||
|
||||
def create_run(self, lookback_period: str, max_results: int) -> DiscoveryRun:
|
||||
return self.runs.create(
|
||||
{
|
||||
"lookback_period": lookback_period,
|
||||
"max_results": max_results,
|
||||
"status": "running",
|
||||
}
|
||||
)
|
||||
|
||||
def save_trending_stock(
|
||||
self,
|
||||
run_id: str,
|
||||
ticker: str,
|
||||
company_name: str,
|
||||
trending_score: float,
|
||||
mention_count: int,
|
||||
sector: str,
|
||||
event_type: str,
|
||||
summary: str | None = None,
|
||||
source_articles: list[str] | None = None,
|
||||
) -> TrendingStockResult:
|
||||
return self.stocks.create(
|
||||
{
|
||||
"discovery_run_id": run_id,
|
||||
"ticker": ticker,
|
||||
"company_name": company_name,
|
||||
"trending_score": trending_score,
|
||||
"mention_count": mention_count,
|
||||
"sector": sector,
|
||||
"event_type": event_type,
|
||||
"summary": summary,
|
||||
"source_articles": json.dumps(source_articles or []),
|
||||
}
|
||||
)
|
||||
|
||||
def save_article(
|
||||
self,
|
||||
run_id: str,
|
||||
title: str,
|
||||
source: str,
|
||||
url: str | None = None,
|
||||
published_at: datetime | None = None,
|
||||
content_snippet: str | None = None,
|
||||
) -> DiscoveryArticle:
|
||||
return self.articles.create(
|
||||
{
|
||||
"discovery_run_id": run_id,
|
||||
"title": title,
|
||||
"source": source,
|
||||
"url": url,
|
||||
"published_at": published_at,
|
||||
"content_snippet": content_snippet,
|
||||
}
|
||||
)
|
||||
|
||||
def complete_run(self, run_id: str, stocks_found: int) -> DiscoveryRun | None:
|
||||
run = self.runs.get(run_id)
|
||||
if run:
|
||||
run.status = "completed"
|
||||
run.completed_at = datetime.utcnow()
|
||||
run.stocks_found = stocks_found
|
||||
self.session.flush()
|
||||
return run
|
||||
|
||||
def fail_run(self, run_id: str, error: str) -> DiscoveryRun | None:
|
||||
run = self.runs.get(run_id)
|
||||
if run:
|
||||
run.status = "failed"
|
||||
run.completed_at = datetime.utcnow()
|
||||
run.error_message = error
|
||||
self.session.flush()
|
||||
return run
|
||||
|
||||
def get_latest_run(self) -> DiscoveryRun | None:
|
||||
return self.runs.get_latest()
|
||||
|
||||
def get_trending_by_run(self, run_id: str) -> list[TrendingStockResult]:
|
||||
return self.stocks.get_by_run(run_id)
|
||||
|
||||
def get_stock_history(
|
||||
self, ticker: str, limit: int = 10
|
||||
) -> list[TrendingStockResult]:
|
||||
return self.stocks.get_by_ticker(ticker, limit)
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tradingagents.database.models import (
|
||||
TradeExecution,
|
||||
TradeReflection,
|
||||
TradingDecision,
|
||||
)
|
||||
from tradingagents.database.repositories import (
|
||||
TradeExecutionRepository,
|
||||
TradeReflectionRepository,
|
||||
TradingDecisionRepository,
|
||||
)
|
||||
|
||||
|
||||
class TradingService:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.decisions = TradingDecisionRepository(session)
|
||||
self.executions = TradeExecutionRepository(session)
|
||||
self.reflections = TradeReflectionRepository(session)
|
||||
|
||||
def save_trading_decision(
|
||||
self,
|
||||
session_id: str,
|
||||
ticker: str,
|
||||
final_state: dict[str, Any],
|
||||
signal: str,
|
||||
) -> TradingDecision:
|
||||
decision_map = {"BUY": "buy", "SELL": "sell", "HOLD": "hold"}
|
||||
decision = decision_map.get(signal.upper(), "hold")
|
||||
|
||||
return self.decisions.create(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"ticker": ticker,
|
||||
"decision": decision,
|
||||
"trader_plan": final_state.get("trader_investment_plan", ""),
|
||||
"investment_plan": final_state.get("investment_plan", ""),
|
||||
"final_decision_text": final_state.get("final_trade_decision", ""),
|
||||
}
|
||||
)
|
||||
|
||||
def record_execution(
|
||||
self,
|
||||
decision_id: str,
|
||||
ticker: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
) -> TradeExecution:
|
||||
return self.executions.create(
|
||||
{
|
||||
"decision_id": decision_id,
|
||||
"ticker": ticker,
|
||||
"action": action,
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"status": "executed",
|
||||
}
|
||||
)
|
||||
|
||||
def save_reflection(
|
||||
self,
|
||||
ticker: str,
|
||||
trade_date: str,
|
||||
decision_id: str | None,
|
||||
returns_losses: float,
|
||||
reflection_content: str,
|
||||
) -> TradeReflection:
|
||||
return self.reflections.create(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"trade_date": trade_date,
|
||||
"decision_id": decision_id,
|
||||
"actual_return": returns_losses,
|
||||
"reflection_content": reflection_content,
|
||||
}
|
||||
)
|
||||
|
||||
def get_decision_by_session(self, session_id: str) -> TradingDecision | None:
|
||||
return self.decisions.get_by_session(session_id)
|
||||
|
||||
def get_decisions_by_ticker(
|
||||
self, ticker: str, limit: int = 100
|
||||
) -> list[TradingDecision]:
|
||||
return self.decisions.get_by_ticker(ticker, limit)
|
||||
|
||||
def get_pending_executions(self) -> list[TradeExecution]:
|
||||
return self.executions.get_pending()
|
||||
|
|
@ -4,7 +4,7 @@ import os
|
|||
import threading
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
|
@ -35,6 +35,13 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_stock_data,
|
||||
)
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.database import (
|
||||
AnalysisService,
|
||||
DiscoveryService,
|
||||
TradingService,
|
||||
get_db_session,
|
||||
init_database,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config, set_config
|
||||
from tradingagents.dataflows.interface import get_bulk_news
|
||||
from tradingagents.validation import validate_date, validate_ticker
|
||||
|
|
@ -73,6 +80,11 @@ class TradingAgentsGraph:
|
|||
exist_ok=True,
|
||||
)
|
||||
|
||||
self.db_enabled = self.config.get("database_enabled", False)
|
||||
if self.db_enabled:
|
||||
db_path = self.config.get("database_path")
|
||||
init_database(db_path)
|
||||
|
||||
if (
|
||||
self.config["llm_provider"].lower() == "openai"
|
||||
or self.config["llm_provider"] == "ollama"
|
||||
|
|
@ -235,6 +247,23 @@ class TradingAgentsGraph:
|
|||
) as f:
|
||||
json.dump(self.log_states_dict, f, indent=4)
|
||||
|
||||
if self.db_enabled:
|
||||
self._persist_to_database(trade_date, final_state)
|
||||
|
||||
def _persist_to_database(self, trade_date, final_state: dict[str, Any]) -> None:
|
||||
with get_db_session() as session:
|
||||
analysis_service = AnalysisService(session)
|
||||
trading_service = TradingService(session)
|
||||
|
||||
analysis_session = analysis_service.save_full_state(
|
||||
self.ticker, str(trade_date), final_state
|
||||
)
|
||||
|
||||
signal = self.process_signal(final_state.get("final_trade_decision", ""))
|
||||
trading_service.save_trading_decision(
|
||||
analysis_session.id, self.ticker, final_state, signal
|
||||
)
|
||||
|
||||
def reflect_and_remember(self, returns_losses) -> None:
|
||||
self.reflector.reflect_bull_researcher(
|
||||
self.curr_state, returns_losses, self.bull_memory
|
||||
|
|
@ -351,8 +380,47 @@ class TradingAgentsGraph:
|
|||
result.status = DiscoveryStatus.COMPLETED
|
||||
result.completed_at = datetime.now()
|
||||
|
||||
if self.db_enabled:
|
||||
self._persist_discovery_to_database(request, result)
|
||||
|
||||
return result
|
||||
|
||||
def _persist_discovery_to_database(
|
||||
self, request: DiscoveryRequest, result: DiscoveryResult
|
||||
) -> None:
|
||||
with get_db_session() as session:
|
||||
discovery_service = DiscoveryService(session)
|
||||
|
||||
run = discovery_service.create_run(
|
||||
request.lookback_period, request.max_results or 20
|
||||
)
|
||||
|
||||
for stock in result.trending_stocks:
|
||||
discovery_service.save_trending_stock(
|
||||
run.id,
|
||||
stock.ticker,
|
||||
stock.company_name,
|
||||
stock.trending_score,
|
||||
stock.mention_count,
|
||||
stock.sector.value
|
||||
if hasattr(stock.sector, "value")
|
||||
else str(stock.sector),
|
||||
stock.event_type.value
|
||||
if hasattr(stock.event_type, "value")
|
||||
else str(stock.event_type),
|
||||
stock.summary,
|
||||
[a.get("title", "") for a in stock.source_articles]
|
||||
if stock.source_articles
|
||||
else None,
|
||||
)
|
||||
|
||||
if result.status == DiscoveryStatus.COMPLETED:
|
||||
discovery_service.complete_run(run.id, len(result.trending_stocks))
|
||||
else:
|
||||
discovery_service.fail_run(
|
||||
run.id, result.error_message or "Unknown error"
|
||||
)
|
||||
|
||||
def analyze_trending(
|
||||
self,
|
||||
trending_stock: TrendingStock,
|
||||
|
|
|
|||
Loading…
Reference in New Issue