Fix parallel research/risk: use async+asyncio.gather instead of ThreadPoolExecutor
Sync ThreadPoolExecutor doesn't truly parallelize inside LangGraph nodes. Switched to async functions with asyncio.to_thread() + asyncio.gather() — the same pattern that works for the parallel analyst node. Result: Research (Bull+Bear) and Risk (Agg+Con+Neu) now run concurrently. Total analysis time reduced from ~450s to ~280s (~38% faster). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
12e0d507c2
commit
ba39a81e82
|
|
@ -8,7 +8,6 @@ Provides parallel wrappers for:
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
|
||||
|
|
@ -109,36 +108,32 @@ def _snapshot_risk_state(state):
|
|||
def create_parallel_research_node(bull_fn, bear_fn):
|
||||
"""Create a node that runs Bull and Bear researchers in parallel.
|
||||
|
||||
Uses a sync function with ThreadPoolExecutor.submit() to avoid any
|
||||
asyncio event-loop interaction issues. LangGraph handles running sync
|
||||
nodes in its own thread, and from there we spawn our own pool.
|
||||
Uses async + asyncio.to_thread + asyncio.gather — the same pattern
|
||||
that works for create_parallel_analyst_node.
|
||||
"""
|
||||
|
||||
def parallel_research_node(state):
|
||||
async def parallel_research_node(state):
|
||||
import time
|
||||
import sys
|
||||
|
||||
state_snap = _snapshot_research_state(state)
|
||||
t0 = time.time()
|
||||
|
||||
def run_bull():
|
||||
print(f"[PARALLEL] Bull starting at +{time.time()-t0:.1f}s", flush=True)
|
||||
result = bull_fn(state_snap)
|
||||
print(f"[PARALLEL] Bull done at +{time.time()-t0:.1f}s", flush=True)
|
||||
async def run_bull():
|
||||
print(f"[PARALLEL] Bull starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
result = await asyncio.to_thread(bull_fn, state_snap)
|
||||
print(f"[PARALLEL] Bull done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
return result
|
||||
|
||||
def run_bear():
|
||||
print(f"[PARALLEL] Bear starting at +{time.time()-t0:.1f}s", flush=True)
|
||||
result = bear_fn(state_snap)
|
||||
print(f"[PARALLEL] Bear done at +{time.time()-t0:.1f}s", flush=True)
|
||||
async def run_bear():
|
||||
print(f"[PARALLEL] Bear starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
result = await asyncio.to_thread(bear_fn, state_snap)
|
||||
print(f"[PARALLEL] Bear done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
return result
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||
bull_future = pool.submit(run_bull)
|
||||
bear_future = pool.submit(run_bear)
|
||||
bull_result = bull_future.result()
|
||||
bear_result = bear_future.result()
|
||||
bull_result, bear_result = await asyncio.gather(run_bull(), run_bear())
|
||||
|
||||
print(f"[PARALLEL] Research total: {time.time()-t0:.1f}s", flush=True)
|
||||
print(f"[PARALLEL] Research total: {time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
|
||||
bull_debate = bull_result["investment_debate_state"]
|
||||
bear_debate = bear_result["investment_debate_state"]
|
||||
|
|
@ -161,44 +156,40 @@ def create_parallel_research_node(bull_fn, bear_fn):
|
|||
def create_parallel_risk_node(aggressive_fn, conservative_fn, neutral_fn):
|
||||
"""Create a node that runs all 3 risk analysts in parallel.
|
||||
|
||||
Uses a sync function with ThreadPoolExecutor.submit() to avoid any
|
||||
asyncio event-loop interaction issues. LangGraph handles running sync
|
||||
nodes in its own thread, and from there we spawn our own pool.
|
||||
Uses async + asyncio.to_thread + asyncio.gather — the same pattern
|
||||
that works for create_parallel_analyst_node.
|
||||
"""
|
||||
|
||||
def parallel_risk_node(state):
|
||||
async def parallel_risk_node(state):
|
||||
import time
|
||||
import sys
|
||||
|
||||
state_snap = _snapshot_risk_state(state)
|
||||
t0 = time.time()
|
||||
|
||||
def run_agg():
|
||||
print(f"[PARALLEL] Aggressive starting at +{time.time()-t0:.1f}s", flush=True)
|
||||
result = aggressive_fn(state_snap)
|
||||
print(f"[PARALLEL] Aggressive done at +{time.time()-t0:.1f}s", flush=True)
|
||||
async def run_agg():
|
||||
print(f"[PARALLEL] Aggressive starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
result = await asyncio.to_thread(aggressive_fn, state_snap)
|
||||
print(f"[PARALLEL] Aggressive done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
return result
|
||||
|
||||
def run_con():
|
||||
print(f"[PARALLEL] Conservative starting at +{time.time()-t0:.1f}s", flush=True)
|
||||
result = conservative_fn(state_snap)
|
||||
print(f"[PARALLEL] Conservative done at +{time.time()-t0:.1f}s", flush=True)
|
||||
async def run_con():
|
||||
print(f"[PARALLEL] Conservative starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
result = await asyncio.to_thread(conservative_fn, state_snap)
|
||||
print(f"[PARALLEL] Conservative done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
return result
|
||||
|
||||
def run_neu():
|
||||
print(f"[PARALLEL] Neutral starting at +{time.time()-t0:.1f}s", flush=True)
|
||||
result = neutral_fn(state_snap)
|
||||
print(f"[PARALLEL] Neutral done at +{time.time()-t0:.1f}s", flush=True)
|
||||
async def run_neu():
|
||||
print(f"[PARALLEL] Neutral starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
result = await asyncio.to_thread(neutral_fn, state_snap)
|
||||
print(f"[PARALLEL] Neutral done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
return result
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as pool:
|
||||
agg_future = pool.submit(run_agg)
|
||||
con_future = pool.submit(run_con)
|
||||
neu_future = pool.submit(run_neu)
|
||||
agg_result = agg_future.result()
|
||||
con_result = con_future.result()
|
||||
neu_result = neu_future.result()
|
||||
agg_result, con_result, neu_result = await asyncio.gather(
|
||||
run_agg(), run_con(), run_neu()
|
||||
)
|
||||
|
||||
logger.info("Parallel risk total: %.1fs", time.time() - t0)
|
||||
print(f"[PARALLEL] Risk total: {time.time()-t0:.1f}s", file=sys.stderr, flush=True)
|
||||
|
||||
agg_debate = agg_result["risk_debate_state"]
|
||||
con_debate = con_result["risk_debate_state"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue