Keep research degradation visible while bounding researcher nodes

Research provenance now rides with the debate state, cache metadata, live payloads, and trace dumps so degraded research no longer masquerades as a normal sample. Bull/Bear/Manager nodes also return explicit guarded fallbacks on timeout or exception, which gives the graph a real node budget boundary without rewriting the bull/bear output shape or removing debate.\n\nConstraint: Must preserve bull/bear debate structure and output shape while adding provenance and node guards\nRejected: Skip bull/bear debate in compact mode | would trade away analysis quality before A/B evidence exists\nConfidence: high\nScope-risk: moderate\nReversibility: clean\nDirective: Treat research_status and data_quality as rollout gates; do not collapse degraded research back into normal success samples\nTested: python -m pytest tradingagents/tests/test_research_guard.py orchestrator/tests/test_llm_runner.py orchestrator/tests/test_live_mode.py web_dashboard/backend/tests/test_executors.py web_dashboard/backend/tests/test_services_migration.py web_dashboard/backend/tests/test_api_smoke.py -q; python -m compileall tradingagents/graph/setup.py tradingagents/agents/utils/agent_states.py tradingagents/graph/propagation.py orchestrator/llm_runner.py orchestrator/live_mode.py orchestrator/profile_stage_chain.py; python orchestrator/profile_stage_chain.py --ticker 600519.SS --date 2026-04-10 --provider anthropic --model MiniMax-M2.7-highspeed --base-url https://api.minimaxi.com/anthropic --selected-analysts market --analysis-prompt-style compact --timeout 45 --max-retries 0 --overall-timeout 120 --dump-raw-on-failure\nNot-tested: Full successful live-provider completion through Portfolio Manager after the post-research connection failure
This commit is contained in:
陈少杰 2026-04-14 03:49:33 +08:00
parent baf67dbd58
commit addc4a1e9c
12 changed files with 443 additions and 12 deletions

View File

@ -45,6 +45,7 @@ class LiveMode:
def _serialize_signal(self, *, ticker: str, date: str, signal) -> dict: def _serialize_signal(self, *, ticker: str, date: str, signal) -> dict:
metadata = getattr(signal, "metadata", {}) or {} metadata = getattr(signal, "metadata", {}) or {}
data_quality = metadata.get("data_quality") data_quality = metadata.get("data_quality")
research = metadata.get("research")
degradation = self._serialize_degradation(signal, data_quality) degradation = self._serialize_degradation(signal, data_quality)
return { return {
"contract_version": self._contract_version(signal), "contract_version": self._contract_version(signal),
@ -55,6 +56,7 @@ class LiveMode:
"error": None, "error": None,
"degradation": degradation, "degradation": degradation,
"data_quality": data_quality, "data_quality": data_quality,
"research": research,
} }
@staticmethod @staticmethod
@ -64,6 +66,11 @@ class LiveMode:
reason_codes.append(ReasonCode.BOTH_SIGNALS_UNAVAILABLE.value) reason_codes.append(ReasonCode.BOTH_SIGNALS_UNAVAILABLE.value)
source_diagnostics = dict(getattr(exc, "source_diagnostics", {}) or {}) source_diagnostics = dict(getattr(exc, "source_diagnostics", {}) or {})
data_quality = getattr(exc, "data_quality", None) data_quality = getattr(exc, "data_quality", None)
research = None
for diagnostic in source_diagnostics.values():
if isinstance(diagnostic, dict) and diagnostic.get("research") is not None:
research = diagnostic["research"]
break
return { return {
"contract_version": CONTRACT_VERSION, "contract_version": CONTRACT_VERSION,
"ticker": ticker, "ticker": ticker,
@ -81,6 +88,7 @@ class LiveMode:
"source_diagnostics": source_diagnostics, "source_diagnostics": source_diagnostics,
}, },
"data_quality": data_quality, "data_quality": data_quality,
"research": research,
} }
async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]: async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]:

View File

