diff --git a/tradingagents/graph/parallel_analysts.py b/tradingagents/graph/parallel_analysts.py index 97837b85..35ea0cda 100644 --- a/tradingagents/graph/parallel_analysts.py +++ b/tradingagents/graph/parallel_analysts.py @@ -8,7 +8,6 @@ Provides parallel wrappers for: import asyncio import logging -from concurrent.futures import ThreadPoolExecutor from langchain_core.messages import HumanMessage, RemoveMessage @@ -109,36 +108,32 @@ def _snapshot_risk_state(state): def create_parallel_research_node(bull_fn, bear_fn): """Create a node that runs Bull and Bear researchers in parallel. - Uses a sync function with ThreadPoolExecutor.submit() to avoid any - asyncio event-loop interaction issues. LangGraph handles running sync - nodes in its own thread, and from there we spawn our own pool. + Uses async + asyncio.to_thread + asyncio.gather — the same pattern + that works for create_parallel_analyst_node. """ - def parallel_research_node(state): + async def parallel_research_node(state): import time + import sys state_snap = _snapshot_research_state(state) t0 = time.time() - def run_bull(): - print(f"[PARALLEL] Bull starting at +{time.time()-t0:.1f}s", flush=True) - result = bull_fn(state_snap) - print(f"[PARALLEL] Bull done at +{time.time()-t0:.1f}s", flush=True) + async def run_bull(): + print(f"[PARALLEL] Bull starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) + result = await asyncio.to_thread(bull_fn, state_snap) + print(f"[PARALLEL] Bull done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) return result - def run_bear(): - print(f"[PARALLEL] Bear starting at +{time.time()-t0:.1f}s", flush=True) - result = bear_fn(state_snap) - print(f"[PARALLEL] Bear done at +{time.time()-t0:.1f}s", flush=True) + async def run_bear(): + print(f"[PARALLEL] Bear starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) + result = await asyncio.to_thread(bear_fn, state_snap) + print(f"[PARALLEL] Bear done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) return result - with ThreadPoolExecutor(max_workers=2) as pool: - bull_future = pool.submit(run_bull) - bear_future = pool.submit(run_bear) - bull_result = bull_future.result() - bear_result = bear_future.result() + bull_result, bear_result = await asyncio.gather(run_bull(), run_bear()) - print(f"[PARALLEL] Research total: {time.time()-t0:.1f}s", flush=True) + print(f"[PARALLEL] Research total: {time.time()-t0:.1f}s", file=sys.stderr, flush=True) bull_debate = bull_result["investment_debate_state"] bear_debate = bear_result["investment_debate_state"] @@ -161,44 +156,40 @@ def create_parallel_research_node(bull_fn, bear_fn): def create_parallel_risk_node(aggressive_fn, conservative_fn, neutral_fn): """Create a node that runs all 3 risk analysts in parallel. - Uses a sync function with ThreadPoolExecutor.submit() to avoid any - asyncio event-loop interaction issues. LangGraph handles running sync - nodes in its own thread, and from there we spawn our own pool. + Uses async + asyncio.to_thread + asyncio.gather — the same pattern + that works for create_parallel_analyst_node. """ - def parallel_risk_node(state): + async def parallel_risk_node(state): import time + import sys state_snap = _snapshot_risk_state(state) t0 = time.time() - def run_agg(): - print(f"[PARALLEL] Aggressive starting at +{time.time()-t0:.1f}s", flush=True) - result = aggressive_fn(state_snap) - print(f"[PARALLEL] Aggressive done at +{time.time()-t0:.1f}s", flush=True) + async def run_agg(): + print(f"[PARALLEL] Aggressive starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) + result = await asyncio.to_thread(aggressive_fn, state_snap) + print(f"[PARALLEL] Aggressive done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) return result - def run_con(): - print(f"[PARALLEL] Conservative starting at +{time.time()-t0:.1f}s", flush=True) - result = conservative_fn(state_snap) - print(f"[PARALLEL] Conservative done at +{time.time()-t0:.1f}s", flush=True) + async def run_con(): + print(f"[PARALLEL] Conservative starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) + result = await asyncio.to_thread(conservative_fn, state_snap) + print(f"[PARALLEL] Conservative done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) return result - def run_neu(): - print(f"[PARALLEL] Neutral starting at +{time.time()-t0:.1f}s", flush=True) - result = neutral_fn(state_snap) - print(f"[PARALLEL] Neutral done at +{time.time()-t0:.1f}s", flush=True) + async def run_neu(): + print(f"[PARALLEL] Neutral starting at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) + result = await asyncio.to_thread(neutral_fn, state_snap) + print(f"[PARALLEL] Neutral done at +{time.time()-t0:.1f}s", file=sys.stderr, flush=True) return result - with ThreadPoolExecutor(max_workers=3) as pool: - agg_future = pool.submit(run_agg) - con_future = pool.submit(run_con) - neu_future = pool.submit(run_neu) - agg_result = agg_future.result() - con_result = con_future.result() - neu_result = neu_future.result() + agg_result, con_result, neu_result = await asyncio.gather( + run_agg(), run_con(), run_neu() + ) - logger.info("Parallel risk total: %.1fs", time.time() - t0) + print(f"[PARALLEL] Risk total: {time.time()-t0:.1f}s", file=sys.stderr, flush=True) agg_debate = agg_result["risk_debate_state"] con_debate = con_result["risk_debate_state"]