Make A/B trace comparisons easier to trust during profiling

The minimal offline harness now carries forward source-file and trace-schema
metadata, and it can break ties using error counts instead of only elapsed
runtime and degraded-research totals. This keeps Phase 1-4 profile comparisons
self-describing when multiple dumps are aggregated.

Constraint: Keep the harness offline and avoid changing the default runtime path
Rejected: Add a live dual-run executor | would couple profiling to external LLM calls and increase risk
Confidence: high
Scope-risk: narrow
Directive: Preserve the trace dump shape as the source of truth for future comparison tooling
Tested: uv run python inline assertions for orchestrator.tests.test_profile_ab
Tested: uv run python CLI smoke test for orchestrator.profile_ab with temp traces
Tested: uv run python -m compileall orchestrator/profile_stage_chain.py orchestrator/profile_trace_utils.py orchestrator/profile_ab.py orchestrator/tests/test_profile_ab.py
This commit is contained in:
陈少杰 2026-04-14 04:50:56 +08:00
parent 5aa0091773
commit a81f825203
2 changed files with 222 additions and 0 deletions

164
orchestrator/profile_ab.py Normal file
View File

@ -0,0 +1,164 @@
from __future__ import annotations
import argparse
import json
from collections import Counter
from pathlib import Path
from statistics import median
AB_SCHEMA_VERSION = "tradingagents.profile_ab.v1alpha1"
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Compare TradingAgents stage-profile traces for a minimal A/B workflow.",
)
parser.add_argument("--a", nargs="+", required=True, help="Trace file(s) or directories for cohort A")
parser.add_argument("--b", nargs="+", required=True, help="Trace file(s) or directories for cohort B")
parser.add_argument("--label-a", default="A")
parser.add_argument("--label-b", default="B")
parser.add_argument("--output", help="Optional path to write the comparison JSON")
return parser
def _expand_inputs(items: list[str]) -> list[Path]:
files: list[Path] = []
for item in items:
path = Path(item)
if path.is_dir():
files.extend(sorted(candidate for candidate in path.glob("*.json") if candidate.is_file()))
elif path.is_file():
files.append(path)
return files
def _load_trace(path: Path) -> dict:
data = json.loads(path.read_text())
if not isinstance(data, dict):
raise ValueError(f"trace at {path} must be a JSON object")
payload = dict(data)
payload.setdefault("_source_path", str(path))
return payload
def _phase_totals_ms(trace: dict) -> dict[str, int]:
summary = trace.get("summary") or {}
phase_totals = summary.get("phase_totals_seconds") or trace.get("phase_totals_seconds") or {}
return {str(key): int(round(float(value) * 1000)) for key, value in phase_totals.items()}
def summarize_traces(traces: list[dict], label: str) -> dict:
run_count = len(traces)
ok_runs = [trace for trace in traces if trace.get("status") == "ok"]
degraded_runs = [
trace for trace in traces
if ((trace.get("summary") or {}).get("final_research_status") not in (None, "full"))
]
total_elapsed = [int((trace.get("summary") or {}).get("total_elapsed_ms", 0)) for trace in ok_runs]
event_counts = [int((trace.get("summary") or {}).get("event_count", 0)) for trace in ok_runs]
status_counts = Counter(str(trace.get("status") or "unknown") for trace in traces)
schema_versions = sorted({str(trace.get("trace_schema_version") or "unknown") for trace in traces})
source_files = sorted(str(trace.get("_source_path")) for trace in traces if trace.get("_source_path"))
phase_values: dict[str, list[int]] = {}
for trace in ok_runs:
for phase, elapsed_ms in _phase_totals_ms(trace).items():
phase_values.setdefault(phase, []).append(elapsed_ms)
phase_medians = {phase: int(median(values)) for phase, values in sorted(phase_values.items()) if values}
variants = sorted({str(trace.get("variant_label") or label) for trace in traces})
return {
"label": label,
"run_count": run_count,
"ok_count": len(ok_runs),
"error_count": run_count - len(ok_runs),
"degraded_run_count": len(degraded_runs),
"variants": variants,
"status_counts": dict(sorted(status_counts.items())),
"trace_schema_versions": schema_versions,
"source_files": source_files,
"median_total_elapsed_ms": int(median(total_elapsed)) if total_elapsed else None,
"median_event_count": int(median(event_counts)) if event_counts else None,
"median_phase_elapsed_ms": phase_medians,
}
def compare_summaries(summary_a: dict, summary_b: dict) -> dict:
total_a = summary_a.get("median_total_elapsed_ms")
total_b = summary_b.get("median_total_elapsed_ms")
degraded_a = summary_a.get("degraded_run_count", 0)
degraded_b = summary_b.get("degraded_run_count", 0)
error_a = summary_a.get("error_count", 0)
error_b = summary_b.get("error_count", 0)
faster = None
if total_a is not None and total_b is not None:
if total_a < total_b:
faster = summary_a["label"]
elif total_b < total_a:
faster = summary_b["label"]
lower_degradation = None
if degraded_a < degraded_b:
lower_degradation = summary_a["label"]
elif degraded_b < degraded_a:
lower_degradation = summary_b["label"]
lower_error_rate = None
if error_a < error_b:
lower_error_rate = summary_a["label"]
elif error_b < error_a:
lower_error_rate = summary_b["label"]
recommended = None
if faster == summary_a["label"] and lower_degradation in (None, summary_a["label"]) and lower_error_rate in (None, summary_a["label"]):
recommended = summary_a["label"]
elif faster == summary_b["label"] and lower_degradation in (None, summary_b["label"]) and lower_error_rate in (None, summary_b["label"]):
recommended = summary_b["label"]
elif lower_degradation == summary_a["label"] and total_a == total_b and lower_error_rate in (None, summary_a["label"]):
recommended = summary_a["label"]
elif lower_degradation == summary_b["label"] and total_a == total_b and lower_error_rate in (None, summary_b["label"]):
recommended = summary_b["label"]
return {
"faster_label": faster,
"lower_degradation_label": lower_degradation,
"lower_error_rate_label": lower_error_rate,
"recommended_label": recommended,
}
def build_comparison(traces_a: list[dict], traces_b: list[dict], *, label_a: str, label_b: str) -> dict:
summary_a = summarize_traces(traces_a, label_a)
summary_b = summarize_traces(traces_b, label_b)
return {
"schema_version": AB_SCHEMA_VERSION,
"cohorts": {
label_a: summary_a,
label_b: summary_b,
},
"comparison": compare_summaries(summary_a, summary_b),
}
def main() -> None:
args = build_parser().parse_args()
files_a = _expand_inputs(args.a)
files_b = _expand_inputs(args.b)
if not files_a:
raise SystemExit("no trace files found for cohort A")
if not files_b:
raise SystemExit("no trace files found for cohort B")
traces_a = [_load_trace(path) for path in files_a]
traces_b = [_load_trace(path) for path in files_b]
payload = build_comparison(traces_a, traces_b, label_a=args.label_a, label_b=args.label_b)
rendered = json.dumps(payload, ensure_ascii=False, indent=2)
if args.output:
Path(args.output).write_text(rendered)
print(rendered)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,58 @@
from orchestrator.profile_ab import build_comparison
from orchestrator.profile_trace_utils import build_trace_summary
def test_build_trace_summary_counts_degraded_events():
summary = build_trace_summary(
[
{"nodes": ["Market Analyst"], "elapsed_ms": 110, "research_status": None, "degraded_reason": None},
{"nodes": ["Bull Researcher"], "elapsed_ms": 220, "research_status": "degraded", "degraded_reason": "bull_timeout"},
],
{"analyst": 0.11, "research": 0.22},
)
assert summary["event_count"] == 2
assert summary["total_elapsed_ms"] == 330
assert summary["degraded_event_count"] == 1
assert summary["final_research_status"] == "degraded"
assert summary["node_hit_count"]["Bull Researcher"] == 1
def test_build_comparison_prefers_faster_less_degraded_cohort():
traces_a = [
{
"status": "ok",
"trace_schema_version": "tradingagents.profile_trace.v1alpha1",
"_source_path": "/tmp/a.json",
"variant_label": "compact",
"summary": {
"total_elapsed_ms": 450,
"event_count": 4,
"final_research_status": "full",
"phase_totals_seconds": {"research": 0.22, "risk": 0.10},
},
}
]
traces_b = [
{
"status": "ok",
"trace_schema_version": "tradingagents.profile_trace.v1alpha1",
"_source_path": "/tmp/b.json",
"variant_label": "verbose",
"summary": {
"total_elapsed_ms": 700,
"event_count": 5,
"final_research_status": "degraded",
"phase_totals_seconds": {"research": 0.45, "risk": 0.15},
},
}
]
payload = build_comparison(traces_a, traces_b, label_a="A", label_b="B")
assert payload["cohorts"]["A"]["median_total_elapsed_ms"] == 450
assert payload["cohorts"]["A"]["trace_schema_versions"] == ["tradingagents.profile_trace.v1alpha1"]
assert payload["cohorts"]["B"]["degraded_run_count"] == 1
assert payload["comparison"]["faster_label"] == "A"
assert payload["comparison"]["lower_error_rate_label"] is None
assert payload["comparison"]["recommended_label"] == "A"