diff --git a/tradingagents/graph/parallel_analysts.py b/tradingagents/graph/parallel_analysts.py index b9e33635..b5c23d0d 100644 --- a/tradingagents/graph/parallel_analysts.py +++ b/tradingagents/graph/parallel_analysts.py @@ -83,31 +83,62 @@ def create_parallel_analyst_node(analyst_fns, tool_nodes, selected_analysts): return parallel_analysts_node +def _snapshot_research_state(state): + """Extract research-relevant fields into a plain dict.""" + return { + "investment_debate_state": dict(state.get("investment_debate_state", {})), + "market_report": state.get("market_report", ""), + "sentiment_report": state.get("sentiment_report", ""), + "news_report": state.get("news_report", ""), + "fundamentals_report": state.get("fundamentals_report", ""), + } + + +def _snapshot_risk_state(state): + """Extract risk-relevant fields into a plain dict.""" + return { + "risk_debate_state": dict(state.get("risk_debate_state", {})), + "market_report": state.get("market_report", ""), + "sentiment_report": state.get("sentiment_report", ""), + "news_report": state.get("news_report", ""), + "fundamentals_report": state.get("fundamentals_report", ""), + "trader_investment_plan": state.get("trader_investment_plan", ""), + } + + def create_parallel_research_node(bull_fn, bear_fn): """Create a node that runs Bull and Bear researchers in parallel. - Both agents receive the same state (reports + empty debate state) and - produce independent arguments. Results are merged into a single - investment_debate_state with both histories and count=2. + Uses a sync function with ThreadPoolExecutor.submit() to avoid any + asyncio event-loop interaction issues. LangGraph handles running sync + nodes in its own thread, and from there we spawn our own pool. """ - async def parallel_research_node(state): - # Snapshot into plain dicts — LangGraph state proxies serialize - # concurrent dict access, which would force sequential execution. - state_snap = { - "investment_debate_state": dict(state.get("investment_debate_state", {})), - "market_report": state.get("market_report", ""), - "sentiment_report": state.get("sentiment_report", ""), - "news_report": state.get("news_report", ""), - "fundamentals_report": state.get("fundamentals_report", ""), - } + def parallel_research_node(state): + import time + + state_snap = _snapshot_research_state(state) + t0 = time.time() + + def run_bull(): + logger.info("Bull researcher starting") + result = bull_fn(state_snap) + logger.info("Bull researcher done in %.1fs", time.time() - t0) + return result + + def run_bear(): + logger.info("Bear researcher starting") + result = bear_fn(state_snap) + logger.info("Bear researcher done in %.1fs", time.time() - t0) + return result - loop = asyncio.get_running_loop() with ThreadPoolExecutor(max_workers=2) as pool: - bull_result, bear_result = await asyncio.gather( - loop.run_in_executor(pool, bull_fn, state_snap), - loop.run_in_executor(pool, bear_fn, state_snap), - ) + bull_future = pool.submit(run_bull) + bear_future = pool.submit(run_bear) + bull_result = bull_future.result() + bear_result = bear_future.result() + + logger.info("Parallel research total: %.1fs", time.time() - t0) bull_debate = bull_result["investment_debate_state"] bear_debate = bear_result["investment_debate_state"] @@ -130,30 +161,44 @@ def create_parallel_research_node(bull_fn, bear_fn): def create_parallel_risk_node(aggressive_fn, conservative_fn, neutral_fn): """Create a node that runs all 3 risk analysts in parallel. - All agents receive the same state (trader plan + empty risk debate state) - and produce independent arguments. Results are merged into a single - risk_debate_state with all histories and count=3. + Uses a sync function with ThreadPoolExecutor.submit() to avoid any + asyncio event-loop interaction issues. LangGraph handles running sync + nodes in its own thread, and from there we spawn our own pool. """ - async def parallel_risk_node(state): - # Snapshot into plain dicts — LangGraph state proxies serialize - # concurrent dict access, which would force sequential execution. - state_snap = { - "risk_debate_state": dict(state.get("risk_debate_state", {})), - "market_report": state.get("market_report", ""), - "sentiment_report": state.get("sentiment_report", ""), - "news_report": state.get("news_report", ""), - "fundamentals_report": state.get("fundamentals_report", ""), - "trader_investment_plan": state.get("trader_investment_plan", ""), - } + def parallel_risk_node(state): + import time + + state_snap = _snapshot_risk_state(state) + t0 = time.time() + + def run_agg(): + logger.info("Aggressive analyst starting") + result = aggressive_fn(state_snap) + logger.info("Aggressive analyst done in %.1fs", time.time() - t0) + return result + + def run_con(): + logger.info("Conservative analyst starting") + result = conservative_fn(state_snap) + logger.info("Conservative analyst done in %.1fs", time.time() - t0) + return result + + def run_neu(): + logger.info("Neutral analyst starting") + result = neutral_fn(state_snap) + logger.info("Neutral analyst done in %.1fs", time.time() - t0) + return result - loop = asyncio.get_running_loop() with ThreadPoolExecutor(max_workers=3) as pool: - agg_result, con_result, neu_result = await asyncio.gather( - loop.run_in_executor(pool, aggressive_fn, state_snap), - loop.run_in_executor(pool, conservative_fn, state_snap), - loop.run_in_executor(pool, neutral_fn, state_snap), - ) + agg_future = pool.submit(run_agg) + con_future = pool.submit(run_con) + neu_future = pool.submit(run_neu) + agg_result = agg_future.result() + con_result = con_future.result() + neu_result = neu_future.result() + + logger.info("Parallel risk total: %.1fs", time.time() - t0) agg_debate = agg_result["risk_debate_state"] con_debate = con_result["risk_debate_state"]