TradingAgents/backend/tradingagents/graph/trading_graph.py

310 lines
13 KiB
Python

# TradingAgents/graph/trading_graph.py
import os
from pathlib import Path
import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
import logging
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import ToolNode
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
from tradingagents.dataflows.interface import set_config
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .signal_processing import SignalProcessor
class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework."""
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
):
"""Initialize the trading agents graph and components.
Args:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
# Update the interface's config
set_config(self.config)
# Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
self.toolkit = Toolkit(config=self.config)
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
# Create tool nodes
self.tool_nodes = self._create_tool_nodes()
# Initialize components
self.conditional_logic = ConditionalLogic()
self.graph_setup = GraphSetup(
self.quick_thinking_llm,
self.deep_thinking_llm,
self.toolkit,
self.tool_nodes,
self.bull_memory,
self.bear_memory,
self.trader_memory,
self.invest_judge_memory,
self.risk_manager_memory,
self.conditional_logic,
)
self.propagator = Propagator()
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
# State tracking
self.curr_state = None
self.ticker = None
self.log_states_dict = {} # date to full state dict
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
"""Create tool nodes for different data sources with specific message channels."""
logger.info("🔧 Creating tool nodes with message channels")
tool_nodes = {
"market": ToolNode(
[
# online tools
self.toolkit.get_YFin_data_online,
self.toolkit.get_stockstats_indicators_report_online,
# offline tools
self.toolkit.get_YFin_data,
self.toolkit.get_stockstats_indicators_report,
],
messages_key="market_messages"
),
"social": ToolNode(
[
# online tools
self.toolkit.get_stock_news_openai,
# offline tools
self.toolkit.get_reddit_stock_info,
],
messages_key="social_messages"
),
"news": ToolNode(
[
# online tools
self.toolkit.get_global_news_openai,
self.toolkit.get_google_news,
# offline tools
self.toolkit.get_finnhub_news,
self.toolkit.get_reddit_news,
],
messages_key="news_messages"
),
"fundamentals": ToolNode(
[
# online tools
self.toolkit.get_fundamentals_openai,
# offline tools
self.toolkit.get_finnhub_company_insider_sentiment,
self.toolkit.get_finnhub_company_insider_transactions,
self.toolkit.get_simfin_balance_sheet,
self.toolkit.get_simfin_cashflow,
self.toolkit.get_simfin_income_stmt,
],
messages_key="fundamentals_messages"
),
}
for tool_type, node in tool_nodes.items():
logger.info(f"{tool_type}: {len(node.tools_by_name)} tools")
return tool_nodes
def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date."""
self.ticker = company_name
# Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
args = self.propagator.get_graph_args()
if self.debug:
# Debug mode with tracing
logger.info("🐛 Running in debug mode with full tracing")
trace = []
chunk_count = 0
for chunk in self.graph.stream(init_agent_state, **args):
chunk_count += 1
logger.info(f"🔄 Processing chunk {chunk_count}")
logger.info(f"📋 Chunk keys: {list(chunk.keys())}")
# Check for any message updates in analyst channels
message_channels = ["market_messages", "social_messages", "news_messages", "fundamentals_messages"]
for channel in message_channels:
if channel in chunk and chunk[channel]:
logger.info(f"💬 Updated {channel}: {len(chunk[channel])} messages")
if chunk[channel]:
last_msg = chunk[channel][-1]
logger.info(f"📝 Last {channel} message type: {type(last_msg).__name__}")
if hasattr(last_msg, 'content'):
logger.info(f"📝 Content preview: {str(last_msg.content)[:200]}...")
if hasattr(last_msg, 'tool_calls') and last_msg.tool_calls:
logger.info(f"🔧 Tool calls: {[tc.name if hasattr(tc, 'name') else str(tc) for tc in last_msg.tool_calls]}")
# Check for report updates
report_keys = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
for report_key in report_keys:
if report_key in chunk and chunk[report_key]:
logger.info(f"📊 Report generated: {report_key} ({len(chunk[report_key])} chars)")
trace.append(chunk)
logger.info(f"✅ Debug execution complete. Processed {chunk_count} chunks")
final_state = trace[-1] if trace else init_agent_state
else:
# Standard mode without tracing
logger.info("🏃 Running in standard mode")
try:
final_state = self.graph.invoke(init_agent_state, **args)
logger.info("✅ Standard execution complete")
except Exception as e:
logger.error(f"❌ Error during graph execution: {str(e)}")
logger.error(f"❌ Error type: {type(e).__name__}")
raise
# Store current state for reflection
self.curr_state = final_state
# Log state
logger.info("💾 Logging final state")
self._log_state(trade_date, final_state)
# Process final decision
final_decision = final_state.get("final_trade_decision", "No decision made")
processed_signal = self.process_signal(final_decision)
logger.info(f"🎯 Analysis complete for {company_name}")
logger.info(f"📊 Final decision: {final_decision[:100]}...")
logger.info(f"🔄 Processed signal: {processed_signal}")
# Return decision and processed signal
return final_state, processed_signal
def _log_state(self, trade_date, final_state):
"""Log the final state to a JSON file."""
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
"market_report": final_state["market_report"],
"sentiment_report": final_state["sentiment_report"],
"news_report": final_state["news_report"],
"fundamentals_report": final_state["fundamentals_report"],
"investment_debate_state": {
"bull_history": final_state["investment_debate_state"]["bull_history"],
"bear_history": final_state["investment_debate_state"]["bear_history"],
"history": final_state["investment_debate_state"]["history"],
"current_response": final_state["investment_debate_state"][
"current_response"
],
"judge_decision": final_state["investment_debate_state"][
"judge_decision"
],
},
"trader_investment_decision": final_state["trader_investment_plan"],
"risk_debate_state": {
"risky_history": final_state["risk_debate_state"]["risky_history"],
"safe_history": final_state["risk_debate_state"]["safe_history"],
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
"history": final_state["risk_debate_state"]["history"],
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
},
"investment_plan": final_state["investment_plan"],
"final_trade_decision": final_state["final_trade_decision"],
}
# Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
directory.mkdir(parents=True, exist_ok=True)
with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log.json",
"w",
) as f:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
)
self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory
)
self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory
)
self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory
)
self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory
)
def process_signal(self, full_signal):
"""Process a signal to extract the core decision."""
return self.signal_processor.process_signal(full_signal)