From addc4a1e9c8b84fc86e5573ed02b62faf3861161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Tue, 14 Apr 2026 03:49:33 +0800 Subject: [PATCH] Keep research degradation visible while bounding researcher nodes Research provenance now rides with the debate state, cache metadata, live payloads, and trace dumps so degraded research no longer masquerades as a normal sample. Bull/Bear/Manager nodes also return explicit guarded fallbacks on timeout or exception, which gives the graph a real node budget boundary without rewriting the bull/bear output shape or removing debate.\n\nConstraint: Must preserve bull/bear debate structure and output shape while adding provenance and node guards\nRejected: Skip bull/bear debate in compact mode | would trade away analysis quality before A/B evidence exists\nConfidence: high\nScope-risk: moderate\nReversibility: clean\nDirective: Treat research_status and data_quality as rollout gates; do not collapse degraded research back into normal success samples\nTested: python -m pytest tradingagents/tests/test_research_guard.py orchestrator/tests/test_llm_runner.py orchestrator/tests/test_live_mode.py web_dashboard/backend/tests/test_executors.py web_dashboard/backend/tests/test_services_migration.py web_dashboard/backend/tests/test_api_smoke.py -q; python -m compileall tradingagents/graph/setup.py tradingagents/agents/utils/agent_states.py tradingagents/graph/propagation.py orchestrator/llm_runner.py orchestrator/live_mode.py orchestrator/profile_stage_chain.py; python orchestrator/profile_stage_chain.py --ticker 600519.SS --date 2026-04-10 --provider anthropic --model MiniMax-M2.7-highspeed --base-url https://api.minimaxi.com/anthropic --selected-analysts market --analysis-prompt-style compact --timeout 45 --max-retries 0 --overall-timeout 120 --dump-raw-on-failure\nNot-tested: Full successful live-provider completion through Portfolio Manager after the post-research connection failure --- orchestrator/live_mode.py | 8 ++ orchestrator/llm_runner.py | 37 +++++- orchestrator/orchestrator.py | 5 + orchestrator/profile_stage_chain.py | 63 +++++++++- orchestrator/tests/test_live_mode.py | 50 +++++++- orchestrator/tests/test_llm_runner.py | 26 +++++ tradingagents/agents/utils/agent_states.py | 12 +- tradingagents/default_config.py | 1 + tradingagents/graph/propagation.py | 6 + tradingagents/graph/setup.py | 119 ++++++++++++++++++- tradingagents/graph/trading_graph.py | 1 + tradingagents/tests/test_research_guard.py | 127 +++++++++++++++++++++ 12 files changed, 443 insertions(+), 12 deletions(-) create mode 100644 tradingagents/tests/test_research_guard.py diff --git a/orchestrator/live_mode.py b/orchestrator/live_mode.py index e7cb8517..e1e83076 100644 --- a/orchestrator/live_mode.py +++ b/orchestrator/live_mode.py @@ -45,6 +45,7 @@ class LiveMode: def _serialize_signal(self, *, ticker: str, date: str, signal) -> dict: metadata = getattr(signal, "metadata", {}) or {} data_quality = metadata.get("data_quality") + research = metadata.get("research") degradation = self._serialize_degradation(signal, data_quality) return { "contract_version": self._contract_version(signal), @@ -55,6 +56,7 @@ class LiveMode: "error": None, "degradation": degradation, "data_quality": data_quality, + "research": research, } @staticmethod @@ -64,6 +66,11 @@ class LiveMode: reason_codes.append(ReasonCode.BOTH_SIGNALS_UNAVAILABLE.value) source_diagnostics = dict(getattr(exc, "source_diagnostics", {}) or {}) data_quality = getattr(exc, "data_quality", None) + research = None + for diagnostic in source_diagnostics.values(): + if isinstance(diagnostic, dict) and diagnostic.get("research") is not None: + research = diagnostic["research"] + break return { "contract_version": CONTRACT_VERSION, "ticker": ticker, @@ -81,6 +88,7 @@ class LiveMode: "source_diagnostics": source_diagnostics, }, "data_quality": data_quality, + "research": research, } async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]: diff --git a/orchestrator/llm_runner.py b/orchestrator/llm_runner.py index 9c5b3988..3e7bbdee 100644 --- a/orchestrator/llm_runner.py +++ b/orchestrator/llm_runner.py @@ -16,6 +16,24 @@ def _build_data_quality(state: str, **details): return payload +def _extract_research_metadata(final_state: dict | None) -> dict | None: + if not isinstance(final_state, dict): + return None + debate_state = final_state.get("investment_debate_state") or {} + if not isinstance(debate_state, dict): + return None + keys = ( + "research_status", + "research_mode", + "timed_out_nodes", + "degraded_reason", + "covered_dimensions", + "manager_confidence", + ) + metadata = {key: debate_state.get(key) for key in keys if key in debate_state} + return metadata or None + + class LLMRunner: def __init__(self, config: OrchestratorConfig): self._config = config @@ -91,6 +109,17 @@ class LLMRunner: rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal) direction, confidence = self._map_rating(rating) now = datetime.now(timezone.utc) + research_metadata = _extract_research_metadata(_final_state) + if research_metadata and research_metadata.get("research_status") != "full": + data_quality = _build_data_quality( + "research_degraded", + research_status=research_metadata.get("research_status"), + research_mode=research_metadata.get("research_mode"), + degraded_reason=research_metadata.get("degraded_reason"), + timed_out_nodes=research_metadata.get("timed_out_nodes"), + ) + else: + data_quality = _build_data_quality("ok") cache_data = { "rating": rating, @@ -99,7 +128,13 @@ class LLMRunner: "timestamp": now.isoformat(), "ticker": ticker, "date": date, - "data_quality": _build_data_quality("ok"), + "data_quality": data_quality, + "research": research_metadata, + "sample_quality": ( + "degraded_research" + if research_metadata and research_metadata.get("research_status") != "full" + else "full_research" + ), } with open(cache_path, "w", encoding="utf-8") as f: json.dump(cache_data, f, ensure_ascii=False, indent=2) diff --git a/orchestrator/orchestrator.py b/orchestrator/orchestrator.py index e78e22c6..c483d879 100644 --- a/orchestrator/orchestrator.py +++ b/orchestrator/orchestrator.py @@ -113,6 +113,8 @@ class TradingOrchestrator: metadata["source_diagnostics"] = source_diagnostics if data_quality: metadata["data_quality"] = data_quality + if llm_sig is not None and llm_sig.metadata.get("research") is not None: + metadata["research"] = llm_sig.metadata.get("research") final_signal.metadata = metadata return final_signal @@ -125,6 +127,9 @@ class TradingOrchestrator: error = signal.metadata.get("error") if error: diagnostic["error"] = error + research = signal.metadata.get("research") + if research is not None: + diagnostic["research"] = research return diagnostic @staticmethod diff --git a/orchestrator/profile_stage_chain.py b/orchestrator/profile_stage_chain.py index 5022fc51..1856c20d 100644 --- a/orchestrator/profile_stage_chain.py +++ b/orchestrator/profile_stage_chain.py @@ -23,6 +23,18 @@ _PHASE_MAP = { "Portfolio Manager": "portfolio", } +_LLM_KIND_MAP = { + "Market Analyst": "quick", + "Bull Researcher": "quick", + "Bear Researcher": "quick", + "Research Manager": "deep", + "Trader": "quick", + "Aggressive Analyst": "quick", + "Conservative Analyst": "quick", + "Neutral Analyst": "quick", + "Portfolio Manager": "deep", +} + def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Profile TradingAgents graph stage timings.") @@ -37,6 +49,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument("--selected-analysts", default="market") parser.add_argument("--overall-timeout", type=int, default=120) parser.add_argument("--dump-dir", default="orchestrator/profile_runs") + parser.add_argument("--dump-raw-on-failure", action="store_true") return parser @@ -44,6 +57,33 @@ class _ProfileTimeout(Exception): pass +def _jsonable(value): + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, dict): + return {str(k): _jsonable(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_jsonable(item) for item in value] + return repr(value) + + +def _extract_research_state(event: dict) -> tuple[str | None, str | None, int | None, int | None]: + node_payload = next(iter(event.values()), {}) + if not isinstance(node_payload, dict): + return None, None, None, None + debate_state = node_payload.get("investment_debate_state") or {} + if not isinstance(debate_state, dict): + return None, None, None, None + history = debate_state.get("history") or "" + current = debate_state.get("current_response") or "" + return ( + debate_state.get("research_status"), + debate_state.get("degraded_reason"), + len(history), + len(current), + ) + + def main() -> None: args = build_parser().parse_args() selected_analysts = [item.strip() for item in args.selected_analysts.split(",") if item.strip()] @@ -66,11 +106,12 @@ def main() -> None: node_timings = [] phase_totals = defaultdict(float) + raw_events = [] started_at = time.monotonic() last_at = started_at + run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") dump_dir = Path(args.dump_dir) dump_dir.mkdir(parents=True, exist_ok=True) - run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") dump_path = dump_dir / f"{args.ticker.replace('/', '_')}_{args.date}_{run_id}.json" def alarm_handler(signum, frame): @@ -84,14 +125,26 @@ def main() -> None: now = time.monotonic() nodes = list(event.keys()) phases = sorted({_PHASE_MAP.get(node, "unknown") for node in nodes}) + llm_kinds = sorted({_LLM_KIND_MAP.get(node, "unknown") for node in nodes}) delta = round(now - last_at, 3) + research_status, degraded_reason, history_len, response_len = _extract_research_state(event) entry = { + "run_id": run_id, "nodes": nodes, "phases": phases, - "delta_seconds": delta, - "elapsed_seconds": round(now - started_at, 3), + "llm_kinds": llm_kinds, + "start_at": round(last_at - started_at, 3), + "end_at": round(now - started_at, 3), + "elapsed_ms": int(delta * 1000), + "selected_analysts": selected_analysts, + "analysis_prompt_style": args.analysis_prompt_style, + "research_status": research_status, + "degraded_reason": degraded_reason, + "history_len": history_len, + "response_len": response_len, } node_timings.append(entry) + raw_events.append(_jsonable(event)) for phase in phases: phase_totals[phase] += delta last_at = now @@ -105,18 +158,22 @@ def main() -> None: "node_timings": node_timings, "phase_totals_seconds": {key: round(value, 3) for key, value in phase_totals.items()}, "dump_path": str(dump_path), + "raw_events": raw_events if args.dump_raw_on_failure else [], } except Exception as exc: payload = { + "run_id": run_id, "status": "error", "ticker": args.ticker, "date": args.date, "selected_analysts": selected_analysts, "analysis_prompt_style": args.analysis_prompt_style, "error": str(exc), + "exception_type": type(exc).__name__, "node_timings": node_timings, "phase_totals_seconds": {key: round(value, 3) for key, value in phase_totals.items()}, "dump_path": str(dump_path), + "raw_events": raw_events, } finally: signal.alarm(0) diff --git a/orchestrator/tests/test_live_mode.py b/orchestrator/tests/test_live_mode.py index d1baa2d7..030eb2c5 100644 --- a/orchestrator/tests/test_live_mode.py +++ b/orchestrator/tests/test_live_mode.py @@ -42,6 +42,14 @@ def test_live_mode_serializes_degraded_contract_shape(): metadata={ "contract_version": "v1alpha1", "data_quality": {"state": "stale_data", "source": "quant"}, + "research": { + "research_status": "degraded", + "research_mode": "degraded_synthesis", + "timed_out_nodes": ["Bull Researcher"], + "degraded_reason": "bull_researcher_timeout", + "covered_dimensions": ["market"], + "manager_confidence": None, + }, "source_diagnostics": { "quant": {"reason_code": ReasonCode.STALE_DATA.value} }, @@ -75,6 +83,14 @@ def test_live_mode_serializes_degraded_contract_shape(): }, }, "data_quality": {"state": "stale_data", "source": "quant"}, + "research": { + "research_status": "degraded", + "research_mode": "degraded_synthesis", + "timed_out_nodes": ["Bull Researcher"], + "degraded_reason": "bull_researcher_timeout", + "covered_dimensions": ["market"], + "manager_confidence": None, + }, } ] @@ -86,7 +102,19 @@ def test_live_mode_serializes_failure_contract_shape(): ("AAPL", "2026-04-11"): CombinedSignalFailure( "both quant and llm signals are None", reason_codes=(ReasonCode.BOTH_SIGNALS_UNAVAILABLE.value, ReasonCode.PROVIDER_MISMATCH.value), - source_diagnostics={"llm": {"reason_code": ReasonCode.PROVIDER_MISMATCH.value}}, + source_diagnostics={ + "llm": { + "reason_code": ReasonCode.PROVIDER_MISMATCH.value, + "research": { + "research_status": "failed", + "research_mode": "degraded_synthesis", + "timed_out_nodes": ["Bull Researcher"], + "degraded_reason": "bull_researcher_connectionerror", + "covered_dimensions": ["market"], + "manager_confidence": None, + }, + } + }, data_quality={"state": "provider_mismatch", "source": "llm"}, ) } @@ -114,9 +142,27 @@ def test_live_mode_serializes_failure_contract_shape(): ReasonCode.PROVIDER_MISMATCH.value, ], "source_diagnostics": { - "llm": {"reason_code": ReasonCode.PROVIDER_MISMATCH.value}, + "llm": { + "reason_code": ReasonCode.PROVIDER_MISMATCH.value, + "research": { + "research_status": "failed", + "research_mode": "degraded_synthesis", + "timed_out_nodes": ["Bull Researcher"], + "degraded_reason": "bull_researcher_connectionerror", + "covered_dimensions": ["market"], + "manager_confidence": None, + }, + }, }, }, "data_quality": {"state": "provider_mismatch", "source": "llm"}, + "research": { + "research_status": "failed", + "research_mode": "degraded_synthesis", + "timed_out_nodes": ["Bull Researcher"], + "degraded_reason": "bull_researcher_connectionerror", + "covered_dimensions": ["market"], + "manager_confidence": None, + }, } ] diff --git a/orchestrator/tests/test_llm_runner.py b/orchestrator/tests/test_llm_runner.py index 7cfa0f27..23ddedac 100644 --- a/orchestrator/tests/test_llm_runner.py +++ b/orchestrator/tests/test_llm_runner.py @@ -99,3 +99,29 @@ def test_get_signal_returns_provider_mismatch_before_graph_init(tmp_path): assert signal.degraded is True assert signal.reason_code == ReasonCode.PROVIDER_MISMATCH.value assert signal.metadata["data_quality"]["state"] == "provider_mismatch" + + +def test_get_signal_persists_research_provenance_on_success(monkeypatch, tmp_path): + class SuccessfulGraph: + def propagate(self, ticker, date): + return { + "investment_debate_state": { + "research_status": "degraded", + "research_mode": "degraded_synthesis", + "timed_out_nodes": ["Bull Researcher"], + "degraded_reason": "bull_researcher_timeout", + "covered_dimensions": ["market"], + "manager_confidence": None, + } + }, "BUY" + + cfg = OrchestratorConfig(cache_dir=str(tmp_path)) + runner = LLMRunner(cfg) + monkeypatch.setattr(runner, "_get_graph", lambda: SuccessfulGraph()) + + signal = runner.get_signal("AAPL", "2024-01-02") + + assert signal.degraded is False + assert signal.metadata["research"]["research_status"] == "degraded" + assert signal.metadata["sample_quality"] == "degraded_research" + assert signal.metadata["data_quality"]["state"] == "research_degraded" diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 6423b936..0fece129 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,10 +1,10 @@ -from typing import Annotated -from typing_extensions import TypedDict +from typing import Annotated, Optional +from typing_extensions import NotRequired, TypedDict from langgraph.graph import MessagesState # Researcher team state -class InvestDebateState(TypedDict): +class InvestDebateState(TypedDict, total=False): bull_history: Annotated[ str, "Bullish Conversation history" ] # Bullish Conversation history @@ -15,6 +15,12 @@ class InvestDebateState(TypedDict): current_response: Annotated[str, "Latest response"] # Last response judge_decision: Annotated[str, "Final judge decision"] # Last response count: Annotated[int, "Length of the current conversation"] # Conversation length + research_status: NotRequired[Annotated[str, "Research stage status: full/degraded/failed"]] + research_mode: NotRequired[Annotated[str, "Research mode: debate/degraded_synthesis"]] + timed_out_nodes: NotRequired[Annotated[list[str], "Research nodes that timed out"]] + degraded_reason: NotRequired[Annotated[Optional[str], "Research degradation reason"]] + covered_dimensions: NotRequired[Annotated[list[str], "Research dimensions covered so far"]] + manager_confidence: NotRequired[Annotated[Optional[float], "Research manager confidence"]] # Risk management team state diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index c4fbf51b..eb6485fe 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -24,6 +24,7 @@ DEFAULT_CONFIG = { "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, "max_recur_limit": 100, + "research_node_timeout_secs": 30.0, # Data vendor configuration # Category-level configuration (default for all tools in category) "data_vendors": { diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index f49fbb1c..3e72db3e 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -31,6 +31,12 @@ class Propagator: "current_response": "", "judge_decision": "", "count": 0, + "research_status": "full", + "research_mode": "debate", + "timed_out_nodes": [], + "degraded_reason": None, + "covered_dimensions": [], + "manager_confidence": None, } ), "risk_debate_state": RiskDebateState( diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index ae90489c..77c0b46c 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,5 +1,7 @@ # TradingAgents/graph/setup.py +import concurrent.futures +import time from typing import Any, Dict from langgraph.graph import END, START, StateGraph from langgraph.prebuilt import ToolNode @@ -24,6 +26,7 @@ class GraphSetup: invest_judge_memory, portfolio_manager_memory, conditional_logic: ConditionalLogic, + research_node_timeout_secs: float = 30.0, ): """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm @@ -35,6 +38,7 @@ class GraphSetup: self.invest_judge_memory = invest_judge_memory self.portfolio_manager_memory = portfolio_manager_memory self.conditional_logic = conditional_logic + self.research_node_timeout_secs = research_node_timeout_secs def setup_graph( self, selected_analysts=["market", "social", "news", "fundamentals"] @@ -85,13 +89,16 @@ class GraphSetup: tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] # Create researcher and manager nodes - bull_researcher_node = create_bull_researcher( + bull_researcher_node = self._guard_research_node( + "Bull Researcher", self.quick_thinking_llm, self.bull_memory ) - bear_researcher_node = create_bear_researcher( + bear_researcher_node = self._guard_research_node( + "Bear Researcher", self.quick_thinking_llm, self.bear_memory ) - research_manager_node = create_research_manager( + research_manager_node = self._guard_research_node( + "Research Manager", self.deep_thinking_llm, self.invest_judge_memory ) trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) @@ -199,3 +206,109 @@ class GraphSetup: # Compile and return return workflow.compile() + + def _guard_research_node(self, node_name: str, llm: Any, memory): + if node_name == "Bull Researcher": + node = create_bull_researcher(llm, memory) + dimension = "bull" + elif node_name == "Bear Researcher": + node = create_bear_researcher(llm, memory) + dimension = "bear" + else: + node = create_research_manager(llm, memory) + dimension = "manager" + + def wrapped(state): + started_at = time.time() + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = executor.submit(node, state) + try: + result = future.result(timeout=self.research_node_timeout_secs) + return self._apply_research_success(state, result, dimension) + except concurrent.futures.TimeoutError: + future.cancel() + executor.shutdown(wait=False, cancel_futures=True) + return self._apply_research_fallback( + state, + node_name=node_name, + dimension=dimension, + reason=f"{node_name.lower().replace(' ', '_')}_timeout", + started_at=started_at, + ) + except Exception as exc: + executor.shutdown(wait=False, cancel_futures=True) + return self._apply_research_fallback( + state, + node_name=node_name, + dimension=dimension, + reason=f"{node_name.lower().replace(' ', '_')}_{type(exc).__name__.lower()}", + started_at=started_at, + ) + finally: + executor.shutdown(wait=False, cancel_futures=True) + + return wrapped + + @staticmethod + def _provenance(state) -> dict: + debate_state = dict(state["investment_debate_state"]) + return { + "research_status": debate_state.get("research_status", "full"), + "research_mode": debate_state.get("research_mode", "debate"), + "timed_out_nodes": list(debate_state.get("timed_out_nodes", [])), + "degraded_reason": debate_state.get("degraded_reason"), + "covered_dimensions": list(debate_state.get("covered_dimensions", [])), + "manager_confidence": debate_state.get("manager_confidence"), + } + + def _apply_research_success(self, state, result: dict, dimension: str): + debate_state = dict(result.get("investment_debate_state") or state["investment_debate_state"]) + provenance = self._provenance(state) + if dimension not in provenance["covered_dimensions"]: + provenance["covered_dimensions"].append(dimension) + if provenance["research_status"] == "full": + provenance["research_mode"] = "debate" + if dimension == "manager" and provenance["manager_confidence"] is None: + provenance["manager_confidence"] = 1.0 if provenance["research_status"] == "full" else 0.5 + debate_state.update(provenance) + updated = dict(result) + updated["investment_debate_state"] = debate_state + return updated + + def _apply_research_fallback(self, state, *, node_name: str, dimension: str, reason: str, started_at: float): + debate_state = dict(state["investment_debate_state"]) + provenance = self._provenance(state) + provenance["research_status"] = "degraded" + provenance["research_mode"] = "degraded_synthesis" + provenance["degraded_reason"] = reason + if "timeout" in reason and node_name not in provenance["timed_out_nodes"]: + provenance["timed_out_nodes"].append(node_name) + + elapsed_seconds = round(time.time() - started_at, 3) + if dimension == "manager": + provenance["manager_confidence"] = 0.0 + fallback = ( + "Recommendation: HOLD\n" + f"Top reasons: research degraded at {node_name} ({reason}); use partial research context cautiously.\n" + f"Simple execution plan: keep sizing conservative and wait for confirmation. Guard elapsed={elapsed_seconds}s." + ) + debate_state["judge_decision"] = fallback + debate_state["current_response"] = fallback + debate_state.update(provenance) + return { + "investment_debate_state": debate_state, + "investment_plan": fallback, + } + + prefix = "Bull Analyst" if dimension == "bull" else "Bear Analyst" + history_field = "bull_history" if dimension == "bull" else "bear_history" + degraded_argument = ( + f"{prefix}: [DEGRADED] {node_name} unavailable ({reason}). " + f"Proceeding with partial research context. Guard elapsed={elapsed_seconds}s." + ) + debate_state["history"] = debate_state.get("history", "") + "\n" + degraded_argument + debate_state[history_field] = debate_state.get(history_field, "") + "\n" + degraded_argument + debate_state["current_response"] = degraded_argument + debate_state["count"] = debate_state.get("count", 0) + 1 + debate_state.update(provenance) + return {"investment_debate_state": debate_state} diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 282fdfc3..44a8e884 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -144,6 +144,7 @@ class TradingAgentsGraph: self.invest_judge_memory, self.portfolio_manager_memory, self.conditional_logic, + research_node_timeout_secs=float(self.config.get("research_node_timeout_secs", 30.0)), ) self.propagator = Propagator() diff --git a/tradingagents/tests/test_research_guard.py b/tradingagents/tests/test_research_guard.py new file mode 100644 index 00000000..fe4631ee --- /dev/null +++ b/tradingagents/tests/test_research_guard.py @@ -0,0 +1,127 @@ +import time + +import tradingagents.graph.setup as graph_setup_module +from tradingagents.graph.setup import GraphSetup + + +def _setup() -> GraphSetup: + return GraphSetup( + quick_thinking_llm=None, + deep_thinking_llm=None, + tool_nodes={}, + bull_memory=None, + bear_memory=None, + trader_memory=None, + invest_judge_memory=None, + portfolio_manager_memory=None, + conditional_logic=None, + research_node_timeout_secs=0.01, + ) + + +def test_manager_guard_fallback_marks_degraded_synthesis(): + setup = _setup() + state = { + "investment_debate_state": { + "history": "Bull Analyst: case", + "bull_history": "Bull Analyst: case", + "bear_history": "", + "current_response": "Bull Analyst: case", + "judge_decision": "", + "count": 1, + "research_status": "full", + "research_mode": "debate", + "timed_out_nodes": [], + "degraded_reason": None, + "covered_dimensions": ["bull"], + "manager_confidence": None, + } + } + + result = setup._apply_research_fallback( + state, + node_name="Research Manager", + dimension="manager", + reason="research_manager_timeout", + started_at=0.0, + ) + + debate = result["investment_debate_state"] + assert debate["research_status"] == "degraded" + assert debate["research_mode"] == "degraded_synthesis" + assert debate["timed_out_nodes"] == ["Research Manager"] + assert result["investment_plan"].startswith("Recommendation: HOLD") + + +def test_bull_guard_success_records_coverage(): + setup = _setup() + state = { + "investment_debate_state": { + "history": "", + "bull_history": "", + "bear_history": "", + "current_response": "", + "judge_decision": "", + "count": 0, + "research_status": "full", + "research_mode": "debate", + "timed_out_nodes": [], + "degraded_reason": None, + "covered_dimensions": [], + "manager_confidence": None, + } + } + result = { + "investment_debate_state": { + "history": "Bull Analyst: ok", + "bull_history": "Bull Analyst: ok", + "bear_history": "", + "current_response": "Bull Analyst: ok", + "judge_decision": "", + "count": 1, + } + } + + updated = setup._apply_research_success(state, result, dimension="bull") + debate = updated["investment_debate_state"] + assert debate["research_status"] == "full" + assert debate["research_mode"] == "debate" + assert debate["covered_dimensions"] == ["bull"] + + +def test_guard_timeout_returns_without_waiting_for_node_completion(monkeypatch): + def slow_bull(_llm, _memory): + def node(_state): + time.sleep(0.2) + return {"investment_debate_state": {"history": "", "bull_history": "", "bear_history": "", "current_response": "", "judge_decision": "", "count": 1}} + return node + + monkeypatch.setattr(graph_setup_module, "create_bull_researcher", slow_bull) + setup = _setup() + wrapped = setup._guard_research_node("Bull Researcher", None, None) + state = { + "investment_debate_state": { + "history": "", + "bull_history": "", + "bear_history": "", + "current_response": "", + "judge_decision": "", + "count": 0, + "research_status": "full", + "research_mode": "debate", + "timed_out_nodes": [], + "degraded_reason": None, + "covered_dimensions": [], + "manager_confidence": None, + } + } + + started = time.monotonic() + result = wrapped(state) + elapsed = time.monotonic() - started + + assert elapsed < 0.1 + debate = result["investment_debate_state"] + assert debate["research_status"] == "degraded" + assert debate["research_mode"] == "degraded_synthesis" + assert debate["timed_out_nodes"] == ["Bull Researcher"]