@ -16,6 +16,24 @@ def _build_data_quality(state: str, **details):
return payload return payload
def _extract_research_metadata(final_state: dict | None) -> dict | None:
if not isinstance(final_state, dict):
return None
debate_state = final_state.get("investment_debate_state") or {}
if not isinstance(debate_state, dict):
return None
keys = (
"research_status",
"research_mode",
"timed_out_nodes",
"degraded_reason",
"covered_dimensions",
"manager_confidence",
)
metadata = {key: debate_state.get(key) for key in keys if key in debate_state}
return metadata or None
class LLMRunner: class LLMRunner:
def __init__(self, config: OrchestratorConfig): def __init__(self, config: OrchestratorConfig):
self._config = config self._config = config
@ -91,6 +109,17 @@ class LLMRunner:
rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal) rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal)
direction, confidence = self._map_rating(rating) direction, confidence = self._map_rating(rating)
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
research_metadata = _extract_research_metadata(_final_state)
if research_metadata and research_metadata.get("research_status") != "full":
data_quality = _build_data_quality(
"research_degraded",
research_status=research_metadata.get("research_status"),
research_mode=research_metadata.get("research_mode"),
degraded_reason=research_metadata.get("degraded_reason"),
timed_out_nodes=research_metadata.get("timed_out_nodes"),
)
else:
data_quality = _build_data_quality("ok")
cache_data = { cache_data = {
"rating": rating, "rating": rating,
@ -99,7 +128,13 @@ class LLMRunner:
"timestamp": now.isoformat(), "timestamp": now.isoformat(),
"ticker": ticker, "ticker": ticker,
"date": date, "date": date,
"data_quality": _build_data_quality("ok"), "data_quality": data_quality,
"research": research_metadata,
"sample_quality": (
"degraded_research"
if research_metadata and research_metadata.get("research_status") != "full"
else "full_research"
),
} }
with open(cache_path, "w", encoding="utf-8") as f: with open(cache_path, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2) json.dump(cache_data, f, ensure_ascii=False, indent=2)

View File

@ -113,6 +113,8 @@ class TradingOrchestrator:
metadata["source_diagnostics"] = source_diagnostics metadata["source_diagnostics"] = source_diagnostics
if data_quality: if data_quality:
metadata["data_quality"] = data_quality metadata["data_quality"] = data_quality
if llm_sig is not None and llm_sig.metadata.get("research") is not None:
metadata["research"] = llm_sig.metadata.get("research")
final_signal.metadata = metadata final_signal.metadata = metadata
return final_signal return final_signal
@ -125,6 +127,9 @@ class TradingOrchestrator:
error = signal.metadata.get("error") error = signal.metadata.get("error")
if error: if error:
diagnostic["error"] = error diagnostic["error"] = error
research = signal.metadata.get("research")
if research is not None:
diagnostic["research"] = research
return diagnostic return diagnostic
@staticmethod @staticmethod

View File

