209 lines
7.6 KiB
Python
209 lines
7.6 KiB
Python
"""Graph setup for the structured equity ranking pipeline.
|
|
|
|
Pipeline stages:
|
|
START → Validation → [veto gate] → Tier 1 (Macro+Liquidity parallel)
|
|
→ Tier 2 (8 agents parallel) → Scoring (Archetype+MasterScore)
|
|
→ Tier 3 (Bull+Bear parallel → Debate → Risk → FinalDecision)
|
|
→ END
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Any, Dict, List
|
|
|
|
from langgraph.graph import END, START, StateGraph
|
|
|
|
from tradingagents.agents.utils.agent_states import PipelineState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StructuredGraphSetup:
|
|
"""Builds the structured equity ranking LangGraph."""
|
|
|
|
def __init__(self, quick_llm, deep_llm):
|
|
self.quick_llm = quick_llm
|
|
self.deep_llm = deep_llm
|
|
|
|
def setup_graph(self):
|
|
"""Build and compile the structured pipeline graph."""
|
|
from tradingagents.agents.structured import (
|
|
create_archetype_node,
|
|
create_backlog_node,
|
|
create_bear_case_node,
|
|
create_bull_case_node,
|
|
create_business_quality_node,
|
|
create_crowding_node,
|
|
create_debate_node,
|
|
create_earnings_revisions_node,
|
|
create_entry_timing_node,
|
|
create_final_decision_node,
|
|
create_institutional_flow_node,
|
|
create_liquidity_node,
|
|
create_macro_node,
|
|
create_position_replacement_node,
|
|
create_risk_node,
|
|
create_scoring_node,
|
|
create_sector_rotation_node,
|
|
create_theme_substitution_node,
|
|
create_validation_node,
|
|
create_valuation_node,
|
|
)
|
|
|
|
# Create node functions
|
|
# Tier 1: cheap model (or no LLM for validation)
|
|
validation_fn = create_validation_node()
|
|
macro_fn = create_macro_node(self.quick_llm)
|
|
liquidity_fn = create_liquidity_node(self.quick_llm)
|
|
|
|
# Tier 2: cheap model for analysis
|
|
bq_fn = create_business_quality_node(self.quick_llm)
|
|
inst_fn = create_institutional_flow_node(self.quick_llm)
|
|
val_fn = create_valuation_node(self.quick_llm)
|
|
et_fn = create_entry_timing_node(self.quick_llm)
|
|
er_fn = create_earnings_revisions_node(self.quick_llm)
|
|
sr_fn = create_sector_rotation_node(self.quick_llm)
|
|
bl_fn = create_backlog_node(self.quick_llm)
|
|
cr_fn = create_crowding_node(self.quick_llm)
|
|
arch_fn = create_archetype_node(self.quick_llm)
|
|
score_fn = create_scoring_node()
|
|
|
|
# Portfolio-level: deep model for theme/replacement analysis
|
|
theme_fn = create_theme_substitution_node(self.deep_llm)
|
|
replace_fn = create_position_replacement_node(self.deep_llm)
|
|
|
|
# Tier 3: deep model for reasoning/debate
|
|
bull_fn = create_bull_case_node(self.deep_llm)
|
|
bear_fn = create_bear_case_node(self.deep_llm)
|
|
debate_fn = create_debate_node(self.deep_llm)
|
|
risk_fn = create_risk_node(self.deep_llm)
|
|
final_fn = create_final_decision_node(self.deep_llm)
|
|
|
|
# Build parallel wrapper nodes
|
|
parallel_tier1 = _create_parallel_node(
|
|
[("macro", macro_fn), ("liquidity", liquidity_fn)],
|
|
"Tier 1",
|
|
)
|
|
parallel_tier2 = _create_parallel_node(
|
|
[
|
|
("business_quality", bq_fn),
|
|
("institutional_flow", inst_fn),
|
|
("valuation", val_fn),
|
|
("entry_timing", et_fn),
|
|
("earnings_revisions", er_fn),
|
|
("sector_rotation", sr_fn),
|
|
("backlog", bl_fn),
|
|
("crowding", cr_fn),
|
|
],
|
|
"Tier 2",
|
|
)
|
|
parallel_bull_bear = _create_parallel_node(
|
|
[("bull_case", bull_fn), ("bear_case", bear_fn)],
|
|
"Bull/Bear",
|
|
)
|
|
|
|
# Archetype + Score combined node
|
|
def archetype_and_score(state):
|
|
arch_result = arch_fn(state)
|
|
merged = {**state, **arch_result}
|
|
score_result = score_fn(merged)
|
|
return {**arch_result, **score_result}
|
|
|
|
# Theme + Replacement combined node (sequential: theme feeds replacement)
|
|
def theme_and_replacement(state):
|
|
theme_result = theme_fn(state)
|
|
merged = {**state, **theme_result}
|
|
replace_result = replace_fn(merged)
|
|
return {**theme_result, **replace_result}
|
|
|
|
# Risk + Final Decision combined node
|
|
def risk_and_decision(state):
|
|
risk_result = risk_fn(state)
|
|
merged = {**state, **risk_result}
|
|
final_result = final_fn(merged)
|
|
return {**risk_result, **final_result}
|
|
|
|
# Build graph
|
|
workflow = StateGraph(PipelineState)
|
|
|
|
workflow.add_node("Validation", validation_fn)
|
|
workflow.add_node("Tier 1 Analysis", parallel_tier1)
|
|
workflow.add_node("Tier 2 Analysis", parallel_tier2)
|
|
workflow.add_node("Scoring", archetype_and_score)
|
|
workflow.add_node("Portfolio Analysis", theme_and_replacement)
|
|
workflow.add_node("Debate", parallel_bull_bear)
|
|
workflow.add_node("Debate Referee", debate_fn)
|
|
workflow.add_node("Decision", risk_and_decision)
|
|
|
|
# Edges
|
|
workflow.add_edge(START, "Validation")
|
|
workflow.add_conditional_edges(
|
|
"Validation",
|
|
_veto_gate,
|
|
{END: END, "continue": "Tier 1 Analysis"},
|
|
)
|
|
workflow.add_edge("Tier 1 Analysis", "Tier 2 Analysis")
|
|
workflow.add_edge("Tier 2 Analysis", "Scoring")
|
|
workflow.add_edge("Scoring", "Portfolio Analysis")
|
|
workflow.add_edge("Portfolio Analysis", "Debate")
|
|
workflow.add_edge("Debate", "Debate Referee")
|
|
workflow.add_edge("Debate Referee", "Decision")
|
|
workflow.add_edge("Decision", END)
|
|
|
|
return workflow.compile()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _veto_gate(state: Dict[str, Any]) -> str:
|
|
"""Check if validation resulted in a hard veto."""
|
|
if state.get("hard_veto"):
|
|
return END
|
|
validation = state.get("validation") or {}
|
|
if validation.get("veto"):
|
|
return END
|
|
return "continue"
|
|
|
|
|
|
def _create_parallel_node(agent_fns: List[tuple], label: str):
|
|
"""Create an async node that runs multiple agent functions in parallel.
|
|
|
|
Args:
|
|
agent_fns: List of (name, fn) tuples.
|
|
label: Label for logging.
|
|
"""
|
|
|
|
async def parallel_node(state):
|
|
t0 = time.time()
|
|
|
|
async def run_one(name, fn):
|
|
logger.debug("[%s] %s starting", label, name)
|
|
result = await asyncio.to_thread(fn, state)
|
|
logger.debug("[%s] %s done (%.1fs)", label, name, time.time() - t0)
|
|
return result
|
|
|
|
tasks = [run_one(name, fn) for name, fn in agent_fns]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
merged: Dict[str, Any] = {}
|
|
all_flags: list = []
|
|
for (name, _), result in zip(agent_fns, results):
|
|
if isinstance(result, Exception):
|
|
logger.error("[%s] %s failed: %s", label, name, result)
|
|
continue
|
|
flags = result.pop("global_flags", [])
|
|
all_flags.extend(flags)
|
|
merged.update(result)
|
|
if all_flags:
|
|
merged["global_flags"] = all_flags
|
|
|
|
logger.info("[%s] completed in %.1fs", label, time.time() - t0)
|
|
return merged
|
|
|
|
return parallel_node
|