TradingAgents/tradingagents/graph/trading_graph.py

463 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# TradingAgents/graph/trading_graph.py
import os
import sqlite3
import hashlib
import logging
from pathlib import Path
import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
logger = logging.getLogger(__name__)
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.sqlite import SqliteSaver
from tradingagents.llm_clients import create_llm_client
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
from tradingagents.dataflows.config import set_config
# Import the new abstract tool methods from agent_utils
from tradingagents.agents.utils.agent_utils import (
get_stock_data,
get_indicators,
get_fundamentals,
get_balance_sheet,
get_cashflow,
get_income_statement,
get_news,
get_insider_transactions,
get_global_news
)
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .signal_processing import SignalProcessor
_NODE_TO_STEP = {
"Market Analyst": "market_analyst",
"News Analyst": "news_analyst",
"Fundamentals Analyst": "fundamentals_analyst",
"Social Analyst": "social_analyst",
"Bull Researcher": "bull_researcher",
"Bear Researcher": "bear_researcher",
"Research Manager": "research_manager",
"Trader": "trader",
"Aggressive Analyst": "aggressive_analyst",
"Conservative Analyst": "conservative_analyst",
"Neutral Analyst": "neutral_analyst",
"Risk Judge": "risk_judge",
}
_SKIP_NODES = {"tools_market", "tools_news", "tools_fundamentals", "tools_social"}
class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework."""
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
callbacks: Optional[List] = None,
):
"""Initialize the trading agents graph and components.
Args:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
self.callbacks = callbacks or []
self.selected_analysts = list(selected_analysts)
# Update the interface's config
set_config(self.config)
# Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
# Add callbacks to kwargs if provided (passed to LLM constructor)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
quick_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["quick_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
# Create tool nodes
self.tool_nodes = self._create_tool_nodes()
# Initialize components
self.conditional_logic = ConditionalLogic(
max_debate_rounds=self.config["max_debate_rounds"],
max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"],
)
self.graph_setup = GraphSetup(
self.quick_thinking_llm,
self.deep_thinking_llm,
self.tool_nodes,
self.bull_memory,
self.bear_memory,
self.trader_memory,
self.invest_judge_memory,
self.risk_manager_memory,
self.conditional_logic,
)
self.propagator = Propagator()
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
# State tracking
self.curr_state = None
self.ticker = None
self.log_states_dict = {} # date to full state dict
self._sqlite_conn, self.checkpointer = self._create_sqlite_checkpointer(self.config)
# Set up the graph (durable checkpoints for resume after crash)
self.graph = self.graph_setup.setup_graph(
selected_analysts, checkpointer=self.checkpointer
)
@staticmethod
def _create_sqlite_checkpointer(
config: Dict[str, Any],
) -> Tuple[sqlite3.Connection, SqliteSaver]:
"""SQLite checkpoint store under results_dir/.checkpoints/langgraph.sqlite.
Returns:
(conn, checkpointer) caller must close conn when done.
"""
results_dir = Path(config.get("results_dir", "./results")).expanduser().resolve()
checkpoint_dir = results_dir / ".checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
db_path = checkpoint_dir / "langgraph.sqlite"
conn = sqlite3.connect(str(db_path), check_same_thread=False)
return conn, SqliteSaver(conn)
def close(self) -> None:
"""Close the underlying SQLite connection held by the checkpointer."""
try:
self._sqlite_conn.close()
except Exception:
pass
def __del__(self) -> None:
self.close()
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
kwargs = {}
provider = self.config.get("llm_provider", "").lower()
if provider == "google":
thinking_level = self.config.get("google_thinking_level")
if thinking_level:
kwargs["thinking_level"] = thinking_level
elif provider == "openai":
reasoning_effort = self.config.get("openai_reasoning_effort")
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
"""Create tool nodes for different data sources using abstract methods."""
return {
"market": ToolNode(
[
# Core stock data tools
get_stock_data,
# Technical indicators
get_indicators,
]
),
"social": ToolNode(
[
# News tools for social media analysis
get_news,
]
),
"news": ToolNode(
[
# News and insider information
get_news,
get_global_news,
get_insider_transactions,
]
),
"fundamentals": ToolNode(
[
# Fundamental analysis tools
get_fundamentals,
get_balance_sheet,
get_cashflow,
get_income_statement,
]
),
}
def propagate(self, company_name, trade_date, thread_id: Optional[str] = None):
"""Run the trading agents graph for a company on a specific date."""
self.ticker = company_name
if thread_id is None:
payload = json.dumps(
{
"ticker": company_name.strip().upper(),
"trade_date": str(trade_date),
"analysts": sorted(self.selected_analysts),
"llm_provider": self.config.get("llm_provider"),
"deep_think_llm": self.config.get("deep_think_llm"),
"quick_think_llm": self.config.get("quick_think_llm"),
"max_debate_rounds": self.config.get("max_debate_rounds"),
"max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"),
},
sort_keys=True,
).encode()
thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24]
# Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
args = self.propagator.get_graph_args(thread_id=thread_id)
# Determine stream input: resume from checkpoint if an incomplete run exists,
# otherwise start fresh. Passing None tells LangGraph to resume from the last
# saved checkpoint for this thread_id.
thread_config = {"configurable": {"thread_id": thread_id}}
snap = self.graph.get_state(thread_config)
if snap.next:
# Incomplete run found — resume automatically (no user prompt in API mode)
stream_input = None
else:
stream_input = init_agent_state
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(stream_input, **args):
if len(chunk["messages"]) == 0:
pass
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
if not trace:
raise RuntimeError(
"Graph stream produced no output — all chunks had empty messages. "
f"ticker={company_name}, trade_date={trade_date}, thread_id={thread_id}"
)
final_state = trace[-1]
else:
# Standard mode without tracing
final_state = self.graph.invoke(stream_input, **args)
# Store current state for reflection
self.curr_state = final_state
# Log state
self._log_state(trade_date, final_state)
# Return decision and processed signal
return final_state, self.process_signal(final_state["final_trade_decision"])
@staticmethod
def _extract_report(step_key: str, update: dict) -> str:
"""Extract the relevant report string from a node's state update."""
extractors = {
"market_analyst": lambda u: u.get("market_report", ""),
"news_analyst": lambda u: u.get("news_report", ""),
"fundamentals_analyst": lambda u: u.get("fundamentals_report", ""),
"social_analyst": lambda u: u.get("sentiment_report", ""),
"bull_researcher": lambda u: (u.get("investment_debate_state") or {}).get("bull_history", ""),
"bear_researcher": lambda u: (u.get("investment_debate_state") or {}).get("bear_history", ""),
"research_manager": lambda u: u.get("investment_plan", ""),
"trader": lambda u: u.get("trader_investment_plan", ""),
"aggressive_analyst": lambda u: (u.get("risk_debate_state") or {}).get("current_aggressive_response", ""),
"conservative_analyst": lambda u: (u.get("risk_debate_state") or {}).get("current_conservative_response", ""),
"neutral_analyst": lambda u: (u.get("risk_debate_state") or {}).get("current_neutral_response", ""),
"risk_judge": lambda u: (u.get("risk_debate_state") or {}).get("judge_decision", ""),
}
return extractors[step_key](update) or ""
def stream_propagate(self, company_name: str, trade_date: str, thread_id=None):
"""Stream trading analysis events as each agent node completes.
Yields:
(step_key, report) tuples for each meaningful node completion.
After the generator is exhausted, self._last_decision is set to the
normalized decision string ("BUY", "SELL", or "HOLD").
"""
self.ticker = company_name
self._last_decision = None
if thread_id is None:
payload = json.dumps(
{
"ticker": company_name.strip().upper(),
"trade_date": str(trade_date),
"analysts": sorted(self.selected_analysts),
"llm_provider": self.config.get("llm_provider"),
"deep_think_llm": self.config.get("deep_think_llm"),
"quick_think_llm": self.config.get("quick_think_llm"),
"max_debate_rounds": self.config.get("max_debate_rounds"),
"max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"),
},
sort_keys=True,
).encode()
thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24]
init_agent_state = self.propagator.create_initial_state(company_name, trade_date)
args = self.propagator.get_graph_args(thread_id=thread_id)
args["stream_mode"] = "updates" # stream per-node deltas, not full state snapshots
thread_config = {"configurable": {"thread_id": thread_id}}
snap = self.graph.get_state(thread_config)
stream_input = None if snap.next else init_agent_state
for chunk in self.graph.stream(stream_input, **args):
node_name, update = next(iter(chunk.items()))
# Filter: skip list first, then known nodes, else warn and skip
if node_name in _SKIP_NODES or node_name.startswith("Msg Clear"):
continue
if node_name not in _NODE_TO_STEP:
logger.warning("stream_propagate: unknown node '%s' — skipping", node_name)
continue
step_key = _NODE_TO_STEP[node_name]
report = TradingAgentsGraph._extract_report(step_key, update)
yield step_key, report
# Post-loop: fetch the complete final state snapshot (all fields populated).
# stream_mode="updates" gives only deltas — use get_state() for the full picture
# needed by _log_state and process_signal.
final_snap = self.graph.get_state(thread_config)
final_state = final_snap.values if hasattr(final_snap, "values") else {}
raw_signal = final_state.get("final_trade_decision", "")
try:
raw_decision = self.process_signal(raw_signal)
decision = raw_decision.strip().upper()
if decision not in {"BUY", "SELL", "HOLD"}:
logger.warning("stream_propagate: unexpected decision '%s' — defaulting to HOLD", decision)
decision = "HOLD"
except Exception:
raise # propagate to run_service for run:error handling
self._last_decision = decision
self._log_state(trade_date, final_state)
def _log_state(self, trade_date, final_state):
"""Log the final state to a JSON file."""
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
"market_report": final_state["market_report"],
"sentiment_report": final_state["sentiment_report"],
"news_report": final_state["news_report"],
"fundamentals_report": final_state["fundamentals_report"],
"investment_debate_state": {
"bull_history": final_state["investment_debate_state"]["bull_history"],
"bear_history": final_state["investment_debate_state"]["bear_history"],
"history": final_state["investment_debate_state"]["history"],
"current_response": final_state["investment_debate_state"][
"current_response"
],
"judge_decision": final_state["investment_debate_state"][
"judge_decision"
],
},
"trader_investment_decision": final_state["trader_investment_plan"],
"risk_debate_state": {
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
"history": final_state["risk_debate_state"]["history"],
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
},
"investment_plan": final_state["investment_plan"],
"final_trade_decision": final_state["final_trade_decision"],
}
# Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
directory.mkdir(parents=True, exist_ok=True)
with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
"w",
encoding="utf-8",
) as f:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
)
self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory
)
self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory
)
self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory
)
self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory
)
def process_signal(self, full_signal):
"""Process a signal to extract the core decision."""
return self.signal_processor.process_signal(full_signal)