@ -23,6 +23,18 @@ _PHASE_MAP = {
"Portfolio Manager": "portfolio", "Portfolio Manager": "portfolio",
} }
_LLM_KIND_MAP = {
"Market Analyst": "quick",
"Bull Researcher": "quick",
"Bear Researcher": "quick",
"Research Manager": "deep",
"Trader": "quick",
"Aggressive Analyst": "quick",
"Conservative Analyst": "quick",
"Neutral Analyst": "quick",
"Portfolio Manager": "deep",
}
def build_parser() -> argparse.ArgumentParser: def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Profile TradingAgents graph stage timings.") parser = argparse.ArgumentParser(description="Profile TradingAgents graph stage timings.")
@ -37,6 +49,7 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument("--selected-analysts", default="market") parser.add_argument("--selected-analysts", default="market")
parser.add_argument("--overall-timeout", type=int, default=120) parser.add_argument("--overall-timeout", type=int, default=120)
parser.add_argument("--dump-dir", default="orchestrator/profile_runs") parser.add_argument("--dump-dir", default="orchestrator/profile_runs")
parser.add_argument("--dump-raw-on-failure", action="store_true")
return parser return parser
@ -44,6 +57,33 @@ class _ProfileTimeout(Exception):
pass pass
def _jsonable(value):
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, dict):
return {str(k): _jsonable(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_jsonable(item) for item in value]
return repr(value)
def _extract_research_state(event: dict) -> tuple[str | None, str | None, int | None, int | None]:
node_payload = next(iter(event.values()), {})
if not isinstance(node_payload, dict):
return None, None, None, None
debate_state = node_payload.get("investment_debate_state") or {}
if not isinstance(debate_state, dict):
return None, None, None, None
history = debate_state.get("history") or ""
current = debate_state.get("current_response") or ""
return (
debate_state.get("research_status"),
debate_state.get("degraded_reason"),
len(history),
len(current),
)
def main() -> None: def main() -> None:
args = build_parser().parse_args() args = build_parser().parse_args()
selected_analysts = [item.strip() for item in args.selected_analysts.split(",") if item.strip()] selected_analysts = [item.strip() for item in args.selected_analysts.split(",") if item.strip()]
@ -66,11 +106,12 @@ def main() -> None:
node_timings = [] node_timings = []
phase_totals = defaultdict(float) phase_totals = defaultdict(float)
raw_events = []
started_at = time.monotonic() started_at = time.monotonic()
last_at = started_at last_at = started_at
run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
dump_dir = Path(args.dump_dir) dump_dir = Path(args.dump_dir)
dump_dir.mkdir(parents=True, exist_ok=True) dump_dir.mkdir(parents=True, exist_ok=True)
run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
dump_path = dump_dir / f"{args.ticker.replace('/', '_')}_{args.date}_{run_id}.json" dump_path = dump_dir / f"{args.ticker.replace('/', '_')}_{args.date}_{run_id}.json"
def alarm_handler(signum, frame): def alarm_handler(signum, frame):
@ -84,14 +125,26 @@ def main() -> None:
now = time.monotonic() now = time.monotonic()
nodes = list(event.keys()) nodes = list(event.keys())
phases = sorted({_PHASE_MAP.get(node, "unknown") for node in nodes}) phases = sorted({_PHASE_MAP.get(node, "unknown") for node in nodes})
llm_kinds = sorted({_LLM_KIND_MAP.get(node, "unknown") for node in nodes})
delta = round(now - last_at, 3) delta = round(now - last_at, 3)
research_status, degraded_reason, history_len, response_len = _extract_research_state(event)
entry = { entry = {
"run_id": run_id,
"nodes": nodes, "nodes": nodes,
"phases": phases, "phases": phases,
"delta_seconds": delta, "llm_kinds": llm_kinds,
"elapsed_seconds": round(now - started_at, 3), "start_at": round(last_at - started_at, 3),
"end_at": round(now - started_at, 3),
"elapsed_ms": int(delta * 1000),
"selected_analysts": selected_analysts,
"analysis_prompt_style": args.analysis_prompt_style,
"research_status": research_status,
"degraded_reason": degraded_reason,
"history_len": history_len,
"response_len": response_len,
} }
node_timings.append(entry) node_timings.append(entry)
raw_events.append(_jsonable(event))
for phase in phases: for phase in phases:
phase_totals[phase] += delta phase_totals[phase] += delta
last_at = now last_at = now
@ -105,18 +158,22 @@ def main() -> None:
"node_timings": node_timings, "node_timings": node_timings,
"phase_totals_seconds": {key: round(value, 3) for key, value in phase_totals.items()}, "phase_totals_seconds": {key: round(value, 3) for key, value in phase_totals.items()},
"dump_path": str(dump_path), "dump_path": str(dump_path),
"raw_events": raw_events if args.dump_raw_on_failure else [],
} }
except Exception as exc: except Exception as exc:
payload = { payload = {
"run_id": run_id,
"status": "error", "status": "error",
"ticker": args.ticker, "ticker": args.ticker,
"date": args.date, "date": args.date,
"selected_analysts": selected_analysts, "selected_analysts": selected_analysts,
"analysis_prompt_style": args.analysis_prompt_style, "analysis_prompt_style": args.analysis_prompt_style,
"error": str(exc), "error": str(exc),
"exception_type": type(exc).__name__,
"node_timings": node_timings, "node_timings": node_timings,
"phase_totals_seconds": {key: round(value, 3) for key, value in phase_totals.items()}, "phase_totals_seconds": {key: round(value, 3) for key, value in phase_totals.items()},
"dump_path": str(dump_path), "dump_path": str(dump_path),
"raw_events": raw_events,
} }
finally: finally:
signal.alarm(0) signal.alarm(0)

View File

