From cdeeba91d204c331787be36aa6292d5da7817c00 Mon Sep 17 00:00:00 2001 From: JackChen Date: Sun, 5 Apr 2026 01:45:35 +0800 Subject: [PATCH] feat(orchestrator): add onApproval callback for human-in-the-loop (#32) Add an optional `onApproval` callback to OrchestratorConfig that gates between task execution rounds. After each batch of parallel tasks completes, the callback receives the completed tasks and the tasks about to start, returning true to continue or false to abort gracefully. Key changes: - Add 'skipped' to TaskStatus for user-initiated abort (distinct from 'failed') - Add skip(), skipRemaining(), cascadeSkip() to TaskQueue - Add 'task_skipped' to OrchestratorEvent for progress monitoring - Approval gate in executeQueue() with try/catch for callback errors - Synthesis prompt now includes skipped tasks section - 17 new tests covering queue skip operations and orchestrator integration Closes #32 --- src/orchestrator/orchestrator.ts | 55 ++++- src/task/queue.ts | 69 +++++- src/types.ts | 14 +- tests/approval.test.ts | 356 +++++++++++++++++++++++++++++++ 4 files changed, 486 insertions(+), 8 deletions(-) create mode 100644 tests/approval.test.ts diff --git a/src/orchestrator/orchestrator.ts b/src/orchestrator/orchestrator.ts index 86f16c0..e04aa2c 100644 --- a/src/orchestrator/orchestrator.ts +++ b/src/orchestrator/orchestrator.ts @@ -283,6 +283,17 @@ async function executeQueue( ): Promise { const { team, pool, scheduler, config } = ctx + // Relay queue-level skip events to the orchestrator's onProgress callback. + const unsubSkipped = config.onProgress + ? queue.on('task:skipped', (task) => { + config.onProgress!({ + type: 'task_skipped', + task: task.id, + data: task, + } satisfies OrchestratorEvent) + }) + : undefined + while (true) { // Re-run auto-assignment each iteration so tasks that were unblocked since // the last round (and thus have no assignee yet) get assigned before dispatch. @@ -294,6 +305,9 @@ async function executeQueue( break } + // Track tasks that complete successfully in this round for the approval gate. + const completedThisRound: Task[] = [] + // Dispatch all currently-pending tasks as a parallel batch. const dispatchPromises = pending.map(async (task): Promise => { // Mark in-progress @@ -390,7 +404,8 @@ async function executeQueue( await sharedMem.write(assignee, `task:${task.id}:result`, result.output) } - queue.complete(task.id, result.output) + const completedTask = queue.complete(task.id, result.output) + completedThisRound.push(completedTask) config.onProgress?.({ type: 'task_complete', @@ -418,7 +433,32 @@ async function executeQueue( // Wait for the entire parallel batch before checking for newly-unblocked tasks. await Promise.all(dispatchPromises) + + // --- Approval gate --- + // After the batch completes, check if the caller wants to approve + // the next round before it starts. + if (config.onApproval && completedThisRound.length > 0) { + scheduler.autoAssign(queue, team.getAgents()) + const nextPending = queue.getByStatus('pending') + + if (nextPending.length > 0) { + let approved: boolean + try { + approved = await config.onApproval(completedThisRound, nextPending) + } catch (err) { + const reason = `Skipped: approval callback error — ${err instanceof Error ? err.message : String(err)}` + queue.skipRemaining(reason) + break + } + if (!approved) { + queue.skipRemaining('Skipped: approval rejected.') + break + } + } + } } + + unsubSkipped?.() } /** @@ -471,8 +511,8 @@ async function buildTaskPrompt(task: Task, team: Team): Promise { */ export class OpenMultiAgent { private readonly config: Required< - Omit - > & Pick + Omit + > & Pick private readonly teams: Map = new Map() private completedTaskCount = 0 @@ -492,6 +532,7 @@ export class OpenMultiAgent { defaultProvider: config.defaultProvider ?? 'anthropic', defaultBaseURL: config.defaultBaseURL, defaultApiKey: config.defaultApiKey, + onApproval: config.onApproval, onProgress: config.onProgress, onTrace: config.onTrace, } @@ -854,6 +895,7 @@ export class OpenMultiAgent { ): Promise { const completedTasks = tasks.filter((t) => t.status === 'completed') const failedTasks = tasks.filter((t) => t.status === 'failed') + const skippedTasks = tasks.filter((t) => t.status === 'skipped') const resultSections = completedTasks.map((t) => { const assignee = t.assignee ?? 'unknown' @@ -864,6 +906,10 @@ export class OpenMultiAgent { (t) => `### ${t.title} (FAILED)\nError: ${t.result ?? 'unknown error'}`, ) + const skippedSections = skippedTasks.map( + (t) => `### ${t.title} (SKIPPED)\nReason: ${t.result ?? 'approval rejected'}`, + ) + // Also include shared memory summary for additional context let memorySummary = '' const sharedMem = team.getSharedMemoryInstance() @@ -878,11 +924,12 @@ export class OpenMultiAgent { `## Task Results`, ...resultSections, ...(failureSections.length > 0 ? ['', '## Failed Tasks', ...failureSections] : []), + ...(skippedSections.length > 0 ? ['', '## Skipped Tasks', ...skippedSections] : []), ...(memorySummary ? ['', memorySummary] : []), '', '## Your Task', 'Synthesise the above results into a comprehensive final answer that addresses the original goal.', - 'If some tasks failed, note any gaps in the result.', + 'If some tasks failed or were skipped, note any gaps in the result.', ].join('\n') } diff --git a/src/task/queue.ts b/src/task/queue.ts index 8888c09..c5f6a17 100644 --- a/src/task/queue.ts +++ b/src/task/queue.ts @@ -18,6 +18,7 @@ export type TaskQueueEvent = | 'task:ready' | 'task:complete' | 'task:failed' + | 'task:skipped' | 'all:complete' /** Handler for `'task:ready' | 'task:complete' | 'task:failed'` events. */ @@ -156,6 +157,44 @@ export class TaskQueue { return failed } + /** + * Marks `taskId` as `'skipped'` and records `reason` in the `result` field. + * + * Fires `'task:skipped'` for the skipped task and cascades to every + * downstream task that transitively depended on it. + * + * @throws {Error} when `taskId` is not found. + */ + skip(taskId: string, reason: string): Task { + const skipped = this.update(taskId, { status: 'skipped', result: reason }) + this.emit('task:skipped', skipped) + this.cascadeSkip(taskId) + if (this.isComplete()) { + this.emitAllComplete() + } + return skipped + } + + /** + * Marks all non-terminal tasks as `'skipped'`. + * + * Used when an approval gate rejects continuation — every pending, blocked, + * or in-progress task is skipped with the given reason. + */ + skipRemaining(reason = 'Skipped: approval rejected.'): void { + // Snapshot first — update() mutates the live map, which is unsafe to + // iterate over during modification. + const snapshot = Array.from(this.tasks.values()) + for (const task of snapshot) { + if (task.status === 'completed' || task.status === 'failed' || task.status === 'skipped') continue + const skipped = this.update(task.id, { status: 'skipped', result: reason }) + this.emit('task:skipped', skipped) + } + if (this.isComplete()) { + this.emitAllComplete() + } + } + /** * Recursively marks all tasks that (transitively) depend on `failedTaskId` * as `'failed'` with an informative message, firing `'task:failed'` for each. @@ -178,6 +217,24 @@ export class TaskQueue { } } + /** + * Recursively marks all tasks that (transitively) depend on `skippedTaskId` + * as `'skipped'`, firing `'task:skipped'` for each. + */ + private cascadeSkip(skippedTaskId: string): void { + for (const task of this.tasks.values()) { + if (task.status !== 'blocked' && task.status !== 'pending') continue + if (!task.dependsOn?.includes(skippedTaskId)) continue + + const cascaded = this.update(task.id, { + status: 'skipped', + result: `Skipped: dependency "${skippedTaskId}" was skipped.`, + }) + this.emit('task:skipped', cascaded) + this.cascadeSkip(task.id) + } + } + // --------------------------------------------------------------------------- // Queries // --------------------------------------------------------------------------- @@ -227,11 +284,11 @@ export class TaskQueue { /** * Returns `true` when every task in the queue has reached a terminal state - * (`'completed'` or `'failed'`), **or** the queue is empty. + * (`'completed'`, `'failed'`, or `'skipped'`), **or** the queue is empty. */ isComplete(): boolean { for (const task of this.tasks.values()) { - if (task.status !== 'completed' && task.status !== 'failed') return false + if (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'skipped') return false } return true } @@ -249,12 +306,14 @@ export class TaskQueue { total: number completed: number failed: number + skipped: number inProgress: number pending: number blocked: number } { let completed = 0 let failed = 0 + let skipped = 0 let inProgress = 0 let pending = 0 let blocked = 0 @@ -267,6 +326,9 @@ export class TaskQueue { case 'failed': failed++ break + case 'skipped': + skipped++ + break case 'in_progress': inProgress++ break @@ -283,6 +345,7 @@ export class TaskQueue { total: this.tasks.size, completed, failed, + skipped, inProgress, pending, blocked, @@ -370,7 +433,7 @@ export class TaskQueue { } } - private emit(event: 'task:ready' | 'task:complete' | 'task:failed', task: Task): void { + private emit(event: 'task:ready' | 'task:complete' | 'task:failed' | 'task:skipped', task: Task): void { const map = this.listeners.get(event) if (!map) return for (const handler of map.values()) { diff --git a/src/types.ts b/src/types.ts index 6695d81..e49f8a0 100644 --- a/src/types.ts +++ b/src/types.ts @@ -286,7 +286,7 @@ export interface TeamRunResult { // --------------------------------------------------------------------------- /** Valid states for a {@link Task}. */ -export type TaskStatus = 'pending' | 'in_progress' | 'completed' | 'failed' | 'blocked' +export type TaskStatus = 'pending' | 'in_progress' | 'completed' | 'failed' | 'blocked' | 'skipped' /** A discrete unit of work tracked by the orchestrator. */ export interface Task { @@ -320,6 +320,7 @@ export interface OrchestratorEvent { | 'agent_complete' | 'task_start' | 'task_complete' + | 'task_skipped' | 'task_retry' | 'message' | 'error' @@ -337,6 +338,17 @@ export interface OrchestratorConfig { readonly defaultApiKey?: string readonly onProgress?: (event: OrchestratorEvent) => void readonly onTrace?: (event: TraceEvent) => void | Promise + /** + * Optional approval gate called between task execution rounds. + * + * After a batch of tasks completes, this callback receives all + * completed {@link Task}s from that round and the list of tasks about + * to start next. Return `true` to continue or `false` to abort — + * remaining tasks will be marked `'skipped'`. + * + * Not called after the final round (when no tasks remain to start). + */ + readonly onApproval?: (completedTasks: readonly Task[], nextTasks: readonly Task[]) => Promise } // --------------------------------------------------------------------------- diff --git a/tests/approval.test.ts b/tests/approval.test.ts new file mode 100644 index 0000000..a9ddfb4 --- /dev/null +++ b/tests/approval.test.ts @@ -0,0 +1,356 @@ +import { describe, it, expect, vi } from 'vitest' +import { TaskQueue } from '../src/task/queue.js' +import { createTask } from '../src/task/task.js' +import { OpenMultiAgent } from '../src/orchestrator/orchestrator.js' +import { Agent } from '../src/agent/agent.js' +import { AgentRunner } from '../src/agent/runner.js' +import { ToolRegistry } from '../src/tool/framework.js' +import { ToolExecutor } from '../src/tool/executor.js' +import { AgentPool } from '../src/agent/pool.js' +import type { AgentConfig, LLMAdapter, LLMResponse, Task } from '../src/types.js' + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function task(id: string, opts: { dependsOn?: string[]; assignee?: string } = {}) { + const t = createTask({ title: id, description: `task ${id}`, assignee: opts.assignee }) + return { ...t, id, dependsOn: opts.dependsOn } as ReturnType +} + +function mockAdapter(responseText: string): LLMAdapter { + return { + name: 'mock', + async chat() { + return { + id: 'mock-1', + content: [{ type: 'text' as const, text: responseText }], + model: 'mock-model', + stop_reason: 'end_turn', + usage: { input_tokens: 10, output_tokens: 20 }, + } satisfies LLMResponse + }, + async *stream() { + /* unused */ + }, + } +} + +function buildMockAgent(config: AgentConfig, responseText: string): Agent { + const registry = new ToolRegistry() + const executor = new ToolExecutor(registry) + const agent = new Agent(config, registry, executor) + const runner = new AgentRunner(mockAdapter(responseText), registry, executor, { + model: config.model, + systemPrompt: config.systemPrompt, + maxTurns: config.maxTurns, + maxTokens: config.maxTokens, + temperature: config.temperature, + agentName: config.name, + }) + ;(agent as any).runner = runner + return agent +} + +// --------------------------------------------------------------------------- +// TaskQueue: skip / skipRemaining +// --------------------------------------------------------------------------- + +describe('TaskQueue — skip', () => { + it('marks a task as skipped', () => { + const q = new TaskQueue() + q.add(task('a')) + q.skip('a', 'user rejected') + expect(q.list()[0].status).toBe('skipped') + expect(q.list()[0].result).toBe('user rejected') + }) + + it('fires task:skipped event with updated task object', () => { + const q = new TaskQueue() + const handler = vi.fn() + q.on('task:skipped', handler) + + q.add(task('a')) + q.skip('a', 'rejected') + + expect(handler).toHaveBeenCalledTimes(1) + const emitted = handler.mock.calls[0][0] + expect(emitted.id).toBe('a') + expect(emitted.status).toBe('skipped') + expect(emitted.result).toBe('rejected') + }) + + it('cascades skip to dependent tasks', () => { + const q = new TaskQueue() + q.add(task('a')) + q.add(task('b', { dependsOn: ['a'] })) + q.add(task('c', { dependsOn: ['b'] })) + + q.skip('a', 'rejected') + + expect(q.list().find((t) => t.id === 'a')!.status).toBe('skipped') + expect(q.list().find((t) => t.id === 'b')!.status).toBe('skipped') + expect(q.list().find((t) => t.id === 'c')!.status).toBe('skipped') + }) + + it('does not cascade to independent tasks', () => { + const q = new TaskQueue() + q.add(task('a')) + q.add(task('b')) + q.add(task('c', { dependsOn: ['a'] })) + + q.skip('a', 'rejected') + + expect(q.list().find((t) => t.id === 'b')!.status).toBe('pending') + expect(q.list().find((t) => t.id === 'c')!.status).toBe('skipped') + }) + + it('throws when skipping a non-existent task', () => { + const q = new TaskQueue() + expect(() => q.skip('nope', 'reason')).toThrow('not found') + }) + + it('isComplete() treats skipped as terminal', () => { + const q = new TaskQueue() + q.add(task('a')) + q.add(task('b')) + + q.complete('a', 'done') + expect(q.isComplete()).toBe(false) + + q.skip('b', 'rejected') + expect(q.isComplete()).toBe(true) + }) + + it('getProgress() counts skipped tasks', () => { + const q = new TaskQueue() + q.add(task('a')) + q.add(task('b')) + q.add(task('c')) + + q.complete('a', 'done') + q.skip('b', 'rejected') + + const progress = q.getProgress() + expect(progress.completed).toBe(1) + expect(progress.skipped).toBe(1) + expect(progress.pending).toBe(1) + }) +}) + +describe('TaskQueue — skipRemaining', () => { + it('marks all non-terminal tasks as skipped', () => { + const q = new TaskQueue() + q.add(task('a')) + q.add(task('b')) + q.add(task('c', { dependsOn: ['a'] })) + + q.complete('a', 'done') + q.skipRemaining('approval rejected') + + expect(q.list().find((t) => t.id === 'a')!.status).toBe('completed') + expect(q.list().find((t) => t.id === 'b')!.status).toBe('skipped') + expect(q.list().find((t) => t.id === 'c')!.status).toBe('skipped') + }) + + it('leaves failed tasks untouched', () => { + const q = new TaskQueue() + q.add(task('a')) + q.add(task('b')) + + q.fail('a', 'error') + q.skipRemaining() + + expect(q.list().find((t) => t.id === 'a')!.status).toBe('failed') + expect(q.list().find((t) => t.id === 'b')!.status).toBe('skipped') + }) + + it('emits task:skipped with the updated task object (not stale)', () => { + const q = new TaskQueue() + const handler = vi.fn() + q.on('task:skipped', handler) + + q.add(task('a')) + q.add(task('b')) + + q.skipRemaining('reason') + + expect(handler).toHaveBeenCalledTimes(2) + // Every emitted task must have status 'skipped' + for (const call of handler.mock.calls) { + expect(call[0].status).toBe('skipped') + expect(call[0].result).toBe('reason') + } + }) + + it('fires all:complete after skipRemaining', () => { + const q = new TaskQueue() + const handler = vi.fn() + q.on('all:complete', handler) + + q.add(task('a')) + q.add(task('b')) + + q.complete('a', 'done') + expect(handler).not.toHaveBeenCalled() + + q.skipRemaining() + expect(handler).toHaveBeenCalledTimes(1) + }) +}) + +// --------------------------------------------------------------------------- +// Orchestrator: onApproval integration +// --------------------------------------------------------------------------- + +describe('onApproval integration', () => { + function patchPool(orchestrator: OpenMultiAgent, agents: Map) { + ;(orchestrator as any).buildPool = () => { + const pool = new AgentPool(5) + for (const [, agent] of agents) { + pool.add(agent) + } + return pool + } + } + + function setup(onApproval?: (tasks: readonly Task[], next: readonly Task[]) => Promise) { + const agentA: AgentConfig = { name: 'agent-a', model: 'mock', systemPrompt: 'You are agent A.' } + const agentB: AgentConfig = { name: 'agent-b', model: 'mock', systemPrompt: 'You are agent B.' } + + const orchestrator = new OpenMultiAgent({ + defaultModel: 'mock', + ...(onApproval ? { onApproval } : {}), + }) + + const team = orchestrator.createTeam('test', { + name: 'test', + agents: [agentA, agentB], + }) + + const mockAgents = new Map() + mockAgents.set('agent-a', buildMockAgent(agentA, 'result from A')) + mockAgents.set('agent-b', buildMockAgent(agentB, 'result from B')) + patchPool(orchestrator, mockAgents) + + return { orchestrator, team } + } + + it('approve all — all tasks complete normally', async () => { + const approvalSpy = vi.fn().mockResolvedValue(true) + const { orchestrator, team } = setup(approvalSpy) + + const result = await orchestrator.runTasks(team, [ + { title: 'task-1', description: 'first', assignee: 'agent-a' }, + { title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] }, + ]) + + expect(result.success).toBe(true) + expect(result.agentResults.has('agent-a')).toBe(true) + expect(result.agentResults.has('agent-b')).toBe(true) + // onApproval called once (between round 1 and round 2) + expect(approvalSpy).toHaveBeenCalledTimes(1) + }) + + it('reject mid-pipeline — remaining tasks skipped', async () => { + const approvalSpy = vi.fn().mockResolvedValue(false) + const { orchestrator, team } = setup(approvalSpy) + + const result = await orchestrator.runTasks(team, [ + { title: 'task-1', description: 'first', assignee: 'agent-a' }, + { title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] }, + ]) + + expect(approvalSpy).toHaveBeenCalledTimes(1) + // Only agent-a's output present (task-2 was skipped, never ran) + expect(result.agentResults.has('agent-a')).toBe(true) + expect(result.agentResults.has('agent-b')).toBe(false) + }) + + it('no callback — tasks flow without interruption', async () => { + const { orchestrator, team } = setup(/* no onApproval */) + + const result = await orchestrator.runTasks(team, [ + { title: 'task-1', description: 'first', assignee: 'agent-a' }, + { title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] }, + ]) + + expect(result.success).toBe(true) + expect(result.agentResults.has('agent-a')).toBe(true) + expect(result.agentResults.has('agent-b')).toBe(true) + }) + + it('callback receives correct arguments — completedTasks array and nextTasks', async () => { + const approvalSpy = vi.fn().mockResolvedValue(true) + const { orchestrator, team } = setup(approvalSpy) + + await orchestrator.runTasks(team, [ + { title: 'task-1', description: 'first', assignee: 'agent-a' }, + { title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] }, + ]) + + // First arg: array of completed tasks from this round + const completedTasks = approvalSpy.mock.calls[0][0] + expect(completedTasks).toHaveLength(1) + expect(completedTasks[0].title).toBe('task-1') + expect(completedTasks[0].status).toBe('completed') + + // Second arg: the next tasks about to run + const nextTasks = approvalSpy.mock.calls[0][1] + expect(nextTasks).toHaveLength(1) + expect(nextTasks[0].title).toBe('task-2') + }) + + it('callback throwing an error skips remaining tasks gracefully', async () => { + const approvalSpy = vi.fn().mockRejectedValue(new Error('network timeout')) + const { orchestrator, team } = setup(approvalSpy) + + // Should not throw — error is caught and remaining tasks are skipped + const result = await orchestrator.runTasks(team, [ + { title: 'task-1', description: 'first', assignee: 'agent-a' }, + { title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] }, + ]) + + expect(approvalSpy).toHaveBeenCalledTimes(1) + expect(result.agentResults.has('agent-a')).toBe(true) + expect(result.agentResults.has('agent-b')).toBe(false) + }) + + it('parallel batch — completedTasks contains all tasks from the round', async () => { + const approvalSpy = vi.fn().mockResolvedValue(true) + const agentA: AgentConfig = { name: 'agent-a', model: 'mock', systemPrompt: 'A' } + const agentB: AgentConfig = { name: 'agent-b', model: 'mock', systemPrompt: 'B' } + const agentC: AgentConfig = { name: 'agent-c', model: 'mock', systemPrompt: 'C' } + + const orchestrator = new OpenMultiAgent({ + defaultModel: 'mock', + onApproval: approvalSpy, + }) + + const team = orchestrator.createTeam('test', { + name: 'test', + agents: [agentA, agentB, agentC], + }) + + const mockAgents = new Map() + mockAgents.set('agent-a', buildMockAgent(agentA, 'A done')) + mockAgents.set('agent-b', buildMockAgent(agentB, 'B done')) + mockAgents.set('agent-c', buildMockAgent(agentC, 'C done')) + patchPool(orchestrator, mockAgents) + + // task-1 and task-2 are independent (run in parallel), task-3 depends on both + await orchestrator.runTasks(team, [ + { title: 'task-1', description: 'first', assignee: 'agent-a' }, + { title: 'task-2', description: 'second', assignee: 'agent-b' }, + { title: 'task-3', description: 'third', assignee: 'agent-c', dependsOn: ['task-1', 'task-2'] }, + ]) + + // Approval called once between the parallel batch and task-3 + expect(approvalSpy).toHaveBeenCalledTimes(1) + const completedTasks = approvalSpy.mock.calls[0][0] as Task[] + // Both task-1 and task-2 completed in the same round + expect(completedTasks).toHaveLength(2) + const titles = completedTasks.map((t: Task) => t.title).sort() + expect(titles).toEqual(['task-1', 'task-2']) + }) +})