feat: integrate database persistence with TradingAgentsGraph
Add database services layer with AnalysisService, TradingService, and DiscoveryService for persisting analysis sessions, trading decisions, and discovery runs. Integration with TradingAgentsGraph: - Add config options: database_enabled, database_path - Persist analysis state to database in _log_state when enabled - Persist discovery results to database when enabled - Save analyst reports, debates, and trading decisions Database integration is opt-in via config setting. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
c39f9aab36
commit
1db81e1fc6
|
|
@ -1,5 +1,6 @@
|
||||||
from .base import Base
|
from .base import Base
|
||||||
from .engine import get_db_session, get_engine, init_database, reset_engine
|
from .engine import get_db_session, get_engine, init_database, reset_engine
|
||||||
|
from .services import AnalysisService, DiscoveryService, TradingService
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Base",
|
"Base",
|
||||||
|
|
@ -7,4 +8,7 @@ __all__ = [
|
||||||
"get_engine",
|
"get_engine",
|
||||||
"init_database",
|
"init_database",
|
||||||
"reset_engine",
|
"reset_engine",
|
||||||
|
"AnalysisService",
|
||||||
|
"DiscoveryService",
|
||||||
|
"TradingService",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
from .analysis import AnalysisService
|
||||||
|
from .discovery import DiscoveryService
|
||||||
|
from .trading import TradingService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnalysisService",
|
||||||
|
"DiscoveryService",
|
||||||
|
"TradingService",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,116 @@
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from tradingagents.database.models import (
|
||||||
|
AnalysisSession,
|
||||||
|
AnalystReport,
|
||||||
|
InvestmentDebate,
|
||||||
|
RiskDebate,
|
||||||
|
)
|
||||||
|
from tradingagents.database.repositories import (
|
||||||
|
AnalysisSessionRepository,
|
||||||
|
AnalystReportRepository,
|
||||||
|
InvestmentDebateRepository,
|
||||||
|
RiskDebateRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisService:
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
self.session = session
|
||||||
|
self.sessions = AnalysisSessionRepository(session)
|
||||||
|
self.reports = AnalystReportRepository(session)
|
||||||
|
self.investment_debates = InvestmentDebateRepository(session)
|
||||||
|
self.risk_debates = RiskDebateRepository(session)
|
||||||
|
|
||||||
|
def create_session(self, ticker: str, trade_date: str) -> AnalysisSession:
|
||||||
|
return self.sessions.create(
|
||||||
|
{
|
||||||
|
"ticker": ticker,
|
||||||
|
"trade_date": trade_date,
|
||||||
|
"status": "running",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_analyst_report(
|
||||||
|
self, session_id: str, analyst_type: str, report_content: str
|
||||||
|
) -> AnalystReport:
|
||||||
|
return self.reports.create(
|
||||||
|
{
|
||||||
|
"session_id": session_id,
|
||||||
|
"analyst_type": analyst_type,
|
||||||
|
"report_content": report_content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_investment_debate(
|
||||||
|
self, session_id: str, debate_state: dict[str, Any]
|
||||||
|
) -> InvestmentDebate:
|
||||||
|
return self.investment_debates.create(
|
||||||
|
{
|
||||||
|
"session_id": session_id,
|
||||||
|
"bull_history": json.dumps(debate_state.get("bull_history", [])),
|
||||||
|
"bear_history": json.dumps(debate_state.get("bear_history", [])),
|
||||||
|
"debate_history": json.dumps(debate_state.get("history", [])),
|
||||||
|
"judge_decision": debate_state.get("judge_decision", ""),
|
||||||
|
"investment_plan": debate_state.get("current_response", ""),
|
||||||
|
"debate_rounds": len(debate_state.get("history", [])),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_risk_debate(
|
||||||
|
self, session_id: str, debate_state: dict[str, Any]
|
||||||
|
) -> RiskDebate:
|
||||||
|
return self.risk_debates.create(
|
||||||
|
{
|
||||||
|
"session_id": session_id,
|
||||||
|
"risky_history": json.dumps(debate_state.get("risky_history", [])),
|
||||||
|
"safe_history": json.dumps(debate_state.get("safe_history", [])),
|
||||||
|
"neutral_history": json.dumps(debate_state.get("neutral_history", [])),
|
||||||
|
"debate_history": json.dumps(debate_state.get("history", [])),
|
||||||
|
"judge_decision": debate_state.get("judge_decision", ""),
|
||||||
|
"debate_rounds": len(debate_state.get("history", [])),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_full_state(
|
||||||
|
self, ticker: str, trade_date: str, final_state: dict[str, Any]
|
||||||
|
) -> AnalysisSession:
|
||||||
|
analysis_session = self.create_session(ticker, trade_date)
|
||||||
|
|
||||||
|
analyst_mappings = [
|
||||||
|
("market", "market_report"),
|
||||||
|
("sentiment", "sentiment_report"),
|
||||||
|
("news", "news_report"),
|
||||||
|
("fundamentals", "fundamentals_report"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for analyst_type, state_key in analyst_mappings:
|
||||||
|
report_content = final_state.get(state_key, "")
|
||||||
|
if report_content:
|
||||||
|
self.save_analyst_report(
|
||||||
|
analysis_session.id, analyst_type, report_content
|
||||||
|
)
|
||||||
|
|
||||||
|
investment_debate_state = final_state.get("investment_debate_state", {})
|
||||||
|
if investment_debate_state:
|
||||||
|
self.save_investment_debate(analysis_session.id, investment_debate_state)
|
||||||
|
|
||||||
|
risk_debate_state = final_state.get("risk_debate_state", {})
|
||||||
|
if risk_debate_state:
|
||||||
|
self.save_risk_debate(analysis_session.id, risk_debate_state)
|
||||||
|
|
||||||
|
self.sessions.mark_completed(analysis_session.id)
|
||||||
|
|
||||||
|
return analysis_session
|
||||||
|
|
||||||
|
def get_session_by_ticker_date(
|
||||||
|
self, ticker: str, trade_date: str
|
||||||
|
) -> AnalysisSession | None:
|
||||||
|
return self.sessions.get_by_ticker_and_date(ticker, trade_date)
|
||||||
|
|
||||||
|
def get_latest_session(self, ticker: str) -> AnalysisSession | None:
|
||||||
|
return self.sessions.get_latest_by_ticker(ticker)
|
||||||
|
|
@ -0,0 +1,158 @@
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from tradingagents.database.models import (
|
||||||
|
DiscoveryArticle,
|
||||||
|
DiscoveryRun,
|
||||||
|
TrendingStockResult,
|
||||||
|
)
|
||||||
|
from tradingagents.database.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoveryRunRepository(BaseRepository[DiscoveryRun]):
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
super().__init__(session, DiscoveryRun)
|
||||||
|
|
||||||
|
def get_latest(self) -> DiscoveryRun | None:
|
||||||
|
return (
|
||||||
|
self.session.query(DiscoveryRun)
|
||||||
|
.order_by(DiscoveryRun.created_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_completed(self, limit: int = 10) -> list[DiscoveryRun]:
|
||||||
|
return (
|
||||||
|
self.session.query(DiscoveryRun)
|
||||||
|
.filter(DiscoveryRun.status == "completed")
|
||||||
|
.order_by(DiscoveryRun.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TrendingStockResultRepository(BaseRepository[TrendingStockResult]):
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
super().__init__(session, TrendingStockResult)
|
||||||
|
|
||||||
|
def get_by_run(self, run_id: str) -> list[TrendingStockResult]:
|
||||||
|
return (
|
||||||
|
self.session.query(TrendingStockResult)
|
||||||
|
.filter(TrendingStockResult.discovery_run_id == run_id)
|
||||||
|
.order_by(TrendingStockResult.trending_score.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_by_ticker(self, ticker: str, limit: int = 10) -> list[TrendingStockResult]:
|
||||||
|
return (
|
||||||
|
self.session.query(TrendingStockResult)
|
||||||
|
.filter(TrendingStockResult.ticker == ticker)
|
||||||
|
.order_by(TrendingStockResult.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoveryArticleRepository(BaseRepository[DiscoveryArticle]):
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
super().__init__(session, DiscoveryArticle)
|
||||||
|
|
||||||
|
def get_by_run(self, run_id: str) -> list[DiscoveryArticle]:
|
||||||
|
return (
|
||||||
|
self.session.query(DiscoveryArticle)
|
||||||
|
.filter(DiscoveryArticle.discovery_run_id == run_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoveryService:
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
self.session = session
|
||||||
|
self.runs = DiscoveryRunRepository(session)
|
||||||
|
self.stocks = TrendingStockResultRepository(session)
|
||||||
|
self.articles = DiscoveryArticleRepository(session)
|
||||||
|
|
||||||
|
def create_run(self, lookback_period: str, max_results: int) -> DiscoveryRun:
|
||||||
|
return self.runs.create(
|
||||||
|
{
|
||||||
|
"lookback_period": lookback_period,
|
||||||
|
"max_results": max_results,
|
||||||
|
"status": "running",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_trending_stock(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
ticker: str,
|
||||||
|
company_name: str,
|
||||||
|
trending_score: float,
|
||||||
|
mention_count: int,
|
||||||
|
sector: str,
|
||||||
|
event_type: str,
|
||||||
|
summary: str | None = None,
|
||||||
|
source_articles: list[str] | None = None,
|
||||||
|
) -> TrendingStockResult:
|
||||||
|
return self.stocks.create(
|
||||||
|
{
|
||||||
|
"discovery_run_id": run_id,
|
||||||
|
"ticker": ticker,
|
||||||
|
"company_name": company_name,
|
||||||
|
"trending_score": trending_score,
|
||||||
|
"mention_count": mention_count,
|
||||||
|
"sector": sector,
|
||||||
|
"event_type": event_type,
|
||||||
|
"summary": summary,
|
||||||
|
"source_articles": json.dumps(source_articles or []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_article(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
title: str,
|
||||||
|
source: str,
|
||||||
|
url: str | None = None,
|
||||||
|
published_at: datetime | None = None,
|
||||||
|
content_snippet: str | None = None,
|
||||||
|
) -> DiscoveryArticle:
|
||||||
|
return self.articles.create(
|
||||||
|
{
|
||||||
|
"discovery_run_id": run_id,
|
||||||
|
"title": title,
|
||||||
|
"source": source,
|
||||||
|
"url": url,
|
||||||
|
"published_at": published_at,
|
||||||
|
"content_snippet": content_snippet,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def complete_run(self, run_id: str, stocks_found: int) -> DiscoveryRun | None:
|
||||||
|
run = self.runs.get(run_id)
|
||||||
|
if run:
|
||||||
|
run.status = "completed"
|
||||||
|
run.completed_at = datetime.utcnow()
|
||||||
|
run.stocks_found = stocks_found
|
||||||
|
self.session.flush()
|
||||||
|
return run
|
||||||
|
|
||||||
|
def fail_run(self, run_id: str, error: str) -> DiscoveryRun | None:
|
||||||
|
run = self.runs.get(run_id)
|
||||||
|
if run:
|
||||||
|
run.status = "failed"
|
||||||
|
run.completed_at = datetime.utcnow()
|
||||||
|
run.error_message = error
|
||||||
|
self.session.flush()
|
||||||
|
return run
|
||||||
|
|
||||||
|
def get_latest_run(self) -> DiscoveryRun | None:
|
||||||
|
return self.runs.get_latest()
|
||||||
|
|
||||||
|
def get_trending_by_run(self, run_id: str) -> list[TrendingStockResult]:
|
||||||
|
return self.stocks.get_by_run(run_id)
|
||||||
|
|
||||||
|
def get_stock_history(
|
||||||
|
self, ticker: str, limit: int = 10
|
||||||
|
) -> list[TrendingStockResult]:
|
||||||
|
return self.stocks.get_by_ticker(ticker, limit)
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from tradingagents.database.models import (
|
||||||
|
TradeExecution,
|
||||||
|
TradeReflection,
|
||||||
|
TradingDecision,
|
||||||
|
)
|
||||||
|
from tradingagents.database.repositories import (
|
||||||
|
TradeExecutionRepository,
|
||||||
|
TradeReflectionRepository,
|
||||||
|
TradingDecisionRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TradingService:
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
self.session = session
|
||||||
|
self.decisions = TradingDecisionRepository(session)
|
||||||
|
self.executions = TradeExecutionRepository(session)
|
||||||
|
self.reflections = TradeReflectionRepository(session)
|
||||||
|
|
||||||
|
def save_trading_decision(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
ticker: str,
|
||||||
|
final_state: dict[str, Any],
|
||||||
|
signal: str,
|
||||||
|
) -> TradingDecision:
|
||||||
|
decision_map = {"BUY": "buy", "SELL": "sell", "HOLD": "hold"}
|
||||||
|
decision = decision_map.get(signal.upper(), "hold")
|
||||||
|
|
||||||
|
return self.decisions.create(
|
||||||
|
{
|
||||||
|
"session_id": session_id,
|
||||||
|
"ticker": ticker,
|
||||||
|
"decision": decision,
|
||||||
|
"trader_plan": final_state.get("trader_investment_plan", ""),
|
||||||
|
"investment_plan": final_state.get("investment_plan", ""),
|
||||||
|
"final_decision_text": final_state.get("final_trade_decision", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_execution(
|
||||||
|
self,
|
||||||
|
decision_id: str,
|
||||||
|
ticker: str,
|
||||||
|
action: str,
|
||||||
|
quantity: int,
|
||||||
|
price: float,
|
||||||
|
) -> TradeExecution:
|
||||||
|
return self.executions.create(
|
||||||
|
{
|
||||||
|
"decision_id": decision_id,
|
||||||
|
"ticker": ticker,
|
||||||
|
"action": action,
|
||||||
|
"quantity": quantity,
|
||||||
|
"price": price,
|
||||||
|
"status": "executed",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_reflection(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
trade_date: str,
|
||||||
|
decision_id: str | None,
|
||||||
|
returns_losses: float,
|
||||||
|
reflection_content: str,
|
||||||
|
) -> TradeReflection:
|
||||||
|
return self.reflections.create(
|
||||||
|
{
|
||||||
|
"ticker": ticker,
|
||||||
|
"trade_date": trade_date,
|
||||||
|
"decision_id": decision_id,
|
||||||
|
"actual_return": returns_losses,
|
||||||
|
"reflection_content": reflection_content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_decision_by_session(self, session_id: str) -> TradingDecision | None:
|
||||||
|
return self.decisions.get_by_session(session_id)
|
||||||
|
|
||||||
|
def get_decisions_by_ticker(
|
||||||
|
self, ticker: str, limit: int = 100
|
||||||
|
) -> list[TradingDecision]:
|
||||||
|
return self.decisions.get_by_ticker(ticker, limit)
|
||||||
|
|
||||||
|
def get_pending_executions(self) -> list[TradeExecution]:
|
||||||
|
return self.executions.get_pending()
|
||||||
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import threading
|
import threading
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
@ -35,6 +35,13 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
)
|
)
|
||||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||||
|
from tradingagents.database import (
|
||||||
|
AnalysisService,
|
||||||
|
DiscoveryService,
|
||||||
|
TradingService,
|
||||||
|
get_db_session,
|
||||||
|
init_database,
|
||||||
|
)
|
||||||
from tradingagents.dataflows.config import get_config, set_config
|
from tradingagents.dataflows.config import get_config, set_config
|
||||||
from tradingagents.dataflows.interface import get_bulk_news
|
from tradingagents.dataflows.interface import get_bulk_news
|
||||||
from tradingagents.validation import validate_date, validate_ticker
|
from tradingagents.validation import validate_date, validate_ticker
|
||||||
|
|
@ -73,6 +80,11 @@ class TradingAgentsGraph:
|
||||||
exist_ok=True,
|
exist_ok=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.db_enabled = self.config.get("database_enabled", False)
|
||||||
|
if self.db_enabled:
|
||||||
|
db_path = self.config.get("database_path")
|
||||||
|
init_database(db_path)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.config["llm_provider"].lower() == "openai"
|
self.config["llm_provider"].lower() == "openai"
|
||||||
or self.config["llm_provider"] == "ollama"
|
or self.config["llm_provider"] == "ollama"
|
||||||
|
|
@ -235,6 +247,23 @@ class TradingAgentsGraph:
|
||||||
) as f:
|
) as f:
|
||||||
json.dump(self.log_states_dict, f, indent=4)
|
json.dump(self.log_states_dict, f, indent=4)
|
||||||
|
|
||||||
|
if self.db_enabled:
|
||||||
|
self._persist_to_database(trade_date, final_state)
|
||||||
|
|
||||||
|
def _persist_to_database(self, trade_date, final_state: dict[str, Any]) -> None:
|
||||||
|
with get_db_session() as session:
|
||||||
|
analysis_service = AnalysisService(session)
|
||||||
|
trading_service = TradingService(session)
|
||||||
|
|
||||||
|
analysis_session = analysis_service.save_full_state(
|
||||||
|
self.ticker, str(trade_date), final_state
|
||||||
|
)
|
||||||
|
|
||||||
|
signal = self.process_signal(final_state.get("final_trade_decision", ""))
|
||||||
|
trading_service.save_trading_decision(
|
||||||
|
analysis_session.id, self.ticker, final_state, signal
|
||||||
|
)
|
||||||
|
|
||||||
def reflect_and_remember(self, returns_losses) -> None:
|
def reflect_and_remember(self, returns_losses) -> None:
|
||||||
self.reflector.reflect_bull_researcher(
|
self.reflector.reflect_bull_researcher(
|
||||||
self.curr_state, returns_losses, self.bull_memory
|
self.curr_state, returns_losses, self.bull_memory
|
||||||
|
|
@ -351,8 +380,47 @@ class TradingAgentsGraph:
|
||||||
result.status = DiscoveryStatus.COMPLETED
|
result.status = DiscoveryStatus.COMPLETED
|
||||||
result.completed_at = datetime.now()
|
result.completed_at = datetime.now()
|
||||||
|
|
||||||
|
if self.db_enabled:
|
||||||
|
self._persist_discovery_to_database(request, result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _persist_discovery_to_database(
|
||||||
|
self, request: DiscoveryRequest, result: DiscoveryResult
|
||||||
|
) -> None:
|
||||||
|
with get_db_session() as session:
|
||||||
|
discovery_service = DiscoveryService(session)
|
||||||
|
|
||||||
|
run = discovery_service.create_run(
|
||||||
|
request.lookback_period, request.max_results or 20
|
||||||
|
)
|
||||||
|
|
||||||
|
for stock in result.trending_stocks:
|
||||||
|
discovery_service.save_trending_stock(
|
||||||
|
run.id,
|
||||||
|
stock.ticker,
|
||||||
|
stock.company_name,
|
||||||
|
stock.trending_score,
|
||||||
|
stock.mention_count,
|
||||||
|
stock.sector.value
|
||||||
|
if hasattr(stock.sector, "value")
|
||||||
|
else str(stock.sector),
|
||||||
|
stock.event_type.value
|
||||||
|
if hasattr(stock.event_type, "value")
|
||||||
|
else str(stock.event_type),
|
||||||
|
stock.summary,
|
||||||
|
[a.get("title", "") for a in stock.source_articles]
|
||||||
|
if stock.source_articles
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.status == DiscoveryStatus.COMPLETED:
|
||||||
|
discovery_service.complete_run(run.id, len(result.trending_stocks))
|
||||||
|
else:
|
||||||
|
discovery_service.fail_run(
|
||||||
|
run.id, result.error_message or "Unknown error"
|
||||||
|
)
|
||||||
|
|
||||||
def analyze_trending(
|
def analyze_trending(
|
||||||
self,
|
self,
|
||||||
trending_stock: TrendingStock,
|
trending_stock: TrendingStock,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue