418 lines
15 KiB
Python
418 lines
15 KiB
Python
# TradingAgents/graph/trading_graph.py
|
|
|
|
import os
|
|
from copy import deepcopy
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
from langgraph.prebuilt import ToolNode
|
|
|
|
from tradingagents.llm_clients import create_llm_client
|
|
|
|
from tradingagents.agents import *
|
|
from tradingagents.default_config import DEFAULT_CONFIG, normalize_llm_routing
|
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
|
from tradingagents.agents.utils.agent_states import (
|
|
AgentState,
|
|
InvestDebateState,
|
|
RiskDebateState,
|
|
)
|
|
from tradingagents.dataflows.config import set_config
|
|
|
|
# Import the new abstract tool methods from agent_utils
|
|
from tradingagents.agents.utils.agent_utils import (
|
|
get_stock_data,
|
|
get_indicators,
|
|
get_fundamentals,
|
|
get_balance_sheet,
|
|
get_cashflow,
|
|
get_income_statement,
|
|
get_economic_indicators,
|
|
get_fed_calendar,
|
|
get_news,
|
|
get_insider_transactions,
|
|
get_global_news,
|
|
get_yield_curve,
|
|
)
|
|
|
|
from .conditional_logic import ConditionalLogic
|
|
from .setup import GraphSetup
|
|
from .propagation import Propagator
|
|
from .reflection import Reflector
|
|
from .signal_processing import SignalProcessor
|
|
|
|
|
|
class TradingAgentsGraph:
|
|
"""Main class that orchestrates the trading agents framework."""
|
|
|
|
ALWAYS_ON_ROLES = {
|
|
"bull_researcher",
|
|
"bear_researcher",
|
|
"research_manager",
|
|
"trader",
|
|
"aggressive_analyst",
|
|
"neutral_analyst",
|
|
"conservative_analyst",
|
|
"portfolio_manager",
|
|
}
|
|
QUICK_THINKING_ROLES = {
|
|
"market",
|
|
"social",
|
|
"news",
|
|
"fundamentals",
|
|
"factor_rules",
|
|
"macro",
|
|
"bull_researcher",
|
|
"bear_researcher",
|
|
"trader",
|
|
"aggressive_analyst",
|
|
"neutral_analyst",
|
|
"conservative_analyst",
|
|
}
|
|
DEEP_THINKING_ROLES = {
|
|
"research_manager",
|
|
"portfolio_manager",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
selected_analysts=["market", "social", "news", "fundamentals", "macro"],
|
|
debug=False,
|
|
config: Dict[str, Any] = None,
|
|
callbacks: Optional[List] = None,
|
|
):
|
|
"""Initialize the trading agents graph and components.
|
|
|
|
Args:
|
|
selected_analysts: List of analyst types to include
|
|
debug: Whether to run in debug mode
|
|
config: Configuration dictionary. If None, uses default config
|
|
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
|
"""
|
|
self.debug = debug
|
|
self.config = self._build_config(config)
|
|
self.callbacks = callbacks or []
|
|
|
|
# Update the interface's config
|
|
set_config(self.config)
|
|
|
|
# Create necessary directories
|
|
os.makedirs(
|
|
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
|
exist_ok=True,
|
|
)
|
|
|
|
self.quick_thinking_llm = self._create_legacy_llm("quick")
|
|
self.deep_thinking_llm = self._create_legacy_llm("deep")
|
|
self.role_llms = self._create_role_llms(selected_analysts)
|
|
|
|
# Initialize memories
|
|
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
|
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
|
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
|
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
|
self.portfolio_manager_memory = FinancialSituationMemory("portfolio_manager_memory", self.config)
|
|
|
|
# Create tool nodes
|
|
self.tool_nodes = self._create_tool_nodes()
|
|
|
|
# Initialize components
|
|
self.conditional_logic = ConditionalLogic(
|
|
max_debate_rounds=self.config["max_debate_rounds"],
|
|
max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"],
|
|
)
|
|
self.graph_setup = GraphSetup(
|
|
self.quick_thinking_llm,
|
|
self.deep_thinking_llm,
|
|
self.tool_nodes,
|
|
self.bull_memory,
|
|
self.bear_memory,
|
|
self.trader_memory,
|
|
self.invest_judge_memory,
|
|
self.portfolio_manager_memory,
|
|
self.conditional_logic,
|
|
role_llms=self.role_llms,
|
|
)
|
|
|
|
self.propagator = Propagator()
|
|
self.reflector = Reflector(self.quick_thinking_llm)
|
|
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
|
|
|
# State tracking
|
|
self.curr_state = None
|
|
self.ticker = None
|
|
self.log_states_dict = {} # date to full state dict
|
|
|
|
# Set up the graph
|
|
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
|
|
|
def _build_config(self, config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""Merge user config over defaults without mutating the shared defaults."""
|
|
return normalize_llm_routing(self._deep_merge_dicts(DEFAULT_CONFIG, config or {}))
|
|
|
|
def _normalize_provider(self, provider: Optional[str]) -> str:
|
|
return (provider or "").lower()
|
|
|
|
def _deep_merge_dicts(
|
|
self,
|
|
base: Dict[str, Any],
|
|
override: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
merged = deepcopy(base)
|
|
for key, value in override.items():
|
|
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
|
merged[key] = self._deep_merge_dicts(merged[key], value)
|
|
else:
|
|
merged[key] = deepcopy(value)
|
|
return merged
|
|
|
|
def _create_legacy_llm(self, thinker_depth: str):
|
|
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
|
provider = self._normalize_provider(self.config["llm_provider"])
|
|
llm_kwargs = self._get_provider_kwargs(provider)
|
|
if self.callbacks:
|
|
llm_kwargs["callbacks"] = self.callbacks
|
|
|
|
client = create_llm_client(
|
|
provider=provider,
|
|
model=self.config[model_key],
|
|
base_url=self.config.get("backend_url"),
|
|
**llm_kwargs,
|
|
)
|
|
return client.get_llm()
|
|
|
|
def _create_role_llms(self, selected_analysts: List[str]) -> Dict[str, Any]:
|
|
role_llms = {}
|
|
llm_cache = {}
|
|
for role in self._get_required_roles(selected_analysts):
|
|
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
|
|
llm_config = self._resolve_llm_config(role, thinker_depth)
|
|
if self._uses_legacy_llm(llm_config, thinker_depth):
|
|
continue
|
|
cache_key = (
|
|
llm_config["provider"],
|
|
llm_config["model"],
|
|
llm_config.get("base_url"),
|
|
)
|
|
if cache_key not in llm_cache:
|
|
llm_cache[cache_key] = self._create_llm_from_config(llm_config)
|
|
role_llms[role] = llm_cache[cache_key]
|
|
return role_llms
|
|
|
|
def _get_required_roles(self, selected_analysts: List[str]) -> set[str]:
|
|
return self.ALWAYS_ON_ROLES | set(selected_analysts)
|
|
|
|
def _create_llm_from_config(self, llm_config: Dict[str, Any]):
|
|
llm_kwargs = self._get_provider_kwargs(llm_config["provider"])
|
|
if self.callbacks:
|
|
llm_kwargs["callbacks"] = self.callbacks
|
|
|
|
client = create_llm_client(
|
|
provider=llm_config["provider"],
|
|
model=llm_config["model"],
|
|
base_url=llm_config.get("base_url"),
|
|
**llm_kwargs,
|
|
)
|
|
return client.get_llm()
|
|
|
|
def _uses_legacy_llm(self, llm_config: Dict[str, Any], thinker_depth: str) -> bool:
|
|
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
|
return (
|
|
llm_config["provider"] == self._normalize_provider(self.config["llm_provider"])
|
|
and llm_config["model"] == self.config[model_key]
|
|
and llm_config.get("base_url") == self.config.get("backend_url")
|
|
)
|
|
|
|
def _resolve_llm_config(
|
|
self,
|
|
role: str,
|
|
thinker_depth: str,
|
|
) -> Dict[str, Any]:
|
|
routing = self.config.get("llm_routing") or {}
|
|
role_routes = routing.get("roles") or {}
|
|
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
|
legacy_provider = self._normalize_provider(self.config["llm_provider"])
|
|
legacy_route = {
|
|
"provider": legacy_provider,
|
|
"model": self.config[model_key],
|
|
"base_url": self.config.get("backend_url"),
|
|
}
|
|
default_route = routing.get("default") or {}
|
|
role_route = role_routes.get(role) or {}
|
|
route = self._deep_merge_dicts(legacy_route, default_route)
|
|
route = self._deep_merge_dicts(route, role_route)
|
|
route["provider"] = self._normalize_provider(route.get("provider"))
|
|
explicit_routed_base_url = "base_url" in default_route or "base_url" in role_route
|
|
if route["provider"] != legacy_provider and not explicit_routed_base_url:
|
|
route["base_url"] = None
|
|
return route
|
|
|
|
def _get_provider_kwargs(self, provider: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Get provider-specific kwargs for LLM client creation."""
|
|
kwargs = {}
|
|
provider = (provider or self.config.get("llm_provider", "")).lower()
|
|
|
|
if provider == "google":
|
|
thinking_level = self.config.get("google_thinking_level")
|
|
if thinking_level:
|
|
kwargs["thinking_level"] = thinking_level
|
|
|
|
elif provider == "openai":
|
|
reasoning_effort = self.config.get("openai_reasoning_effort")
|
|
if reasoning_effort:
|
|
kwargs["reasoning_effort"] = reasoning_effort
|
|
|
|
elif provider == "anthropic":
|
|
effort = self.config.get("anthropic_effort")
|
|
if effort:
|
|
kwargs["effort"] = effort
|
|
|
|
return kwargs
|
|
|
|
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
|
"""Create tool nodes for different data sources using abstract methods."""
|
|
return {
|
|
"market": ToolNode(
|
|
[
|
|
# Core stock data tools
|
|
get_stock_data,
|
|
# Technical indicators
|
|
get_indicators,
|
|
]
|
|
),
|
|
"social": ToolNode(
|
|
[
|
|
# News tools for social media analysis
|
|
get_news,
|
|
]
|
|
),
|
|
"news": ToolNode(
|
|
[
|
|
# News and insider information
|
|
get_news,
|
|
get_global_news,
|
|
get_insider_transactions,
|
|
]
|
|
),
|
|
"fundamentals": ToolNode(
|
|
[
|
|
# Fundamental analysis tools
|
|
get_fundamentals,
|
|
get_balance_sheet,
|
|
get_cashflow,
|
|
get_income_statement,
|
|
]
|
|
),
|
|
"macro": ToolNode(
|
|
[
|
|
# Macroeconomic analysis tools
|
|
get_economic_indicators,
|
|
get_yield_curve,
|
|
get_fed_calendar,
|
|
]
|
|
),
|
|
}
|
|
|
|
def propagate(self, company_name, trade_date):
|
|
"""Run the trading agents graph for a company on a specific date."""
|
|
|
|
self.ticker = company_name
|
|
|
|
# Initialize state
|
|
init_agent_state = self.propagator.create_initial_state(
|
|
company_name, trade_date
|
|
)
|
|
args = self.propagator.get_graph_args()
|
|
|
|
if self.debug:
|
|
# Debug mode with tracing
|
|
trace = []
|
|
for chunk in self.graph.stream(init_agent_state, **args):
|
|
if len(chunk["messages"]) == 0:
|
|
pass
|
|
else:
|
|
chunk["messages"][-1].pretty_print()
|
|
trace.append(chunk)
|
|
|
|
final_state = trace[-1]
|
|
else:
|
|
# Standard mode without tracing
|
|
final_state = self.graph.invoke(init_agent_state, **args)
|
|
|
|
# Store current state for reflection
|
|
self.curr_state = final_state
|
|
|
|
# Log state
|
|
self._log_state(trade_date, final_state)
|
|
|
|
# Return decision and processed signal
|
|
return final_state, self.process_signal(final_state["final_trade_decision"])
|
|
|
|
def _log_state(self, trade_date, final_state):
|
|
"""Log the final state to a JSON file."""
|
|
self.log_states_dict[str(trade_date)] = {
|
|
"company_of_interest": final_state["company_of_interest"],
|
|
"trade_date": final_state["trade_date"],
|
|
"market_report": final_state["market_report"],
|
|
"sentiment_report": final_state["sentiment_report"],
|
|
"news_report": final_state["news_report"],
|
|
"fundamentals_report": final_state["fundamentals_report"],
|
|
"factor_rules_report": final_state.get("factor_rules_report", ""),
|
|
"macro_report": final_state.get("macro_report", ""),
|
|
"investment_debate_state": {
|
|
"bull_history": final_state["investment_debate_state"]["bull_history"],
|
|
"bear_history": final_state["investment_debate_state"]["bear_history"],
|
|
"history": final_state["investment_debate_state"]["history"],
|
|
"current_response": final_state["investment_debate_state"][
|
|
"current_response"
|
|
],
|
|
"judge_decision": final_state["investment_debate_state"][
|
|
"judge_decision"
|
|
],
|
|
},
|
|
"trader_investment_decision": final_state["trader_investment_plan"],
|
|
"risk_debate_state": {
|
|
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
|
|
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
|
|
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
|
|
"history": final_state["risk_debate_state"]["history"],
|
|
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
|
|
},
|
|
"investment_plan": final_state["investment_plan"],
|
|
"final_trade_decision": final_state["final_trade_decision"],
|
|
}
|
|
|
|
# Save to file
|
|
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
|
directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(
|
|
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
|
"w",
|
|
encoding="utf-8",
|
|
) as f:
|
|
json.dump(self.log_states_dict, f, indent=4)
|
|
|
|
def reflect_and_remember(self, returns_losses):
|
|
"""Reflect on decisions and update memory based on returns."""
|
|
self.reflector.reflect_bull_researcher(
|
|
self.curr_state, returns_losses, self.bull_memory
|
|
)
|
|
self.reflector.reflect_bear_researcher(
|
|
self.curr_state, returns_losses, self.bear_memory
|
|
)
|
|
self.reflector.reflect_trader(
|
|
self.curr_state, returns_losses, self.trader_memory
|
|
)
|
|
self.reflector.reflect_invest_judge(
|
|
self.curr_state, returns_losses, self.invest_judge_memory
|
|
)
|
|
self.reflector.reflect_portfolio_manager(
|
|
self.curr_state, returns_losses, self.portfolio_manager_memory
|
|
)
|
|
|
|
def process_signal(self, full_signal):
|
|
"""Process a signal to extract the core decision."""
|
|
return self.signal_processor.process_signal(full_signal)
|