463 lines
18 KiB
Python
463 lines
18 KiB
Python
# TradingAgents/graph/trading_graph.py
|
||
|
||
import os
|
||
import sqlite3
|
||
import hashlib
|
||
import logging
|
||
from pathlib import Path
|
||
import json
|
||
from datetime import date
|
||
from typing import Dict, Any, Tuple, List, Optional
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
from langgraph.prebuilt import ToolNode
|
||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||
|
||
from tradingagents.llm_clients import create_llm_client
|
||
|
||
from tradingagents.agents import *
|
||
from tradingagents.default_config import DEFAULT_CONFIG
|
||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||
from tradingagents.agents.utils.agent_states import (
|
||
AgentState,
|
||
InvestDebateState,
|
||
RiskDebateState,
|
||
)
|
||
from tradingagents.dataflows.config import set_config
|
||
|
||
# Import the new abstract tool methods from agent_utils
|
||
from tradingagents.agents.utils.agent_utils import (
|
||
get_stock_data,
|
||
get_indicators,
|
||
get_fundamentals,
|
||
get_balance_sheet,
|
||
get_cashflow,
|
||
get_income_statement,
|
||
get_news,
|
||
get_insider_transactions,
|
||
get_global_news
|
||
)
|
||
|
||
from .conditional_logic import ConditionalLogic
|
||
from .setup import GraphSetup
|
||
from .propagation import Propagator
|
||
from .reflection import Reflector
|
||
from .signal_processing import SignalProcessor
|
||
|
||
|
||
_NODE_TO_STEP = {
|
||
"Market Analyst": "market_analyst",
|
||
"News Analyst": "news_analyst",
|
||
"Fundamentals Analyst": "fundamentals_analyst",
|
||
"Social Analyst": "social_analyst",
|
||
"Bull Researcher": "bull_researcher",
|
||
"Bear Researcher": "bear_researcher",
|
||
"Research Manager": "research_manager",
|
||
"Trader": "trader",
|
||
"Aggressive Analyst": "aggressive_analyst",
|
||
"Conservative Analyst": "conservative_analyst",
|
||
"Neutral Analyst": "neutral_analyst",
|
||
"Risk Judge": "risk_judge",
|
||
}
|
||
|
||
_SKIP_NODES = {"tools_market", "tools_news", "tools_fundamentals", "tools_social"}
|
||
|
||
|
||
class TradingAgentsGraph:
|
||
"""Main class that orchestrates the trading agents framework."""
|
||
|
||
def __init__(
|
||
self,
|
||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||
debug=False,
|
||
config: Dict[str, Any] = None,
|
||
callbacks: Optional[List] = None,
|
||
):
|
||
"""Initialize the trading agents graph and components.
|
||
|
||
Args:
|
||
selected_analysts: List of analyst types to include
|
||
debug: Whether to run in debug mode
|
||
config: Configuration dictionary. If None, uses default config
|
||
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||
"""
|
||
self.debug = debug
|
||
self.config = config or DEFAULT_CONFIG
|
||
self.callbacks = callbacks or []
|
||
self.selected_analysts = list(selected_analysts)
|
||
|
||
# Update the interface's config
|
||
set_config(self.config)
|
||
|
||
# Create necessary directories
|
||
os.makedirs(
|
||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
||
exist_ok=True,
|
||
)
|
||
|
||
# Initialize LLMs with provider-specific thinking configuration
|
||
llm_kwargs = self._get_provider_kwargs()
|
||
|
||
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||
if self.callbacks:
|
||
llm_kwargs["callbacks"] = self.callbacks
|
||
|
||
deep_client = create_llm_client(
|
||
provider=self.config["llm_provider"],
|
||
model=self.config["deep_think_llm"],
|
||
base_url=self.config.get("backend_url"),
|
||
**llm_kwargs,
|
||
)
|
||
quick_client = create_llm_client(
|
||
provider=self.config["llm_provider"],
|
||
model=self.config["quick_think_llm"],
|
||
base_url=self.config.get("backend_url"),
|
||
**llm_kwargs,
|
||
)
|
||
|
||
self.deep_thinking_llm = deep_client.get_llm()
|
||
self.quick_thinking_llm = quick_client.get_llm()
|
||
|
||
# Initialize memories
|
||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||
|
||
# Create tool nodes
|
||
self.tool_nodes = self._create_tool_nodes()
|
||
|
||
# Initialize components
|
||
self.conditional_logic = ConditionalLogic(
|
||
max_debate_rounds=self.config["max_debate_rounds"],
|
||
max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"],
|
||
)
|
||
self.graph_setup = GraphSetup(
|
||
self.quick_thinking_llm,
|
||
self.deep_thinking_llm,
|
||
self.tool_nodes,
|
||
self.bull_memory,
|
||
self.bear_memory,
|
||
self.trader_memory,
|
||
self.invest_judge_memory,
|
||
self.risk_manager_memory,
|
||
self.conditional_logic,
|
||
)
|
||
|
||
self.propagator = Propagator()
|
||
self.reflector = Reflector(self.quick_thinking_llm)
|
||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||
|
||
# State tracking
|
||
self.curr_state = None
|
||
self.ticker = None
|
||
self.log_states_dict = {} # date to full state dict
|
||
|
||
self._sqlite_conn, self.checkpointer = self._create_sqlite_checkpointer(self.config)
|
||
# Set up the graph (durable checkpoints for resume after crash)
|
||
self.graph = self.graph_setup.setup_graph(
|
||
selected_analysts, checkpointer=self.checkpointer
|
||
)
|
||
|
||
@staticmethod
|
||
def _create_sqlite_checkpointer(
|
||
config: Dict[str, Any],
|
||
) -> Tuple[sqlite3.Connection, SqliteSaver]:
|
||
"""SQLite checkpoint store under results_dir/.checkpoints/langgraph.sqlite.
|
||
|
||
Returns:
|
||
(conn, checkpointer) – caller must close conn when done.
|
||
"""
|
||
results_dir = Path(config.get("results_dir", "./results")).expanduser().resolve()
|
||
checkpoint_dir = results_dir / ".checkpoints"
|
||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||
db_path = checkpoint_dir / "langgraph.sqlite"
|
||
conn = sqlite3.connect(str(db_path), check_same_thread=False)
|
||
return conn, SqliteSaver(conn)
|
||
|
||
def close(self) -> None:
|
||
"""Close the underlying SQLite connection held by the checkpointer."""
|
||
try:
|
||
self._sqlite_conn.close()
|
||
except Exception:
|
||
pass
|
||
|
||
def __del__(self) -> None:
|
||
self.close()
|
||
|
||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||
"""Get provider-specific kwargs for LLM client creation."""
|
||
kwargs = {}
|
||
provider = self.config.get("llm_provider", "").lower()
|
||
|
||
if provider == "google":
|
||
thinking_level = self.config.get("google_thinking_level")
|
||
if thinking_level:
|
||
kwargs["thinking_level"] = thinking_level
|
||
|
||
elif provider == "openai":
|
||
reasoning_effort = self.config.get("openai_reasoning_effort")
|
||
if reasoning_effort:
|
||
kwargs["reasoning_effort"] = reasoning_effort
|
||
|
||
return kwargs
|
||
|
||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||
"""Create tool nodes for different data sources using abstract methods."""
|
||
return {
|
||
"market": ToolNode(
|
||
[
|
||
# Core stock data tools
|
||
get_stock_data,
|
||
# Technical indicators
|
||
get_indicators,
|
||
]
|
||
),
|
||
"social": ToolNode(
|
||
[
|
||
# News tools for social media analysis
|
||
get_news,
|
||
]
|
||
),
|
||
"news": ToolNode(
|
||
[
|
||
# News and insider information
|
||
get_news,
|
||
get_global_news,
|
||
get_insider_transactions,
|
||
]
|
||
),
|
||
"fundamentals": ToolNode(
|
||
[
|
||
# Fundamental analysis tools
|
||
get_fundamentals,
|
||
get_balance_sheet,
|
||
get_cashflow,
|
||
get_income_statement,
|
||
]
|
||
),
|
||
}
|
||
|
||
def propagate(self, company_name, trade_date, thread_id: Optional[str] = None):
|
||
"""Run the trading agents graph for a company on a specific date."""
|
||
|
||
self.ticker = company_name
|
||
|
||
if thread_id is None:
|
||
payload = json.dumps(
|
||
{
|
||
"ticker": company_name.strip().upper(),
|
||
"trade_date": str(trade_date),
|
||
"analysts": sorted(self.selected_analysts),
|
||
"llm_provider": self.config.get("llm_provider"),
|
||
"deep_think_llm": self.config.get("deep_think_llm"),
|
||
"quick_think_llm": self.config.get("quick_think_llm"),
|
||
"max_debate_rounds": self.config.get("max_debate_rounds"),
|
||
"max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"),
|
||
},
|
||
sort_keys=True,
|
||
).encode()
|
||
thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24]
|
||
|
||
# Initialize state
|
||
init_agent_state = self.propagator.create_initial_state(
|
||
company_name, trade_date
|
||
)
|
||
args = self.propagator.get_graph_args(thread_id=thread_id)
|
||
|
||
# Determine stream input: resume from checkpoint if an incomplete run exists,
|
||
# otherwise start fresh. Passing None tells LangGraph to resume from the last
|
||
# saved checkpoint for this thread_id.
|
||
thread_config = {"configurable": {"thread_id": thread_id}}
|
||
snap = self.graph.get_state(thread_config)
|
||
if snap.next:
|
||
# Incomplete run found — resume automatically (no user prompt in API mode)
|
||
stream_input = None
|
||
else:
|
||
stream_input = init_agent_state
|
||
|
||
if self.debug:
|
||
# Debug mode with tracing
|
||
trace = []
|
||
for chunk in self.graph.stream(stream_input, **args):
|
||
if len(chunk["messages"]) == 0:
|
||
pass
|
||
else:
|
||
chunk["messages"][-1].pretty_print()
|
||
trace.append(chunk)
|
||
|
||
if not trace:
|
||
raise RuntimeError(
|
||
"Graph stream produced no output — all chunks had empty messages. "
|
||
f"ticker={company_name}, trade_date={trade_date}, thread_id={thread_id}"
|
||
)
|
||
final_state = trace[-1]
|
||
else:
|
||
# Standard mode without tracing
|
||
final_state = self.graph.invoke(stream_input, **args)
|
||
|
||
# Store current state for reflection
|
||
self.curr_state = final_state
|
||
|
||
# Log state
|
||
self._log_state(trade_date, final_state)
|
||
|
||
# Return decision and processed signal
|
||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||
|
||
@staticmethod
|
||
def _extract_report(step_key: str, update: dict) -> str:
|
||
"""Extract the relevant report string from a node's state update."""
|
||
extractors = {
|
||
"market_analyst": lambda u: u.get("market_report", ""),
|
||
"news_analyst": lambda u: u.get("news_report", ""),
|
||
"fundamentals_analyst": lambda u: u.get("fundamentals_report", ""),
|
||
"social_analyst": lambda u: u.get("sentiment_report", ""),
|
||
"bull_researcher": lambda u: (u.get("investment_debate_state") or {}).get("bull_history", ""),
|
||
"bear_researcher": lambda u: (u.get("investment_debate_state") or {}).get("bear_history", ""),
|
||
"research_manager": lambda u: u.get("investment_plan", ""),
|
||
"trader": lambda u: u.get("trader_investment_plan", ""),
|
||
"aggressive_analyst": lambda u: (u.get("risk_debate_state") or {}).get("current_aggressive_response", ""),
|
||
"conservative_analyst": lambda u: (u.get("risk_debate_state") or {}).get("current_conservative_response", ""),
|
||
"neutral_analyst": lambda u: (u.get("risk_debate_state") or {}).get("current_neutral_response", ""),
|
||
"risk_judge": lambda u: (u.get("risk_debate_state") or {}).get("judge_decision", ""),
|
||
}
|
||
return extractors[step_key](update) or ""
|
||
|
||
def stream_propagate(self, company_name: str, trade_date: str, thread_id=None):
|
||
"""Stream trading analysis events as each agent node completes.
|
||
|
||
Yields:
|
||
(step_key, report) tuples for each meaningful node completion.
|
||
|
||
After the generator is exhausted, self._last_decision is set to the
|
||
normalized decision string ("BUY", "SELL", or "HOLD").
|
||
"""
|
||
self.ticker = company_name
|
||
self._last_decision = None
|
||
|
||
if thread_id is None:
|
||
payload = json.dumps(
|
||
{
|
||
"ticker": company_name.strip().upper(),
|
||
"trade_date": str(trade_date),
|
||
"analysts": sorted(self.selected_analysts),
|
||
"llm_provider": self.config.get("llm_provider"),
|
||
"deep_think_llm": self.config.get("deep_think_llm"),
|
||
"quick_think_llm": self.config.get("quick_think_llm"),
|
||
"max_debate_rounds": self.config.get("max_debate_rounds"),
|
||
"max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"),
|
||
},
|
||
sort_keys=True,
|
||
).encode()
|
||
thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24]
|
||
|
||
init_agent_state = self.propagator.create_initial_state(company_name, trade_date)
|
||
args = self.propagator.get_graph_args(thread_id=thread_id)
|
||
args["stream_mode"] = "updates" # stream per-node deltas, not full state snapshots
|
||
|
||
thread_config = {"configurable": {"thread_id": thread_id}}
|
||
snap = self.graph.get_state(thread_config)
|
||
stream_input = None if snap.next else init_agent_state
|
||
|
||
for chunk in self.graph.stream(stream_input, **args):
|
||
node_name, update = next(iter(chunk.items()))
|
||
|
||
# Filter: skip list first, then known nodes, else warn and skip
|
||
if node_name in _SKIP_NODES or node_name.startswith("Msg Clear"):
|
||
continue
|
||
if node_name not in _NODE_TO_STEP:
|
||
logger.warning("stream_propagate: unknown node '%s' — skipping", node_name)
|
||
continue
|
||
|
||
step_key = _NODE_TO_STEP[node_name]
|
||
report = TradingAgentsGraph._extract_report(step_key, update)
|
||
|
||
yield step_key, report
|
||
|
||
# Post-loop: fetch the complete final state snapshot (all fields populated).
|
||
# stream_mode="updates" gives only deltas — use get_state() for the full picture
|
||
# needed by _log_state and process_signal.
|
||
final_snap = self.graph.get_state(thread_config)
|
||
final_state = final_snap.values if hasattr(final_snap, "values") else {}
|
||
|
||
raw_signal = final_state.get("final_trade_decision", "")
|
||
try:
|
||
raw_decision = self.process_signal(raw_signal)
|
||
decision = raw_decision.strip().upper()
|
||
if decision not in {"BUY", "SELL", "HOLD"}:
|
||
logger.warning("stream_propagate: unexpected decision '%s' — defaulting to HOLD", decision)
|
||
decision = "HOLD"
|
||
except Exception:
|
||
raise # propagate to run_service for run:error handling
|
||
|
||
self._last_decision = decision
|
||
self._log_state(trade_date, final_state)
|
||
|
||
def _log_state(self, trade_date, final_state):
|
||
"""Log the final state to a JSON file."""
|
||
self.log_states_dict[str(trade_date)] = {
|
||
"company_of_interest": final_state["company_of_interest"],
|
||
"trade_date": final_state["trade_date"],
|
||
"market_report": final_state["market_report"],
|
||
"sentiment_report": final_state["sentiment_report"],
|
||
"news_report": final_state["news_report"],
|
||
"fundamentals_report": final_state["fundamentals_report"],
|
||
"investment_debate_state": {
|
||
"bull_history": final_state["investment_debate_state"]["bull_history"],
|
||
"bear_history": final_state["investment_debate_state"]["bear_history"],
|
||
"history": final_state["investment_debate_state"]["history"],
|
||
"current_response": final_state["investment_debate_state"][
|
||
"current_response"
|
||
],
|
||
"judge_decision": final_state["investment_debate_state"][
|
||
"judge_decision"
|
||
],
|
||
},
|
||
"trader_investment_decision": final_state["trader_investment_plan"],
|
||
"risk_debate_state": {
|
||
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
|
||
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
|
||
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
|
||
"history": final_state["risk_debate_state"]["history"],
|
||
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
|
||
},
|
||
"investment_plan": final_state["investment_plan"],
|
||
"final_trade_decision": final_state["final_trade_decision"],
|
||
}
|
||
|
||
# Save to file
|
||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
||
directory.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(
|
||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
||
"w",
|
||
encoding="utf-8",
|
||
) as f:
|
||
json.dump(self.log_states_dict, f, indent=4)
|
||
|
||
def reflect_and_remember(self, returns_losses):
|
||
"""Reflect on decisions and update memory based on returns."""
|
||
self.reflector.reflect_bull_researcher(
|
||
self.curr_state, returns_losses, self.bull_memory
|
||
)
|
||
self.reflector.reflect_bear_researcher(
|
||
self.curr_state, returns_losses, self.bear_memory
|
||
)
|
||
self.reflector.reflect_trader(
|
||
self.curr_state, returns_losses, self.trader_memory
|
||
)
|
||
self.reflector.reflect_invest_judge(
|
||
self.curr_state, returns_losses, self.invest_judge_memory
|
||
)
|
||
self.reflector.reflect_risk_manager(
|
||
self.curr_state, returns_losses, self.risk_manager_memory
|
||
)
|
||
|
||
def process_signal(self, full_signal):
|
||
"""Process a signal to extract the core decision."""
|
||
return self.signal_processor.process_signal(full_signal)
|