@ -42,6 +42,14 @@ def test_live_mode_serializes_degraded_contract_shape():
metadata={ metadata={
"contract_version": "v1alpha1", "contract_version": "v1alpha1",
"data_quality": {"state": "stale_data", "source": "quant"}, "data_quality": {"state": "stale_data", "source": "quant"},
"research": {
"research_status": "degraded",
"research_mode": "degraded_synthesis",
"timed_out_nodes": ["Bull Researcher"],
"degraded_reason": "bull_researcher_timeout",
"covered_dimensions": ["market"],
"manager_confidence": None,
},
"source_diagnostics": { "source_diagnostics": {
"quant": {"reason_code": ReasonCode.STALE_DATA.value} "quant": {"reason_code": ReasonCode.STALE_DATA.value}
}, },
@ -75,6 +83,14 @@ def test_live_mode_serializes_degraded_contract_shape():
}, },
}, },
"data_quality": {"state": "stale_data", "source": "quant"}, "data_quality": {"state": "stale_data", "source": "quant"},
"research": {
"research_status": "degraded",
"research_mode": "degraded_synthesis",
"timed_out_nodes": ["Bull Researcher"],
"degraded_reason": "bull_researcher_timeout",
"covered_dimensions": ["market"],
"manager_confidence": None,
},
} }
] ]
@ -86,7 +102,19 @@ def test_live_mode_serializes_failure_contract_shape():
("AAPL", "2026-04-11"): CombinedSignalFailure( ("AAPL", "2026-04-11"): CombinedSignalFailure(
"both quant and llm signals are None", "both quant and llm signals are None",
reason_codes=(ReasonCode.BOTH_SIGNALS_UNAVAILABLE.value, ReasonCode.PROVIDER_MISMATCH.value), reason_codes=(ReasonCode.BOTH_SIGNALS_UNAVAILABLE.value, ReasonCode.PROVIDER_MISMATCH.value),
source_diagnostics={"llm": {"reason_code": ReasonCode.PROVIDER_MISMATCH.value}}, source_diagnostics={
"llm": {
"reason_code": ReasonCode.PROVIDER_MISMATCH.value,
"research": {
"research_status": "failed",
"research_mode": "degraded_synthesis",
"timed_out_nodes": ["Bull Researcher"],
"degraded_reason": "bull_researcher_connectionerror",
"covered_dimensions": ["market"],
"manager_confidence": None,
},
}
},
data_quality={"state": "provider_mismatch", "source": "llm"}, data_quality={"state": "provider_mismatch", "source": "llm"},
) )
} }
@ -114,9 +142,27 @@ def test_live_mode_serializes_failure_contract_shape():
ReasonCode.PROVIDER_MISMATCH.value, ReasonCode.PROVIDER_MISMATCH.value,
], ],
"source_diagnostics": { "source_diagnostics": {
"llm": {"reason_code": ReasonCode.PROVIDER_MISMATCH.value}, "llm": {
"reason_code": ReasonCode.PROVIDER_MISMATCH.value,
"research": {
"research_status": "failed",
"research_mode": "degraded_synthesis",
"timed_out_nodes": ["Bull Researcher"],
"degraded_reason": "bull_researcher_connectionerror",
"covered_dimensions": ["market"],
"manager_confidence": None,
},
},
}, },
}, },
"data_quality": {"state": "provider_mismatch", "source": "llm"}, "data_quality": {"state": "provider_mismatch", "source": "llm"},
"research": {
"research_status": "failed",
"research_mode": "degraded_synthesis",
"timed_out_nodes": ["Bull Researcher"],
"degraded_reason": "bull_researcher_connectionerror",
"covered_dimensions": ["market"],
"manager_confidence": None,
},
} }
] ]

View File

@ -99,3 +99,29 @@ def test_get_signal_returns_provider_mismatch_before_graph_init(tmp_path):
assert signal.degraded is True assert signal.degraded is True
assert signal.reason_code == ReasonCode.PROVIDER_MISMATCH.value assert signal.reason_code == ReasonCode.PROVIDER_MISMATCH.value
assert signal.metadata["data_quality"]["state"] == "provider_mismatch" assert signal.metadata["data_quality"]["state"] == "provider_mismatch"
def test_get_signal_persists_research_provenance_on_success(monkeypatch, tmp_path):
class SuccessfulGraph:
def propagate(self, ticker, date):
return {
"investment_debate_state": {
"research_status": "degraded",
"research_mode": "degraded_synthesis",
"timed_out_nodes": ["Bull Researcher"],
"degraded_reason": "bull_researcher_timeout",
"covered_dimensions": ["market"],
"manager_confidence": None,
}
}, "BUY"
cfg = OrchestratorConfig(cache_dir=str(tmp_path))
runner = LLMRunner(cfg)
monkeypatch.setattr(runner, "_get_graph", lambda: SuccessfulGraph())
signal = runner.get_signal("AAPL", "2024-01-02")
assert signal.degraded is False
assert signal.metadata["research"]["research_status"] == "degraded"
assert signal.metadata["sample_quality"] == "degraded_research"
assert signal.metadata["data_quality"]["state"] == "research_degraded"

