TradingAgents/tradingagents/prediction_market/graph/pm_trading_graph.py

292 lines
11 KiB
Python

# TradingAgents/prediction_market/graph/pm_trading_graph.py
import os
from pathlib import Path
import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from langgraph.prebuilt import ToolNode
from tradingagents.llm_clients import create_llm_client
from tradingagents.prediction_market.agents import *
from tradingagents.prediction_market.pm_config import PM_DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.prediction_market.agents.utils.pm_agent_states import (
PMAgentState,
PMInvestDebateState,
PMRiskDebateState,
)
# Import PM tool functions
from tradingagents.prediction_market.agents.utils.pm_agent_utils import (
get_market_info,
get_market_price_history,
get_order_book,
get_resolution_criteria,
get_event_context,
get_related_markets,
search_markets,
get_news,
get_global_news,
)
from .conditional_logic import PMConditionalLogic
from .setup import PMGraphSetup
from .propagation import PMPropagator
from .reflection import PMReflector
from .signal_processing import PMSignalProcessor
class PMTradingAgentsGraph:
"""Main class that orchestrates the prediction market trading agents framework."""
def __init__(
self,
selected_analysts=["event", "odds", "information", "sentiment"],
debug=False,
config: Dict[str, Any] = None,
callbacks: Optional[List] = None,
):
"""Initialize the prediction market trading agents graph and components.
Args:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses PM default config
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
"""
self.debug = debug
self.config = config or PM_DEFAULT_CONFIG
self.callbacks = callbacks or []
# Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
# Add callbacks to kwargs if provided (passed to LLM constructor)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
quick_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["quick_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
# Initialize memories
self.yes_memory = FinancialSituationMemory("yes_memory", self.config)
self.no_memory = FinancialSituationMemory("no_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
# Create tool nodes
self.tool_nodes = self._create_tool_nodes()
# Initialize components
self.conditional_logic = PMConditionalLogic(
max_debate_rounds=self.config["max_debate_rounds"],
max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"],
)
self.graph_setup = PMGraphSetup(
self.quick_thinking_llm,
self.deep_thinking_llm,
self.tool_nodes,
self.yes_memory,
self.no_memory,
self.trader_memory,
self.invest_judge_memory,
self.risk_manager_memory,
self.conditional_logic,
)
self.propagator = PMPropagator()
self.reflector = PMReflector(self.quick_thinking_llm)
self.signal_processor = PMSignalProcessor(self.quick_thinking_llm)
# State tracking
self.curr_state = None
self.market_id = None
self.log_states_dict = {} # date to full state dict
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
kwargs = {}
provider = self.config.get("llm_provider", "").lower()
if provider == "google":
thinking_level = self.config.get("google_thinking_level")
if thinking_level:
kwargs["thinking_level"] = thinking_level
elif provider == "openai":
reasoning_effort = self.config.get("openai_reasoning_effort")
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
"""Create tool nodes for different prediction market data sources."""
return {
"event": ToolNode(
[
# Event context and resolution
get_market_info,
get_resolution_criteria,
get_event_context,
]
),
"odds": ToolNode(
[
# Price, order book, and market data
get_market_info,
get_market_price_history,
get_order_book,
]
),
"information": ToolNode(
[
# News and related markets
get_news,
get_global_news,
get_related_markets,
search_markets,
]
),
"sentiment": ToolNode(
[
# News for sentiment analysis
get_news,
get_global_news,
]
),
}
def propagate(self, market_id, trade_date, market_question=""):
"""Run the prediction market trading agents graph for a market on a specific date.
Args:
market_id: The Polymarket condition ID or market identifier
trade_date: The date of analysis
market_question: Optional full text of the market question
"""
self.market_id = market_id
# Initialize state
init_agent_state = self.propagator.create_initial_state(
market_id, trade_date, market_question
)
args = self.propagator.get_graph_args()
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
pass
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
final_state = trace[-1]
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
# Store current state for reflection
self.curr_state = final_state
# Log state
self._log_state(trade_date, final_state)
# Return decision and processed signal
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state):
"""Log the final state to a JSON file."""
self.log_states_dict[str(trade_date)] = {
"market_id": final_state["market_id"],
"market_question": final_state["market_question"],
"trade_date": final_state["trade_date"],
"event_report": final_state["event_report"],
"odds_report": final_state["odds_report"],
"information_report": final_state["information_report"],
"sentiment_report": final_state["sentiment_report"],
"investment_debate_state": {
"yes_history": final_state["investment_debate_state"]["yes_history"],
"no_history": final_state["investment_debate_state"]["no_history"],
"history": final_state["investment_debate_state"]["history"],
"current_response": final_state["investment_debate_state"][
"current_response"
],
"judge_decision": final_state["investment_debate_state"][
"judge_decision"
],
},
"trader_investment_decision": final_state["trader_investment_plan"],
"risk_debate_state": {
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
"history": final_state["risk_debate_state"]["history"],
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
},
"investment_plan": final_state["investment_plan"],
"final_trade_decision": final_state["final_trade_decision"],
}
# Save to file
directory = Path(f"eval_results/{self.market_id}/PMTradingAgentsStrategy_logs/")
directory.mkdir(parents=True, exist_ok=True)
with open(
f"eval_results/{self.market_id}/PMTradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
"w",
encoding="utf-8",
) as f:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns."""
self.reflector.reflect_yes_researcher(
self.curr_state, returns_losses, self.yes_memory
)
self.reflector.reflect_no_researcher(
self.curr_state, returns_losses, self.no_memory
)
self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory
)
self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory
)
self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory
)
def process_signal(self, full_signal):
"""Process a signal to extract the core decision."""
return self.signal_processor.process_signal(full_signal)