diff --git a/orchestrator/profile_stage_chain.py b/orchestrator/profile_stage_chain.py new file mode 100644 index 00000000..68fad753 --- /dev/null +++ b/orchestrator/profile_stage_chain.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import argparse +import json +import signal +import time +from collections import defaultdict + +from tradingagents.graph.propagation import Propagator +from tradingagents.graph.trading_graph import TradingAgentsGraph + +_PHASE_MAP = { + "Market Analyst": "analyst", + "Bull Researcher": "research", + "Bear Researcher": "research", + "Research Manager": "research", + "Trader": "trading", + "Aggressive Analyst": "risk", + "Conservative Analyst": "risk", + "Neutral Analyst": "risk", + "Portfolio Manager": "portfolio", +} + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Profile TradingAgents graph stage timings.") + parser.add_argument("--ticker", required=True) + parser.add_argument("--date", required=True) + parser.add_argument("--provider", default="anthropic") + parser.add_argument("--model", default="MiniMax-M2.7-highspeed") + parser.add_argument("--base-url", default="https://api.minimaxi.com/anthropic") + parser.add_argument("--timeout", type=float, default=45.0) + parser.add_argument("--max-retries", type=int, default=0) + parser.add_argument("--analysis-prompt-style", default="compact") + parser.add_argument("--selected-analysts", default="market") + parser.add_argument("--overall-timeout", type=int, default=120) + return parser + + +class _ProfileTimeout(Exception): + pass + + +def main() -> None: + args = build_parser().parse_args() + selected_analysts = [item.strip() for item in args.selected_analysts.split(",") if item.strip()] + config = { + "llm_provider": args.provider, + "deep_think_llm": args.model, + "quick_think_llm": args.model, + "backend_url": args.base_url, + "selected_analysts": selected_analysts, + "analysis_prompt_style": args.analysis_prompt_style, + "llm_timeout": args.timeout, + "llm_max_retries": args.max_retries, + "max_debate_rounds": 1, + "max_risk_discuss_rounds": 1, + } + + graph = TradingAgentsGraph(selected_analysts=selected_analysts, config=config) + state = Propagator().create_initial_state(args.ticker, args.date) + config_kwargs = {"recursion_limit": 100, "max_concurrency": 1} + + node_timings = [] + phase_totals = defaultdict(float) + started_at = time.monotonic() + last_at = started_at + + def alarm_handler(signum, frame): + raise _ProfileTimeout(f"profiling timeout after {args.overall_timeout}s") + + signal.signal(signal.SIGALRM, alarm_handler) + signal.alarm(args.overall_timeout) + + try: + for event in graph.graph.stream(state, stream_mode="updates", config=config_kwargs): + now = time.monotonic() + nodes = list(event.keys()) + phases = sorted({_PHASE_MAP.get(node, "unknown") for node in nodes}) + delta = round(now - last_at, 3) + entry = { + "nodes": nodes, + "phases": phases, + "delta_seconds": delta, + "elapsed_seconds": round(now - started_at, 3), + } + node_timings.append(entry) + for phase in phases: + phase_totals[phase] += delta + last_at = now + + payload = { + "status": "ok", + "ticker": args.ticker, + "date": args.date, + "selected_analysts": selected_analysts, + "analysis_prompt_style": args.analysis_prompt_style, + "node_timings": node_timings, + "phase_totals_seconds": {key: round(value, 3) for key, value in phase_totals.items()}, + } + except Exception as exc: + payload = { + "status": "error", + "ticker": args.ticker, + "date": args.date, + "selected_analysts": selected_analysts, + "analysis_prompt_style": args.analysis_prompt_style, + "error": str(exc), + "node_timings": node_timings, + "phase_totals_seconds": {key: round(value, 3) for key, value in phase_totals.items()}, + } + finally: + signal.alarm(0) + + print(json.dumps(payload, ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/web_dashboard/backend/main.py b/web_dashboard/backend/main.py index a741b77e..9746c475 100644 --- a/web_dashboard/backend/main.py +++ b/web_dashboard/backend/main.py @@ -228,11 +228,20 @@ def _resolve_analysis_runtime_settings() -> dict: or os.environ.get("TRADINGAGENTS_MODEL") or defaults.get("quick_think_llm") ) + selected_analysts_raw = os.environ.get("TRADINGAGENTS_SELECTED_ANALYSTS", "market") + selected_analysts = [item.strip() for item in selected_analysts_raw.split(",") if item.strip()] + analysis_prompt_style = os.environ.get("TRADINGAGENTS_ANALYSIS_PROMPT_STYLE", "compact") + llm_timeout = float(os.environ.get("TRADINGAGENTS_LLM_TIMEOUT", "45")) + llm_max_retries = int(os.environ.get("TRADINGAGENTS_LLM_MAX_RETRIES", "0")) return { "llm_provider": provider, "backend_url": backend_url, "deep_think_llm": deep_model, "quick_think_llm": quick_model, + "selected_analysts": selected_analysts, + "analysis_prompt_style": analysis_prompt_style, + "llm_timeout": llm_timeout, + "llm_max_retries": llm_max_retries, "provider_api_key": _get_analysis_provider_api_key(provider, saved.get("api_key")), } @@ -247,6 +256,10 @@ def _build_analysis_request_context(request: Request, auth_key: Optional[str]): backend_url=settings["backend_url"], deep_think_llm=settings["deep_think_llm"], quick_think_llm=settings["quick_think_llm"], + selected_analysts=settings["selected_analysts"], + analysis_prompt_style=settings["analysis_prompt_style"], + llm_timeout=settings["llm_timeout"], + llm_max_retries=settings["llm_max_retries"], ) diff --git a/web_dashboard/backend/services/executor.py b/web_dashboard/backend/services/executor.py index 84431cbf..45c70ae2 100644 --- a/web_dashboard/backend/services/executor.py +++ b/web_dashboard/backend/services/executor.py @@ -60,7 +60,16 @@ if os.environ.get("TRADINGAGENTS_DEEP_MODEL"): trading_config["deep_think_llm"] = os.environ["TRADINGAGENTS_DEEP_MODEL"] if os.environ.get("TRADINGAGENTS_QUICK_MODEL"): trading_config["quick_think_llm"] = os.environ["TRADINGAGENTS_QUICK_MODEL"] - +if os.environ.get("TRADINGAGENTS_SELECTED_ANALYSTS"): + trading_config["selected_analysts"] = [ + item.strip() for item in os.environ["TRADINGAGENTS_SELECTED_ANALYSTS"].split(",") if item.strip() + ] +if os.environ.get("TRADINGAGENTS_ANALYSIS_PROMPT_STYLE"): + trading_config["analysis_prompt_style"] = os.environ["TRADINGAGENTS_ANALYSIS_PROMPT_STYLE"] +if os.environ.get("TRADINGAGENTS_LLM_TIMEOUT"): + trading_config["llm_timeout"] = float(os.environ["TRADINGAGENTS_LLM_TIMEOUT"]) +if os.environ.get("TRADINGAGENTS_LLM_MAX_RETRIES"): + trading_config["llm_max_retries"] = int(os.environ["TRADINGAGENTS_LLM_MAX_RETRIES"]) print("STAGE:analysts", flush=True) print("STAGE:research", flush=True) @@ -305,6 +314,14 @@ class LegacySubprocessAnalysisExecutor: clean_env["TRADINGAGENTS_DEEP_MODEL"] = request_context.deep_think_llm if request_context.quick_think_llm: clean_env["TRADINGAGENTS_QUICK_MODEL"] = request_context.quick_think_llm + if request_context.selected_analysts: + clean_env["TRADINGAGENTS_SELECTED_ANALYSTS"] = ",".join(request_context.selected_analysts) + if request_context.analysis_prompt_style: + clean_env["TRADINGAGENTS_ANALYSIS_PROMPT_STYLE"] = request_context.analysis_prompt_style + if request_context.llm_timeout is not None: + clean_env["TRADINGAGENTS_LLM_TIMEOUT"] = str(request_context.llm_timeout) + if request_context.llm_max_retries is not None: + clean_env["TRADINGAGENTS_LLM_MAX_RETRIES"] = str(request_context.llm_max_retries) for env_name in self._provider_api_env_names(llm_provider): if analysis_api_key: clean_env[env_name] = analysis_api_key diff --git a/web_dashboard/backend/services/request_context.py b/web_dashboard/backend/services/request_context.py index b3824701..b06d25db 100644 --- a/web_dashboard/backend/services/request_context.py +++ b/web_dashboard/backend/services/request_context.py @@ -24,6 +24,10 @@ class RequestContext: backend_url: Optional[str] = None deep_think_llm: Optional[str] = None quick_think_llm: Optional[str] = None + selected_analysts: tuple[str, ...] = () + analysis_prompt_style: Optional[str] = None + llm_timeout: Optional[float] = None + llm_max_retries: Optional[int] = None client_host: Optional[str] = None is_local: bool = False metadata: dict[str, str] = field(default_factory=dict) @@ -38,6 +42,10 @@ def build_request_context( backend_url: Optional[str] = None, deep_think_llm: Optional[str] = None, quick_think_llm: Optional[str] = None, + selected_analysts: Optional[list[str] | tuple[str, ...]] = None, + analysis_prompt_style: Optional[str] = None, + llm_timeout: Optional[float] = None, + llm_max_retries: Optional[int] = None, request_id: Optional[str] = None, contract_version: str = CONTRACT_VERSION, executor_type: str = DEFAULT_EXECUTOR_TYPE, @@ -56,6 +64,10 @@ def build_request_context( backend_url=backend_url, deep_think_llm=deep_think_llm, quick_think_llm=quick_think_llm, + selected_analysts=tuple(selected_analysts or ()), + analysis_prompt_style=analysis_prompt_style, + llm_timeout=llm_timeout, + llm_max_retries=llm_max_retries, client_host=client_host, is_local=is_local, metadata=dict(metadata or {}), diff --git a/web_dashboard/backend/tests/test_executors.py b/web_dashboard/backend/tests/test_executors.py index fe6b4df1..623a9210 100644 --- a/web_dashboard/backend/tests/test_executors.py +++ b/web_dashboard/backend/tests/test_executors.py @@ -197,6 +197,10 @@ def test_executor_injects_provider_specific_env(monkeypatch): backend_url="https://api.openai.com/v1", deep_think_llm="gpt-5.4", quick_think_llm="gpt-5.4-mini", + selected_analysts=["market"], + analysis_prompt_style="compact", + llm_timeout=45, + llm_max_retries=0, ), ) @@ -205,6 +209,10 @@ def test_executor_injects_provider_specific_env(monkeypatch): assert captured["env"]["TRADINGAGENTS_LLM_PROVIDER"] == "openai" assert captured["env"]["TRADINGAGENTS_BACKEND_URL"] == "https://api.openai.com/v1" assert captured["env"]["OPENAI_API_KEY"] == "provider-key" + assert captured["env"]["TRADINGAGENTS_SELECTED_ANALYSTS"] == "market" + assert captured["env"]["TRADINGAGENTS_ANALYSIS_PROMPT_STYLE"] == "compact" + assert captured["env"]["TRADINGAGENTS_LLM_TIMEOUT"] == "45" + assert captured["env"]["TRADINGAGENTS_LLM_MAX_RETRIES"] == "0" assert "ANTHROPIC_API_KEY" not in captured["env"] diff --git a/web_dashboard/backend/tests/test_services_migration.py b/web_dashboard/backend/tests/test_services_migration.py index f2e9df30..35bfa9d9 100644 --- a/web_dashboard/backend/tests/test_services_migration.py +++ b/web_dashboard/backend/tests/test_services_migration.py @@ -52,6 +52,10 @@ def test_build_request_context_defaults(): backend_url="https://api.minimaxi.com/anthropic", deep_think_llm="MiniMax-M2.7-highspeed", quick_think_llm="MiniMax-M2.7-highspeed", + selected_analysts=["market"], + analysis_prompt_style="compact", + llm_timeout=45, + llm_max_retries=0, metadata={"source": "test"}, ) @@ -59,6 +63,10 @@ def test_build_request_context_defaults(): assert context.provider_api_key == "provider-secret" assert context.llm_provider == "anthropic" assert context.backend_url == "https://api.minimaxi.com/anthropic" + assert context.selected_analysts == ("market",) + assert context.analysis_prompt_style == "compact" + assert context.llm_timeout == 45 + assert context.llm_max_retries == 0 assert context.request_id assert context.contract_version == "v1alpha1" assert context.executor_type == "legacy_subprocess" @@ -225,6 +233,10 @@ def test_analysis_service_start_analysis_uses_executor(tmp_path): provider_api_key="provider-secret", llm_provider="anthropic", backend_url="https://api.minimaxi.com/anthropic", + selected_analysts=["market"], + analysis_prompt_style="compact", + llm_timeout=45, + llm_max_retries=0, ), broadcast_progress=_broadcast, )