View File

@ -1,10 +1,10 @@
from typing import Annotated from typing import Annotated, Optional
from typing_extensions import TypedDict from typing_extensions import NotRequired, TypedDict
from langgraph.graph import MessagesState from langgraph.graph import MessagesState
# Researcher team state # Researcher team state
class InvestDebateState(TypedDict): class InvestDebateState(TypedDict, total=False):
bull_history: Annotated[ bull_history: Annotated[
str, "Bullish Conversation history" str, "Bullish Conversation history"
] # Bullish Conversation history ] # Bullish Conversation history
@ -15,6 +15,12 @@ class InvestDebateState(TypedDict):
current_response: Annotated[str, "Latest response"] # Last response current_response: Annotated[str, "Latest response"] # Last response
judge_decision: Annotated[str, "Final judge decision"] # Last response judge_decision: Annotated[str, "Final judge decision"] # Last response
count: Annotated[int, "Length of the current conversation"] # Conversation length count: Annotated[int, "Length of the current conversation"] # Conversation length
research_status: NotRequired[Annotated[str, "Research stage status: full/degraded/failed"]]
research_mode: NotRequired[Annotated[str, "Research mode: debate/degraded_synthesis"]]
timed_out_nodes: NotRequired[Annotated[list[str], "Research nodes that timed out"]]
degraded_reason: NotRequired[Annotated[Optional[str], "Research degradation reason"]]
covered_dimensions: NotRequired[Annotated[list[str], "Research dimensions covered so far"]]
manager_confidence: NotRequired[Annotated[Optional[float], "Research manager confidence"]]
# Risk management team state # Risk management team state

View File

@ -24,6 +24,7 @@ DEFAULT_CONFIG = {
"max_debate_rounds": 1, "max_debate_rounds": 1,
"max_risk_discuss_rounds": 1, "max_risk_discuss_rounds": 1,
"max_recur_limit": 100, "max_recur_limit": 100,
"research_node_timeout_secs": 30.0,
# Data vendor configuration # Data vendor configuration
# Category-level configuration (default for all tools in category) # Category-level configuration (default for all tools in category)
"data_vendors": { "data_vendors": {

View File

@ -31,6 +31,12 @@ class Propagator:
"current_response": "", "current_response": "",
"judge_decision": "", "judge_decision": "",
"count": 0, "count": 0,
"research_status": "full",
"research_mode": "debate",
"timed_out_nodes": [],
"degraded_reason": None,
"covered_dimensions": [],
"manager_confidence": None,
} }
), ),
"risk_debate_state": RiskDebateState( "risk_debate_state": RiskDebateState(

View File

@ -1,5 +1,7 @@
# TradingAgents/graph/setup.py # TradingAgents/graph/setup.py
import concurrent.futures
import time
from typing import Any, Dict from typing import Any, Dict
from langgraph.graph import END, START, StateGraph from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
@ -24,6 +26,7 @@ class GraphSetup:
invest_judge_memory, invest_judge_memory,
portfolio_manager_memory, portfolio_manager_memory,
conditional_logic: ConditionalLogic, conditional_logic: ConditionalLogic,
research_node_timeout_secs: float = 30.0,
): ):
"""Initialize with required components.""" """Initialize with required components."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm
@ -35,6 +38,7 @@ class GraphSetup:
self.invest_judge_memory = invest_judge_memory self.invest_judge_memory = invest_judge_memory
self.portfolio_manager_memory = portfolio_manager_memory self.portfolio_manager_memory = portfolio_manager_memory
self.conditional_logic = conditional_logic self.conditional_logic = conditional_logic
self.research_node_timeout_secs = research_node_timeout_secs
def setup_graph( def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"] self, selected_analysts=["market", "social", "news", "fundamentals"]
@ -85,13 +89,16 @@ class GraphSetup:
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
# Create researcher and manager nodes # Create researcher and manager nodes
bull_researcher_node = create_bull_researcher( bull_researcher_node = self._guard_research_node(
"Bull Researcher",
self.quick_thinking_llm, self.bull_memory self.quick_thinking_llm, self.bull_memory
) )
bear_researcher_node = create_bear_researcher( bear_researcher_node = self._guard_research_node(
"Bear Researcher",
self.quick_thinking_llm, self.bear_memory self.quick_thinking_llm, self.bear_memory
) )
research_manager_node = create_research_manager( research_manager_node = self._guard_research_node(
"Research Manager",
self.deep_thinking_llm, self.invest_judge_memory self.deep_thinking_llm, self.invest_judge_memory
) )
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
@ -199,3 +206,109 @@ class GraphSetup:
# Compile and return # Compile and return
return workflow.compile() return workflow.compile()
def _guard_research_node(self, node_name: str, llm: Any, memory):
if node_name == "Bull Researcher":
node = create_bull_researcher(llm, memory)
dimension = "bull"
elif node_name == "Bear Researcher":
node = create_bear_researcher(llm, memory)
dimension = "bear"
else:
node = create_research_manager(llm, memory)
dimension = "manager"
def wrapped(state):
started_at = time.time()
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = executor.submit(node, state)
try:
result = future.result(timeout=self.research_node_timeout_secs)
return self._apply_research_success(state, result, dimension)
except concurrent.futures.TimeoutError:
future.cancel()
executor.shutdown(wait=False, cancel_futures=True)
return self._apply_research_fallback(
state,
node_name=node_name,
dimension=dimension,
reason=f"{node_name.lower().replace(' ', '_')}_timeout",
started_at=started_at,
)
except Exception as exc:
executor.shutdown(wait=False, cancel_futures=True)
return self._apply_research_fallback(
state,
node_name=node_name,
dimension=dimension,
reason=f"{node_name.lower().replace(' ', '_')}_{type(exc).__name__.lower()}",
started_at=started_at,
)
finally:
executor.shutdown(wait=False, cancel_futures=True)
return wrapped
@staticmethod
def _provenance(state) -> dict:
debate_state = dict(state["investment_debate_state"])
return {
"research_status": debate_state.get("research_status", "full"),
"research_mode": debate_state.get("research_mode", "debate"),
"timed_out_nodes": list(debate_state.get("timed_out_nodes", [])),
"degraded_reason": debate_state.get("degraded_reason"),
"covered_dimensions": list(debate_state.get("covered_dimensions", [])),
"manager_confidence": debate_state.get("manager_confidence"),
}
def _apply_research_success(self, state, result: dict, dimension: str):
debate_state = dict(result.get("investment_debate_state") or state["investment_debate_state"])
provenance = self._provenance(state)
if dimension not in provenance["covered_dimensions"]:
provenance["covered_dimensions"].append(dimension)
if provenance["research_status"] == "full":
provenance["research_mode"] = "debate"
if dimension == "manager" and provenance["manager_confidence"] is None:
provenance["manager_confidence"] = 1.0 if provenance["research_status"] == "full" else 0.5
debate_state.update(provenance)
updated = dict(result)
updated["investment_debate_state"] = debate_state
return updated
def _apply_research_fallback(self, state, *, node_name: str, dimension: str, reason: str, started_at: float):
debate_state = dict(state["investment_debate_state"])
provenance = self._provenance(state)
provenance["research_status"] = "degraded"
provenance["research_mode"] = "degraded_synthesis"
provenance["degraded_reason"] = reason
if "timeout" in reason and node_name not in provenance["timed_out_nodes"]:
provenance["timed_out_nodes"].append(node_name)
elapsed_seconds = round(time.time() - started_at, 3)
if dimension == "manager":
provenance["manager_confidence"] = 0.0
fallback = (
"Recommendation: HOLD\n"
f"Top reasons: research degraded at {node_name} ({reason}); use partial research context cautiously.\n"
f"Simple execution plan: keep sizing conservative and wait for confirmation. Guard elapsed={elapsed_seconds}s."
)
debate_state["judge_decision"] = fallback
debate_state["current_response"] = fallback
debate_state.update(provenance)
return {
"investment_debate_state": debate_state,
"investment_plan": fallback,
}
prefix = "Bull Analyst" if dimension == "bull" else "Bear Analyst"
history_field = "bull_history" if dimension == "bull" else "bear_history"
degraded_argument = (
f"{prefix}: [DEGRADED] {node_name} unavailable ({reason}). "
f"Proceeding with partial research context. Guard elapsed={elapsed_seconds}s."
)
debate_state["history"] = debate_state.get("history", "") + "\n" + degraded_argument
debate_state[history_field] = debate_state.get(history_field, "") + "\n" + degraded_argument
debate_state["current_response"] = degraded_argument
debate_state["count"] = debate_state.get("count", 0) + 1
debate_state.update(provenance)
return {"investment_debate_state": debate_state}

View File

@ -144,6 +144,7 @@ class TradingAgentsGraph:
self.invest_judge_memory, self.invest_judge_memory,
self.portfolio_manager_memory, self.portfolio_manager_memory,
self.conditional_logic, self.conditional_logic,
research_node_timeout_secs=float(self.config.get("research_node_timeout_secs", 30.0)),
) )
self.propagator = Propagator() self.propagator = Propagator()

View File

@ -0,0 +1,127 @@
import time
import tradingagents.graph.setup as graph_setup_module
from tradingagents.graph.setup import GraphSetup
def _setup() -> GraphSetup:
return GraphSetup(
quick_thinking_llm=None,
deep_thinking_llm=None,
tool_nodes={},
bull_memory=None,
bear_memory=None,
trader_memory=None,
invest_judge_memory=None,
portfolio_manager_memory=None,
conditional_logic=None,
research_node_timeout_secs=0.01,
)
def test_manager_guard_fallback_marks_degraded_synthesis():
setup = _setup()
state = {
"investment_debate_state": {
"history": "Bull Analyst: case",
"bull_history": "Bull Analyst: case",
"bear_history": "",
"current_response": "Bull Analyst: case",
"judge_decision": "",
"count": 1,
"research_status": "full",
"research_mode": "debate",
"timed_out_nodes": [],
"degraded_reason": None,
"covered_dimensions": ["bull"],
"manager_confidence": None,
}
}
result = setup._apply_research_fallback(
state,
node_name="Research Manager",
dimension="manager",
reason="research_manager_timeout",
started_at=0.0,
)
debate = result["investment_debate_state"]
assert debate["research_status"] == "degraded"
assert debate["research_mode"] == "degraded_synthesis"
assert debate["timed_out_nodes"] == ["Research Manager"]
assert result["investment_plan"].startswith("Recommendation: HOLD")
def test_bull_guard_success_records_coverage():
setup = _setup()
state = {
"investment_debate_state": {
"history": "",
"bull_history": "",
"bear_history": "",
"current_response": "",
"judge_decision": "",
"count": 0,
"research_status": "full",
"research_mode": "debate",
"timed_out_nodes": [],
"degraded_reason": None,
"covered_dimensions": [],
"manager_confidence": None,
}
}
result = {
"investment_debate_state": {
"history": "Bull Analyst: ok",
"bull_history": "Bull Analyst: ok",
"bear_history": "",
"current_response": "Bull Analyst: ok",
"judge_decision": "",
"count": 1,
}
}
updated = setup._apply_research_success(state, result, dimension="bull")
debate = updated["investment_debate_state"]
assert debate["research_status"] == "full"
assert debate["research_mode"] == "debate"
assert debate["covered_dimensions"] == ["bull"]
def test_guard_timeout_returns_without_waiting_for_node_completion(monkeypatch):
def slow_bull(_llm, _memory):
def node(_state):
time.sleep(0.2)
return {"investment_debate_state": {"history": "", "bull_history": "", "bear_history": "", "current_response": "", "judge_decision": "", "count": 1}}
return node
monkeypatch.setattr(graph_setup_module, "create_bull_researcher", slow_bull)
setup = _setup()
wrapped = setup._guard_research_node("Bull Researcher", None, None)
state = {
"investment_debate_state": {
"history": "",
"bull_history": "",
"bear_history": "",
"current_response": "",
"judge_decision": "",
"count": 0,
"research_status": "full",
"research_mode": "debate",
"timed_out_nodes": [],
"degraded_reason": None,
"covered_dimensions": [],
"manager_confidence": None,
}
}
started = time.monotonic()
result = wrapped(state)
elapsed = time.monotonic() - started
assert elapsed < 0.1
debate = result["investment_debate_state"]
assert debate["research_status"] == "degraded"
assert debate["research_mode"] == "degraded_synthesis"
assert debate["timed_out_nodes"] == ["Bull Researcher"]