diff --git a/orchestrator/profile_ab.py b/orchestrator/profile_ab.py new file mode 100644 index 00000000..11917425 --- /dev/null +++ b/orchestrator/profile_ab.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import argparse +import json +from collections import Counter +from pathlib import Path +from statistics import median + +AB_SCHEMA_VERSION = "tradingagents.profile_ab.v1alpha1" + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Compare TradingAgents stage-profile traces for a minimal A/B workflow.", + ) + parser.add_argument("--a", nargs="+", required=True, help="Trace file(s) or directories for cohort A") + parser.add_argument("--b", nargs="+", required=True, help="Trace file(s) or directories for cohort B") + parser.add_argument("--label-a", default="A") + parser.add_argument("--label-b", default="B") + parser.add_argument("--output", help="Optional path to write the comparison JSON") + return parser + + +def _expand_inputs(items: list[str]) -> list[Path]: + files: list[Path] = [] + for item in items: + path = Path(item) + if path.is_dir(): + files.extend(sorted(candidate for candidate in path.glob("*.json") if candidate.is_file())) + elif path.is_file(): + files.append(path) + return files + + +def _load_trace(path: Path) -> dict: + data = json.loads(path.read_text()) + if not isinstance(data, dict): + raise ValueError(f"trace at {path} must be a JSON object") + payload = dict(data) + payload.setdefault("_source_path", str(path)) + return payload + + +def _phase_totals_ms(trace: dict) -> dict[str, int]: + summary = trace.get("summary") or {} + phase_totals = summary.get("phase_totals_seconds") or trace.get("phase_totals_seconds") or {} + return {str(key): int(round(float(value) * 1000)) for key, value in phase_totals.items()} + + +def summarize_traces(traces: list[dict], label: str) -> dict: + run_count = len(traces) + ok_runs = [trace for trace in traces if trace.get("status") == "ok"] + degraded_runs = [ + trace for trace in traces + if ((trace.get("summary") or {}).get("final_research_status") not in (None, "full")) + ] + total_elapsed = [int((trace.get("summary") or {}).get("total_elapsed_ms", 0)) for trace in ok_runs] + event_counts = [int((trace.get("summary") or {}).get("event_count", 0)) for trace in ok_runs] + status_counts = Counter(str(trace.get("status") or "unknown") for trace in traces) + schema_versions = sorted({str(trace.get("trace_schema_version") or "unknown") for trace in traces}) + source_files = sorted(str(trace.get("_source_path")) for trace in traces if trace.get("_source_path")) + + phase_values: dict[str, list[int]] = {} + for trace in ok_runs: + for phase, elapsed_ms in _phase_totals_ms(trace).items(): + phase_values.setdefault(phase, []).append(elapsed_ms) + + phase_medians = {phase: int(median(values)) for phase, values in sorted(phase_values.items()) if values} + variants = sorted({str(trace.get("variant_label") or label) for trace in traces}) + return { + "label": label, + "run_count": run_count, + "ok_count": len(ok_runs), + "error_count": run_count - len(ok_runs), + "degraded_run_count": len(degraded_runs), + "variants": variants, + "status_counts": dict(sorted(status_counts.items())), + "trace_schema_versions": schema_versions, + "source_files": source_files, + "median_total_elapsed_ms": int(median(total_elapsed)) if total_elapsed else None, + "median_event_count": int(median(event_counts)) if event_counts else None, + "median_phase_elapsed_ms": phase_medians, + } + + +def compare_summaries(summary_a: dict, summary_b: dict) -> dict: + total_a = summary_a.get("median_total_elapsed_ms") + total_b = summary_b.get("median_total_elapsed_ms") + degraded_a = summary_a.get("degraded_run_count", 0) + degraded_b = summary_b.get("degraded_run_count", 0) + error_a = summary_a.get("error_count", 0) + error_b = summary_b.get("error_count", 0) + + faster = None + if total_a is not None and total_b is not None: + if total_a < total_b: + faster = summary_a["label"] + elif total_b < total_a: + faster = summary_b["label"] + + lower_degradation = None + if degraded_a < degraded_b: + lower_degradation = summary_a["label"] + elif degraded_b < degraded_a: + lower_degradation = summary_b["label"] + + lower_error_rate = None + if error_a < error_b: + lower_error_rate = summary_a["label"] + elif error_b < error_a: + lower_error_rate = summary_b["label"] + + recommended = None + if faster == summary_a["label"] and lower_degradation in (None, summary_a["label"]) and lower_error_rate in (None, summary_a["label"]): + recommended = summary_a["label"] + elif faster == summary_b["label"] and lower_degradation in (None, summary_b["label"]) and lower_error_rate in (None, summary_b["label"]): + recommended = summary_b["label"] + elif lower_degradation == summary_a["label"] and total_a == total_b and lower_error_rate in (None, summary_a["label"]): + recommended = summary_a["label"] + elif lower_degradation == summary_b["label"] and total_a == total_b and lower_error_rate in (None, summary_b["label"]): + recommended = summary_b["label"] + + return { + "faster_label": faster, + "lower_degradation_label": lower_degradation, + "lower_error_rate_label": lower_error_rate, + "recommended_label": recommended, + } + + +def build_comparison(traces_a: list[dict], traces_b: list[dict], *, label_a: str, label_b: str) -> dict: + summary_a = summarize_traces(traces_a, label_a) + summary_b = summarize_traces(traces_b, label_b) + return { + "schema_version": AB_SCHEMA_VERSION, + "cohorts": { + label_a: summary_a, + label_b: summary_b, + }, + "comparison": compare_summaries(summary_a, summary_b), + } + + +def main() -> None: + args = build_parser().parse_args() + files_a = _expand_inputs(args.a) + files_b = _expand_inputs(args.b) + if not files_a: + raise SystemExit("no trace files found for cohort A") + if not files_b: + raise SystemExit("no trace files found for cohort B") + + traces_a = [_load_trace(path) for path in files_a] + traces_b = [_load_trace(path) for path in files_b] + payload = build_comparison(traces_a, traces_b, label_a=args.label_a, label_b=args.label_b) + + rendered = json.dumps(payload, ensure_ascii=False, indent=2) + if args.output: + Path(args.output).write_text(rendered) + print(rendered) + + +if __name__ == "__main__": + main() diff --git a/orchestrator/tests/test_profile_ab.py b/orchestrator/tests/test_profile_ab.py new file mode 100644 index 00000000..34fbdada --- /dev/null +++ b/orchestrator/tests/test_profile_ab.py @@ -0,0 +1,58 @@ +from orchestrator.profile_ab import build_comparison +from orchestrator.profile_trace_utils import build_trace_summary + + +def test_build_trace_summary_counts_degraded_events(): + summary = build_trace_summary( + [ + {"nodes": ["Market Analyst"], "elapsed_ms": 110, "research_status": None, "degraded_reason": None}, + {"nodes": ["Bull Researcher"], "elapsed_ms": 220, "research_status": "degraded", "degraded_reason": "bull_timeout"}, + ], + {"analyst": 0.11, "research": 0.22}, + ) + + assert summary["event_count"] == 2 + assert summary["total_elapsed_ms"] == 330 + assert summary["degraded_event_count"] == 1 + assert summary["final_research_status"] == "degraded" + assert summary["node_hit_count"]["Bull Researcher"] == 1 + + +def test_build_comparison_prefers_faster_less_degraded_cohort(): + traces_a = [ + { + "status": "ok", + "trace_schema_version": "tradingagents.profile_trace.v1alpha1", + "_source_path": "/tmp/a.json", + "variant_label": "compact", + "summary": { + "total_elapsed_ms": 450, + "event_count": 4, + "final_research_status": "full", + "phase_totals_seconds": {"research": 0.22, "risk": 0.10}, + }, + } + ] + traces_b = [ + { + "status": "ok", + "trace_schema_version": "tradingagents.profile_trace.v1alpha1", + "_source_path": "/tmp/b.json", + "variant_label": "verbose", + "summary": { + "total_elapsed_ms": 700, + "event_count": 5, + "final_research_status": "degraded", + "phase_totals_seconds": {"research": 0.45, "risk": 0.15}, + }, + } + ] + + payload = build_comparison(traces_a, traces_b, label_a="A", label_b="B") + + assert payload["cohorts"]["A"]["median_total_elapsed_ms"] == 450 + assert payload["cohorts"]["A"]["trace_schema_versions"] == ["tradingagents.profile_trace.v1alpha1"] + assert payload["cohorts"]["B"]["degraded_run_count"] == 1 + assert payload["comparison"]["faster_label"] == "A" + assert payload["comparison"]["lower_error_rate_label"] is None + assert payload["comparison"]["recommended_label"] == "A"