165 lines
6.4 KiB
Python
165 lines
6.4 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
from statistics import median
|
|
|
|
AB_SCHEMA_VERSION = "tradingagents.profile_ab.v1alpha1"
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(
|
|
description="Compare TradingAgents stage-profile traces for a minimal A/B workflow.",
|
|
)
|
|
parser.add_argument("--a", nargs="+", required=True, help="Trace file(s) or directories for cohort A")
|
|
parser.add_argument("--b", nargs="+", required=True, help="Trace file(s) or directories for cohort B")
|
|
parser.add_argument("--label-a", default="A")
|
|
parser.add_argument("--label-b", default="B")
|
|
parser.add_argument("--output", help="Optional path to write the comparison JSON")
|
|
return parser
|
|
|
|
|
|
def _expand_inputs(items: list[str]) -> list[Path]:
|
|
files: list[Path] = []
|
|
for item in items:
|
|
path = Path(item)
|
|
if path.is_dir():
|
|
files.extend(sorted(candidate for candidate in path.glob("*.json") if candidate.is_file()))
|
|
elif path.is_file():
|
|
files.append(path)
|
|
return files
|
|
|
|
|
|
def _load_trace(path: Path) -> dict:
|
|
data = json.loads(path.read_text())
|
|
if not isinstance(data, dict):
|
|
raise ValueError(f"trace at {path} must be a JSON object")
|
|
payload = dict(data)
|
|
payload.setdefault("_source_path", str(path))
|
|
return payload
|
|
|
|
|
|
def _phase_totals_ms(trace: dict) -> dict[str, int]:
|
|
summary = trace.get("summary") or {}
|
|
phase_totals = summary.get("phase_totals_seconds") or trace.get("phase_totals_seconds") or {}
|
|
return {str(key): int(round(float(value) * 1000)) for key, value in phase_totals.items()}
|
|
|
|
|
|
def summarize_traces(traces: list[dict], label: str) -> dict:
|
|
run_count = len(traces)
|
|
ok_runs = [trace for trace in traces if trace.get("status") == "ok"]
|
|
degraded_runs = [
|
|
trace for trace in traces
|
|
if ((trace.get("summary") or {}).get("final_research_status") not in (None, "full"))
|
|
]
|
|
total_elapsed = [int((trace.get("summary") or {}).get("total_elapsed_ms", 0)) for trace in ok_runs]
|
|
event_counts = [int((trace.get("summary") or {}).get("event_count", 0)) for trace in ok_runs]
|
|
status_counts = Counter(str(trace.get("status") or "unknown") for trace in traces)
|
|
schema_versions = sorted({str(trace.get("trace_schema_version") or "unknown") for trace in traces})
|
|
source_files = sorted(str(trace.get("_source_path")) for trace in traces if trace.get("_source_path"))
|
|
|
|
phase_values: dict[str, list[int]] = {}
|
|
for trace in ok_runs:
|
|
for phase, elapsed_ms in _phase_totals_ms(trace).items():
|
|
phase_values.setdefault(phase, []).append(elapsed_ms)
|
|
|
|
phase_medians = {phase: int(median(values)) for phase, values in sorted(phase_values.items()) if values}
|
|
variants = sorted({str(trace.get("variant_label") or label) for trace in traces})
|
|
return {
|
|
"label": label,
|
|
"run_count": run_count,
|
|
"ok_count": len(ok_runs),
|
|
"error_count": run_count - len(ok_runs),
|
|
"degraded_run_count": len(degraded_runs),
|
|
"variants": variants,
|
|
"status_counts": dict(sorted(status_counts.items())),
|
|
"trace_schema_versions": schema_versions,
|
|
"source_files": source_files,
|
|
"median_total_elapsed_ms": int(median(total_elapsed)) if total_elapsed else None,
|
|
"median_event_count": int(median(event_counts)) if event_counts else None,
|
|
"median_phase_elapsed_ms": phase_medians,
|
|
}
|
|
|
|
|
|
def compare_summaries(summary_a: dict, summary_b: dict) -> dict:
|
|
total_a = summary_a.get("median_total_elapsed_ms")
|
|
total_b = summary_b.get("median_total_elapsed_ms")
|
|
degraded_a = summary_a.get("degraded_run_count", 0)
|
|
degraded_b = summary_b.get("degraded_run_count", 0)
|
|
error_a = summary_a.get("error_count", 0)
|
|
error_b = summary_b.get("error_count", 0)
|
|
|
|
faster = None
|
|
if total_a is not None and total_b is not None:
|
|
if total_a < total_b:
|
|
faster = summary_a["label"]
|
|
elif total_b < total_a:
|
|
faster = summary_b["label"]
|
|
|
|
lower_degradation = None
|
|
if degraded_a < degraded_b:
|
|
lower_degradation = summary_a["label"]
|
|
elif degraded_b < degraded_a:
|
|
lower_degradation = summary_b["label"]
|
|
|
|
lower_error_rate = None
|
|
if error_a < error_b:
|
|
lower_error_rate = summary_a["label"]
|
|
elif error_b < error_a:
|
|
lower_error_rate = summary_b["label"]
|
|
|
|
recommended = None
|
|
if faster == summary_a["label"] and lower_degradation in (None, summary_a["label"]) and lower_error_rate in (None, summary_a["label"]):
|
|
recommended = summary_a["label"]
|
|
elif faster == summary_b["label"] and lower_degradation in (None, summary_b["label"]) and lower_error_rate in (None, summary_b["label"]):
|
|
recommended = summary_b["label"]
|
|
elif lower_degradation == summary_a["label"] and total_a == total_b and lower_error_rate in (None, summary_a["label"]):
|
|
recommended = summary_a["label"]
|
|
elif lower_degradation == summary_b["label"] and total_a == total_b and lower_error_rate in (None, summary_b["label"]):
|
|
recommended = summary_b["label"]
|
|
|
|
return {
|
|
"faster_label": faster,
|
|
"lower_degradation_label": lower_degradation,
|
|
"lower_error_rate_label": lower_error_rate,
|
|
"recommended_label": recommended,
|
|
}
|
|
|
|
|
|
def build_comparison(traces_a: list[dict], traces_b: list[dict], *, label_a: str, label_b: str) -> dict:
|
|
summary_a = summarize_traces(traces_a, label_a)
|
|
summary_b = summarize_traces(traces_b, label_b)
|
|
return {
|
|
"schema_version": AB_SCHEMA_VERSION,
|
|
"cohorts": {
|
|
label_a: summary_a,
|
|
label_b: summary_b,
|
|
},
|
|
"comparison": compare_summaries(summary_a, summary_b),
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
args = build_parser().parse_args()
|
|
files_a = _expand_inputs(args.a)
|
|
files_b = _expand_inputs(args.b)
|
|
if not files_a:
|
|
raise SystemExit("no trace files found for cohort A")
|
|
if not files_b:
|
|
raise SystemExit("no trace files found for cohort B")
|
|
|
|
traces_a = [_load_trace(path) for path in files_a]
|
|
traces_b = [_load_trace(path) for path in files_b]
|
|
payload = build_comparison(traces_a, traces_b, label_a=args.label_a, label_b=args.label_b)
|
|
|
|
rendered = json.dumps(payload, ensure_ascii=False, indent=2)
|
|
if args.output:
|
|
Path(args.output).write_text(rendered)
|
|
print(rendered)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|