TradingAgents/tradingagents/graph/setup.py

209 lines
7.6 KiB
Python

"""Graph setup for the structured equity ranking pipeline.
Pipeline stages:
START → Validation → [veto gate] → Tier 1 (Macro+Liquidity parallel)
→ Tier 2 (8 agents parallel) → Scoring (Archetype+MasterScore)
→ Tier 3 (Bull+Bear parallel → Debate → Risk → FinalDecision)
→ END
"""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any, Dict, List
from langgraph.graph import END, START, StateGraph
from tradingagents.agents.utils.agent_states import PipelineState
logger = logging.getLogger(__name__)
class StructuredGraphSetup:
"""Builds the structured equity ranking LangGraph."""
def __init__(self, quick_llm, deep_llm):
self.quick_llm = quick_llm
self.deep_llm = deep_llm
def setup_graph(self):
"""Build and compile the structured pipeline graph."""
from tradingagents.agents.structured import (
create_archetype_node,
create_backlog_node,
create_bear_case_node,
create_bull_case_node,
create_business_quality_node,
create_crowding_node,
create_debate_node,
create_earnings_revisions_node,
create_entry_timing_node,
create_final_decision_node,
create_institutional_flow_node,
create_liquidity_node,
create_macro_node,
create_position_replacement_node,
create_risk_node,
create_scoring_node,
create_sector_rotation_node,
create_theme_substitution_node,
create_validation_node,
create_valuation_node,
)
# Create node functions
# Tier 1: cheap model (or no LLM for validation)
validation_fn = create_validation_node()
macro_fn = create_macro_node(self.quick_llm)
liquidity_fn = create_liquidity_node(self.quick_llm)
# Tier 2: cheap model for analysis
bq_fn = create_business_quality_node(self.quick_llm)
inst_fn = create_institutional_flow_node(self.quick_llm)
val_fn = create_valuation_node(self.quick_llm)
et_fn = create_entry_timing_node(self.quick_llm)
er_fn = create_earnings_revisions_node(self.quick_llm)
sr_fn = create_sector_rotation_node(self.quick_llm)
bl_fn = create_backlog_node(self.quick_llm)
cr_fn = create_crowding_node(self.quick_llm)
arch_fn = create_archetype_node(self.quick_llm)
score_fn = create_scoring_node()
# Portfolio-level: deep model for theme/replacement analysis
theme_fn = create_theme_substitution_node(self.deep_llm)
replace_fn = create_position_replacement_node(self.deep_llm)
# Tier 3: deep model for reasoning/debate
bull_fn = create_bull_case_node(self.deep_llm)
bear_fn = create_bear_case_node(self.deep_llm)
debate_fn = create_debate_node(self.deep_llm)
risk_fn = create_risk_node(self.deep_llm)
final_fn = create_final_decision_node(self.deep_llm)
# Build parallel wrapper nodes
parallel_tier1 = _create_parallel_node(
[("macro", macro_fn), ("liquidity", liquidity_fn)],
"Tier 1",
)
parallel_tier2 = _create_parallel_node(
[
("business_quality", bq_fn),
("institutional_flow", inst_fn),
("valuation", val_fn),
("entry_timing", et_fn),
("earnings_revisions", er_fn),
("sector_rotation", sr_fn),
("backlog", bl_fn),
("crowding", cr_fn),
],
"Tier 2",
)
parallel_bull_bear = _create_parallel_node(
[("bull_case", bull_fn), ("bear_case", bear_fn)],
"Bull/Bear",
)
# Archetype + Score combined node
def archetype_and_score(state):
arch_result = arch_fn(state)
merged = {**state, **arch_result}
score_result = score_fn(merged)
return {**arch_result, **score_result}
# Theme + Replacement combined node (sequential: theme feeds replacement)
def theme_and_replacement(state):
theme_result = theme_fn(state)
merged = {**state, **theme_result}
replace_result = replace_fn(merged)
return {**theme_result, **replace_result}
# Risk + Final Decision combined node
def risk_and_decision(state):
risk_result = risk_fn(state)
merged = {**state, **risk_result}
final_result = final_fn(merged)
return {**risk_result, **final_result}
# Build graph
workflow = StateGraph(PipelineState)
workflow.add_node("Validation", validation_fn)
workflow.add_node("Tier 1 Analysis", parallel_tier1)
workflow.add_node("Tier 2 Analysis", parallel_tier2)
workflow.add_node("Scoring", archetype_and_score)
workflow.add_node("Portfolio Analysis", theme_and_replacement)
workflow.add_node("Debate", parallel_bull_bear)
workflow.add_node("Debate Referee", debate_fn)
workflow.add_node("Decision", risk_and_decision)
# Edges
workflow.add_edge(START, "Validation")
workflow.add_conditional_edges(
"Validation",
_veto_gate,
{END: END, "continue": "Tier 1 Analysis"},
)
workflow.add_edge("Tier 1 Analysis", "Tier 2 Analysis")
workflow.add_edge("Tier 2 Analysis", "Scoring")
workflow.add_edge("Scoring", "Portfolio Analysis")
workflow.add_edge("Portfolio Analysis", "Debate")
workflow.add_edge("Debate", "Debate Referee")
workflow.add_edge("Debate Referee", "Decision")
workflow.add_edge("Decision", END)
return workflow.compile()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _veto_gate(state: Dict[str, Any]) -> str:
"""Check if validation resulted in a hard veto."""
if state.get("hard_veto"):
return END
validation = state.get("validation") or {}
if validation.get("veto"):
return END
return "continue"
def _create_parallel_node(agent_fns: List[tuple], label: str):
"""Create an async node that runs multiple agent functions in parallel.
Args:
agent_fns: List of (name, fn) tuples.
label: Label for logging.
"""
async def parallel_node(state):
t0 = time.time()
async def run_one(name, fn):
logger.debug("[%s] %s starting", label, name)
result = await asyncio.to_thread(fn, state)
logger.debug("[%s] %s done (%.1fs)", label, name, time.time() - t0)
return result
tasks = [run_one(name, fn) for name, fn in agent_fns]
results = await asyncio.gather(*tasks, return_exceptions=True)
merged: Dict[str, Any] = {}
all_flags: list = []
for (name, _), result in zip(agent_fns, results):
if isinstance(result, Exception):
logger.error("[%s] %s failed: %s", label, name, result)
continue
flags = result.pop("global_flags", [])
all_flags.extend(flags)
merged.update(result)
if all_flags:
merged["global_flags"] = all_flags
logger.info("[%s] completed in %.1fs", label, time.time() - t0)
return merged
return parallel_node