feat(orchestrator): add onApproval callback for human-in-the-loop (#32)
* feat(orchestrator): add onApproval callback for human-in-the-loop (#32) Add an optional `onApproval` callback to OrchestratorConfig that gates between task execution rounds. After each batch of parallel tasks completes, the callback receives the completed tasks and the tasks about to start, returning true to continue or false to abort gracefully. Key changes: - Add 'skipped' to TaskStatus for user-initiated abort (distinct from 'failed') - Add skip(), skipRemaining(), cascadeSkip() to TaskQueue - Add 'task_skipped' to OrchestratorEvent for progress monitoring - Approval gate in executeQueue() with try/catch for callback errors - Synthesis prompt now includes skipped tasks section - 17 new tests covering queue skip operations and orchestrator integration Closes #32 * docs: clarify onApproval contract and add missing test scenarios - Document skip() cascade semantics, skipRemaining() in-flight constraint, and onApproval trigger conditions / mutation warning - Add concurrency safety comment on completedThisRound - Note task_skipped as breaking union addition on OrchestratorEvent - Add 3 test scenarios: single-batch no-callback, mixed success/failure batch, and onProgress task_skipped event relay
This commit is contained in:
parent
d327acb89b
commit
9f5afb10f5
|
|
@ -283,6 +283,17 @@ async function executeQueue(
|
|||
): Promise<void> {
|
||||
const { team, pool, scheduler, config } = ctx
|
||||
|
||||
// Relay queue-level skip events to the orchestrator's onProgress callback.
|
||||
const unsubSkipped = config.onProgress
|
||||
? queue.on('task:skipped', (task) => {
|
||||
config.onProgress!({
|
||||
type: 'task_skipped',
|
||||
task: task.id,
|
||||
data: task,
|
||||
} satisfies OrchestratorEvent)
|
||||
})
|
||||
: undefined
|
||||
|
||||
while (true) {
|
||||
// Re-run auto-assignment each iteration so tasks that were unblocked since
|
||||
// the last round (and thus have no assignee yet) get assigned before dispatch.
|
||||
|
|
@ -294,6 +305,11 @@ async function executeQueue(
|
|||
break
|
||||
}
|
||||
|
||||
// Track tasks that complete successfully in this round for the approval gate.
|
||||
// Safe to push from concurrent promises: JS is single-threaded, so
|
||||
// Array.push calls from resolved microtasks never interleave.
|
||||
const completedThisRound: Task[] = []
|
||||
|
||||
// Dispatch all currently-pending tasks as a parallel batch.
|
||||
const dispatchPromises = pending.map(async (task): Promise<void> => {
|
||||
// Mark in-progress
|
||||
|
|
@ -390,7 +406,8 @@ async function executeQueue(
|
|||
await sharedMem.write(assignee, `task:${task.id}:result`, result.output)
|
||||
}
|
||||
|
||||
queue.complete(task.id, result.output)
|
||||
const completedTask = queue.complete(task.id, result.output)
|
||||
completedThisRound.push(completedTask)
|
||||
|
||||
config.onProgress?.({
|
||||
type: 'task_complete',
|
||||
|
|
@ -418,7 +435,32 @@ async function executeQueue(
|
|||
|
||||
// Wait for the entire parallel batch before checking for newly-unblocked tasks.
|
||||
await Promise.all(dispatchPromises)
|
||||
|
||||
// --- Approval gate ---
|
||||
// After the batch completes, check if the caller wants to approve
|
||||
// the next round before it starts.
|
||||
if (config.onApproval && completedThisRound.length > 0) {
|
||||
scheduler.autoAssign(queue, team.getAgents())
|
||||
const nextPending = queue.getByStatus('pending')
|
||||
|
||||
if (nextPending.length > 0) {
|
||||
let approved: boolean
|
||||
try {
|
||||
approved = await config.onApproval(completedThisRound, nextPending)
|
||||
} catch (err) {
|
||||
const reason = `Skipped: approval callback error — ${err instanceof Error ? err.message : String(err)}`
|
||||
queue.skipRemaining(reason)
|
||||
break
|
||||
}
|
||||
if (!approved) {
|
||||
queue.skipRemaining('Skipped: approval rejected.')
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsubSkipped?.()
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -471,8 +513,8 @@ async function buildTaskPrompt(task: Task, team: Team): Promise<string> {
|
|||
*/
|
||||
export class OpenMultiAgent {
|
||||
private readonly config: Required<
|
||||
Omit<OrchestratorConfig, 'onProgress' | 'onTrace' | 'defaultBaseURL' | 'defaultApiKey'>
|
||||
> & Pick<OrchestratorConfig, 'onProgress' | 'onTrace' | 'defaultBaseURL' | 'defaultApiKey'>
|
||||
Omit<OrchestratorConfig, 'onApproval' | 'onProgress' | 'onTrace' | 'defaultBaseURL' | 'defaultApiKey'>
|
||||
> & Pick<OrchestratorConfig, 'onApproval' | 'onProgress' | 'onTrace' | 'defaultBaseURL' | 'defaultApiKey'>
|
||||
|
||||
private readonly teams: Map<string, Team> = new Map()
|
||||
private completedTaskCount = 0
|
||||
|
|
@ -492,6 +534,7 @@ export class OpenMultiAgent {
|
|||
defaultProvider: config.defaultProvider ?? 'anthropic',
|
||||
defaultBaseURL: config.defaultBaseURL,
|
||||
defaultApiKey: config.defaultApiKey,
|
||||
onApproval: config.onApproval,
|
||||
onProgress: config.onProgress,
|
||||
onTrace: config.onTrace,
|
||||
}
|
||||
|
|
@ -854,6 +897,7 @@ export class OpenMultiAgent {
|
|||
): Promise<string> {
|
||||
const completedTasks = tasks.filter((t) => t.status === 'completed')
|
||||
const failedTasks = tasks.filter((t) => t.status === 'failed')
|
||||
const skippedTasks = tasks.filter((t) => t.status === 'skipped')
|
||||
|
||||
const resultSections = completedTasks.map((t) => {
|
||||
const assignee = t.assignee ?? 'unknown'
|
||||
|
|
@ -864,6 +908,10 @@ export class OpenMultiAgent {
|
|||
(t) => `### ${t.title} (FAILED)\nError: ${t.result ?? 'unknown error'}`,
|
||||
)
|
||||
|
||||
const skippedSections = skippedTasks.map(
|
||||
(t) => `### ${t.title} (SKIPPED)\nReason: ${t.result ?? 'approval rejected'}`,
|
||||
)
|
||||
|
||||
// Also include shared memory summary for additional context
|
||||
let memorySummary = ''
|
||||
const sharedMem = team.getSharedMemoryInstance()
|
||||
|
|
@ -878,11 +926,12 @@ export class OpenMultiAgent {
|
|||
`## Task Results`,
|
||||
...resultSections,
|
||||
...(failureSections.length > 0 ? ['', '## Failed Tasks', ...failureSections] : []),
|
||||
...(skippedSections.length > 0 ? ['', '## Skipped Tasks', ...skippedSections] : []),
|
||||
...(memorySummary ? ['', memorySummary] : []),
|
||||
'',
|
||||
'## Your Task',
|
||||
'Synthesise the above results into a comprehensive final answer that addresses the original goal.',
|
||||
'If some tasks failed, note any gaps in the result.',
|
||||
'If some tasks failed or were skipped, note any gaps in the result.',
|
||||
].join('\n')
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ export type TaskQueueEvent =
|
|||
| 'task:ready'
|
||||
| 'task:complete'
|
||||
| 'task:failed'
|
||||
| 'task:skipped'
|
||||
| 'all:complete'
|
||||
|
||||
/** Handler for `'task:ready' | 'task:complete' | 'task:failed'` events. */
|
||||
|
|
@ -156,6 +157,51 @@ export class TaskQueue {
|
|||
return failed
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks `taskId` as `'skipped'` and records `reason` in the `result` field.
|
||||
*
|
||||
* Fires `'task:skipped'` for the skipped task and cascades to every
|
||||
* downstream task that transitively depended on it — even if the dependent
|
||||
* has other dependencies that are still pending or completed. A skipped
|
||||
* upstream is treated as permanently unsatisfiable, mirroring `fail()`.
|
||||
*
|
||||
* @throws {Error} when `taskId` is not found.
|
||||
*/
|
||||
skip(taskId: string, reason: string): Task {
|
||||
const skipped = this.update(taskId, { status: 'skipped', result: reason })
|
||||
this.emit('task:skipped', skipped)
|
||||
this.cascadeSkip(taskId)
|
||||
if (this.isComplete()) {
|
||||
this.emitAllComplete()
|
||||
}
|
||||
return skipped
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks all non-terminal tasks as `'skipped'`.
|
||||
*
|
||||
* Used when an approval gate rejects continuation — every pending, blocked,
|
||||
* or in-progress task is skipped with the given reason.
|
||||
*
|
||||
* **Important:** Call only when no tasks are actively executing. The
|
||||
* orchestrator invokes this after `await Promise.all()`, so no tasks are
|
||||
* in-flight. Calling while agents are running may mark an in-progress task
|
||||
* as skipped while its agent continues executing.
|
||||
*/
|
||||
skipRemaining(reason = 'Skipped: approval rejected.'): void {
|
||||
// Snapshot first — update() mutates the live map, which is unsafe to
|
||||
// iterate over during modification.
|
||||
const snapshot = Array.from(this.tasks.values())
|
||||
for (const task of snapshot) {
|
||||
if (task.status === 'completed' || task.status === 'failed' || task.status === 'skipped') continue
|
||||
const skipped = this.update(task.id, { status: 'skipped', result: reason })
|
||||
this.emit('task:skipped', skipped)
|
||||
}
|
||||
if (this.isComplete()) {
|
||||
this.emitAllComplete()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively marks all tasks that (transitively) depend on `failedTaskId`
|
||||
* as `'failed'` with an informative message, firing `'task:failed'` for each.
|
||||
|
|
@ -178,6 +224,24 @@ export class TaskQueue {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively marks all tasks that (transitively) depend on `skippedTaskId`
|
||||
* as `'skipped'`, firing `'task:skipped'` for each.
|
||||
*/
|
||||
private cascadeSkip(skippedTaskId: string): void {
|
||||
for (const task of this.tasks.values()) {
|
||||
if (task.status !== 'blocked' && task.status !== 'pending') continue
|
||||
if (!task.dependsOn?.includes(skippedTaskId)) continue
|
||||
|
||||
const cascaded = this.update(task.id, {
|
||||
status: 'skipped',
|
||||
result: `Skipped: dependency "${skippedTaskId}" was skipped.`,
|
||||
})
|
||||
this.emit('task:skipped', cascaded)
|
||||
this.cascadeSkip(task.id)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Queries
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -227,11 +291,11 @@ export class TaskQueue {
|
|||
|
||||
/**
|
||||
* Returns `true` when every task in the queue has reached a terminal state
|
||||
* (`'completed'` or `'failed'`), **or** the queue is empty.
|
||||
* (`'completed'`, `'failed'`, or `'skipped'`), **or** the queue is empty.
|
||||
*/
|
||||
isComplete(): boolean {
|
||||
for (const task of this.tasks.values()) {
|
||||
if (task.status !== 'completed' && task.status !== 'failed') return false
|
||||
if (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'skipped') return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
@ -249,12 +313,14 @@ export class TaskQueue {
|
|||
total: number
|
||||
completed: number
|
||||
failed: number
|
||||
skipped: number
|
||||
inProgress: number
|
||||
pending: number
|
||||
blocked: number
|
||||
} {
|
||||
let completed = 0
|
||||
let failed = 0
|
||||
let skipped = 0
|
||||
let inProgress = 0
|
||||
let pending = 0
|
||||
let blocked = 0
|
||||
|
|
@ -267,6 +333,9 @@ export class TaskQueue {
|
|||
case 'failed':
|
||||
failed++
|
||||
break
|
||||
case 'skipped':
|
||||
skipped++
|
||||
break
|
||||
case 'in_progress':
|
||||
inProgress++
|
||||
break
|
||||
|
|
@ -283,6 +352,7 @@ export class TaskQueue {
|
|||
total: this.tasks.size,
|
||||
completed,
|
||||
failed,
|
||||
skipped,
|
||||
inProgress,
|
||||
pending,
|
||||
blocked,
|
||||
|
|
@ -370,7 +440,7 @@ export class TaskQueue {
|
|||
}
|
||||
}
|
||||
|
||||
private emit(event: 'task:ready' | 'task:complete' | 'task:failed', task: Task): void {
|
||||
private emit(event: 'task:ready' | 'task:complete' | 'task:failed' | 'task:skipped', task: Task): void {
|
||||
const map = this.listeners.get(event)
|
||||
if (!map) return
|
||||
for (const handler of map.values()) {
|
||||
|
|
|
|||
27
src/types.ts
27
src/types.ts
|
|
@ -286,7 +286,7 @@ export interface TeamRunResult {
|
|||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Valid states for a {@link Task}. */
|
||||
export type TaskStatus = 'pending' | 'in_progress' | 'completed' | 'failed' | 'blocked'
|
||||
export type TaskStatus = 'pending' | 'in_progress' | 'completed' | 'failed' | 'blocked' | 'skipped'
|
||||
|
||||
/** A discrete unit of work tracked by the orchestrator. */
|
||||
export interface Task {
|
||||
|
|
@ -313,13 +313,19 @@ export interface Task {
|
|||
// Orchestrator
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Progress event emitted by the orchestrator during a run. */
|
||||
/**
|
||||
* Progress event emitted by the orchestrator during a run.
|
||||
*
|
||||
* **v0.3 addition:** `'task_skipped'` — consumers with exhaustive switches
|
||||
* on `type` will need to add a case for this variant.
|
||||
*/
|
||||
export interface OrchestratorEvent {
|
||||
readonly type:
|
||||
| 'agent_start'
|
||||
| 'agent_complete'
|
||||
| 'task_start'
|
||||
| 'task_complete'
|
||||
| 'task_skipped'
|
||||
| 'task_retry'
|
||||
| 'message'
|
||||
| 'error'
|
||||
|
|
@ -337,6 +343,23 @@ export interface OrchestratorConfig {
|
|||
readonly defaultApiKey?: string
|
||||
readonly onProgress?: (event: OrchestratorEvent) => void
|
||||
readonly onTrace?: (event: TraceEvent) => void | Promise<void>
|
||||
/**
|
||||
* Optional approval gate called between task execution rounds.
|
||||
*
|
||||
* After a batch of tasks completes, this callback receives all
|
||||
* completed {@link Task}s from that round and the list of tasks about
|
||||
* to start next. Return `true` to continue or `false` to abort —
|
||||
* remaining tasks will be marked `'skipped'`.
|
||||
*
|
||||
* Not called when:
|
||||
* - No tasks succeeded in the round (all failed).
|
||||
* - No pending tasks remain after the round (final batch).
|
||||
*
|
||||
* **Note:** Do not mutate the {@link Task} objects passed to this
|
||||
* callback — they are live references to queue state. Mutation is
|
||||
* undefined behavior.
|
||||
*/
|
||||
readonly onApproval?: (completedTasks: readonly Task[], nextTasks: readonly Task[]) => Promise<boolean>
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,464 @@
|
|||
import { describe, it, expect, vi } from 'vitest'
|
||||
import { TaskQueue } from '../src/task/queue.js'
|
||||
import { createTask } from '../src/task/task.js'
|
||||
import { OpenMultiAgent } from '../src/orchestrator/orchestrator.js'
|
||||
import { Agent } from '../src/agent/agent.js'
|
||||
import { AgentRunner } from '../src/agent/runner.js'
|
||||
import { ToolRegistry } from '../src/tool/framework.js'
|
||||
import { ToolExecutor } from '../src/tool/executor.js'
|
||||
import { AgentPool } from '../src/agent/pool.js'
|
||||
import type { AgentConfig, LLMAdapter, LLMResponse, Task } from '../src/types.js'
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function task(id: string, opts: { dependsOn?: string[]; assignee?: string } = {}) {
|
||||
const t = createTask({ title: id, description: `task ${id}`, assignee: opts.assignee })
|
||||
return { ...t, id, dependsOn: opts.dependsOn } as ReturnType<typeof createTask>
|
||||
}
|
||||
|
||||
function mockAdapter(responseText: string): LLMAdapter {
|
||||
return {
|
||||
name: 'mock',
|
||||
async chat() {
|
||||
return {
|
||||
id: 'mock-1',
|
||||
content: [{ type: 'text' as const, text: responseText }],
|
||||
model: 'mock-model',
|
||||
stop_reason: 'end_turn',
|
||||
usage: { input_tokens: 10, output_tokens: 20 },
|
||||
} satisfies LLMResponse
|
||||
},
|
||||
async *stream() {
|
||||
/* unused */
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
function buildMockAgent(config: AgentConfig, responseText: string): Agent {
|
||||
const registry = new ToolRegistry()
|
||||
const executor = new ToolExecutor(registry)
|
||||
const agent = new Agent(config, registry, executor)
|
||||
const runner = new AgentRunner(mockAdapter(responseText), registry, executor, {
|
||||
model: config.model,
|
||||
systemPrompt: config.systemPrompt,
|
||||
maxTurns: config.maxTurns,
|
||||
maxTokens: config.maxTokens,
|
||||
temperature: config.temperature,
|
||||
agentName: config.name,
|
||||
})
|
||||
;(agent as any).runner = runner
|
||||
return agent
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TaskQueue: skip / skipRemaining
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('TaskQueue — skip', () => {
|
||||
it('marks a task as skipped', () => {
|
||||
const q = new TaskQueue()
|
||||
q.add(task('a'))
|
||||
q.skip('a', 'user rejected')
|
||||
expect(q.list()[0].status).toBe('skipped')
|
||||
expect(q.list()[0].result).toBe('user rejected')
|
||||
})
|
||||
|
||||
it('fires task:skipped event with updated task object', () => {
|
||||
const q = new TaskQueue()
|
||||
const handler = vi.fn()
|
||||
q.on('task:skipped', handler)
|
||||
|
||||
q.add(task('a'))
|
||||
q.skip('a', 'rejected')
|
||||
|
||||
expect(handler).toHaveBeenCalledTimes(1)
|
||||
const emitted = handler.mock.calls[0][0]
|
||||
expect(emitted.id).toBe('a')
|
||||
expect(emitted.status).toBe('skipped')
|
||||
expect(emitted.result).toBe('rejected')
|
||||
})
|
||||
|
||||
it('cascades skip to dependent tasks', () => {
|
||||
const q = new TaskQueue()
|
||||
q.add(task('a'))
|
||||
q.add(task('b', { dependsOn: ['a'] }))
|
||||
q.add(task('c', { dependsOn: ['b'] }))
|
||||
|
||||
q.skip('a', 'rejected')
|
||||
|
||||
expect(q.list().find((t) => t.id === 'a')!.status).toBe('skipped')
|
||||
expect(q.list().find((t) => t.id === 'b')!.status).toBe('skipped')
|
||||
expect(q.list().find((t) => t.id === 'c')!.status).toBe('skipped')
|
||||
})
|
||||
|
||||
it('does not cascade to independent tasks', () => {
|
||||
const q = new TaskQueue()
|
||||
q.add(task('a'))
|
||||
q.add(task('b'))
|
||||
q.add(task('c', { dependsOn: ['a'] }))
|
||||
|
||||
q.skip('a', 'rejected')
|
||||
|
||||
expect(q.list().find((t) => t.id === 'b')!.status).toBe('pending')
|
||||
expect(q.list().find((t) => t.id === 'c')!.status).toBe('skipped')
|
||||
})
|
||||
|
||||
it('throws when skipping a non-existent task', () => {
|
||||
const q = new TaskQueue()
|
||||
expect(() => q.skip('nope', 'reason')).toThrow('not found')
|
||||
})
|
||||
|
||||
it('isComplete() treats skipped as terminal', () => {
|
||||
const q = new TaskQueue()
|
||||
q.add(task('a'))
|
||||
q.add(task('b'))
|
||||
|
||||
q.complete('a', 'done')
|
||||
expect(q.isComplete()).toBe(false)
|
||||
|
||||
q.skip('b', 'rejected')
|
||||
expect(q.isComplete()).toBe(true)
|
||||
})
|
||||
|
||||
it('getProgress() counts skipped tasks', () => {
|
||||
const q = new TaskQueue()
|
||||
q.add(task('a'))
|
||||
q.add(task('b'))
|
||||
q.add(task('c'))
|
||||
|
||||
q.complete('a', 'done')
|
||||
q.skip('b', 'rejected')
|
||||
|
||||
const progress = q.getProgress()
|
||||
expect(progress.completed).toBe(1)
|
||||
expect(progress.skipped).toBe(1)
|
||||
expect(progress.pending).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('TaskQueue — skipRemaining', () => {
|
||||
it('marks all non-terminal tasks as skipped', () => {
|
||||
const q = new TaskQueue()
|
||||
q.add(task('a'))
|
||||
q.add(task('b'))
|
||||
q.add(task('c', { dependsOn: ['a'] }))
|
||||
|
||||
q.complete('a', 'done')
|
||||
q.skipRemaining('approval rejected')
|
||||
|
||||
expect(q.list().find((t) => t.id === 'a')!.status).toBe('completed')
|
||||
expect(q.list().find((t) => t.id === 'b')!.status).toBe('skipped')
|
||||
expect(q.list().find((t) => t.id === 'c')!.status).toBe('skipped')
|
||||
})
|
||||
|
||||
it('leaves failed tasks untouched', () => {
|
||||
const q = new TaskQueue()
|
||||
q.add(task('a'))
|
||||
q.add(task('b'))
|
||||
|
||||
q.fail('a', 'error')
|
||||
q.skipRemaining()
|
||||
|
||||
expect(q.list().find((t) => t.id === 'a')!.status).toBe('failed')
|
||||
expect(q.list().find((t) => t.id === 'b')!.status).toBe('skipped')
|
||||
})
|
||||
|
||||
it('emits task:skipped with the updated task object (not stale)', () => {
|
||||
const q = new TaskQueue()
|
||||
const handler = vi.fn()
|
||||
q.on('task:skipped', handler)
|
||||
|
||||
q.add(task('a'))
|
||||
q.add(task('b'))
|
||||
|
||||
q.skipRemaining('reason')
|
||||
|
||||
expect(handler).toHaveBeenCalledTimes(2)
|
||||
// Every emitted task must have status 'skipped'
|
||||
for (const call of handler.mock.calls) {
|
||||
expect(call[0].status).toBe('skipped')
|
||||
expect(call[0].result).toBe('reason')
|
||||
}
|
||||
})
|
||||
|
||||
it('fires all:complete after skipRemaining', () => {
|
||||
const q = new TaskQueue()
|
||||
const handler = vi.fn()
|
||||
q.on('all:complete', handler)
|
||||
|
||||
q.add(task('a'))
|
||||
q.add(task('b'))
|
||||
|
||||
q.complete('a', 'done')
|
||||
expect(handler).not.toHaveBeenCalled()
|
||||
|
||||
q.skipRemaining()
|
||||
expect(handler).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Orchestrator: onApproval integration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('onApproval integration', () => {
|
||||
function patchPool(orchestrator: OpenMultiAgent, agents: Map<string, Agent>) {
|
||||
;(orchestrator as any).buildPool = () => {
|
||||
const pool = new AgentPool(5)
|
||||
for (const [, agent] of agents) {
|
||||
pool.add(agent)
|
||||
}
|
||||
return pool
|
||||
}
|
||||
}
|
||||
|
||||
function setup(onApproval?: (tasks: readonly Task[], next: readonly Task[]) => Promise<boolean>) {
|
||||
const agentA: AgentConfig = { name: 'agent-a', model: 'mock', systemPrompt: 'You are agent A.' }
|
||||
const agentB: AgentConfig = { name: 'agent-b', model: 'mock', systemPrompt: 'You are agent B.' }
|
||||
|
||||
const orchestrator = new OpenMultiAgent({
|
||||
defaultModel: 'mock',
|
||||
...(onApproval ? { onApproval } : {}),
|
||||
})
|
||||
|
||||
const team = orchestrator.createTeam('test', {
|
||||
name: 'test',
|
||||
agents: [agentA, agentB],
|
||||
})
|
||||
|
||||
const mockAgents = new Map<string, Agent>()
|
||||
mockAgents.set('agent-a', buildMockAgent(agentA, 'result from A'))
|
||||
mockAgents.set('agent-b', buildMockAgent(agentB, 'result from B'))
|
||||
patchPool(orchestrator, mockAgents)
|
||||
|
||||
return { orchestrator, team }
|
||||
}
|
||||
|
||||
it('approve all — all tasks complete normally', async () => {
|
||||
const approvalSpy = vi.fn().mockResolvedValue(true)
|
||||
const { orchestrator, team } = setup(approvalSpy)
|
||||
|
||||
const result = await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] },
|
||||
])
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.agentResults.has('agent-a')).toBe(true)
|
||||
expect(result.agentResults.has('agent-b')).toBe(true)
|
||||
// onApproval called once (between round 1 and round 2)
|
||||
expect(approvalSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('reject mid-pipeline — remaining tasks skipped', async () => {
|
||||
const approvalSpy = vi.fn().mockResolvedValue(false)
|
||||
const { orchestrator, team } = setup(approvalSpy)
|
||||
|
||||
const result = await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] },
|
||||
])
|
||||
|
||||
expect(approvalSpy).toHaveBeenCalledTimes(1)
|
||||
// Only agent-a's output present (task-2 was skipped, never ran)
|
||||
expect(result.agentResults.has('agent-a')).toBe(true)
|
||||
expect(result.agentResults.has('agent-b')).toBe(false)
|
||||
})
|
||||
|
||||
it('no callback — tasks flow without interruption', async () => {
|
||||
const { orchestrator, team } = setup(/* no onApproval */)
|
||||
|
||||
const result = await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] },
|
||||
])
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.agentResults.has('agent-a')).toBe(true)
|
||||
expect(result.agentResults.has('agent-b')).toBe(true)
|
||||
})
|
||||
|
||||
it('callback receives correct arguments — completedTasks array and nextTasks', async () => {
|
||||
const approvalSpy = vi.fn().mockResolvedValue(true)
|
||||
const { orchestrator, team } = setup(approvalSpy)
|
||||
|
||||
await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] },
|
||||
])
|
||||
|
||||
// First arg: array of completed tasks from this round
|
||||
const completedTasks = approvalSpy.mock.calls[0][0]
|
||||
expect(completedTasks).toHaveLength(1)
|
||||
expect(completedTasks[0].title).toBe('task-1')
|
||||
expect(completedTasks[0].status).toBe('completed')
|
||||
|
||||
// Second arg: the next tasks about to run
|
||||
const nextTasks = approvalSpy.mock.calls[0][1]
|
||||
expect(nextTasks).toHaveLength(1)
|
||||
expect(nextTasks[0].title).toBe('task-2')
|
||||
})
|
||||
|
||||
it('callback throwing an error skips remaining tasks gracefully', async () => {
|
||||
const approvalSpy = vi.fn().mockRejectedValue(new Error('network timeout'))
|
||||
const { orchestrator, team } = setup(approvalSpy)
|
||||
|
||||
// Should not throw — error is caught and remaining tasks are skipped
|
||||
const result = await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] },
|
||||
])
|
||||
|
||||
expect(approvalSpy).toHaveBeenCalledTimes(1)
|
||||
expect(result.agentResults.has('agent-a')).toBe(true)
|
||||
expect(result.agentResults.has('agent-b')).toBe(false)
|
||||
})
|
||||
|
||||
it('parallel batch — completedTasks contains all tasks from the round', async () => {
|
||||
const approvalSpy = vi.fn().mockResolvedValue(true)
|
||||
const agentA: AgentConfig = { name: 'agent-a', model: 'mock', systemPrompt: 'A' }
|
||||
const agentB: AgentConfig = { name: 'agent-b', model: 'mock', systemPrompt: 'B' }
|
||||
const agentC: AgentConfig = { name: 'agent-c', model: 'mock', systemPrompt: 'C' }
|
||||
|
||||
const orchestrator = new OpenMultiAgent({
|
||||
defaultModel: 'mock',
|
||||
onApproval: approvalSpy,
|
||||
})
|
||||
|
||||
const team = orchestrator.createTeam('test', {
|
||||
name: 'test',
|
||||
agents: [agentA, agentB, agentC],
|
||||
})
|
||||
|
||||
const mockAgents = new Map<string, Agent>()
|
||||
mockAgents.set('agent-a', buildMockAgent(agentA, 'A done'))
|
||||
mockAgents.set('agent-b', buildMockAgent(agentB, 'B done'))
|
||||
mockAgents.set('agent-c', buildMockAgent(agentC, 'C done'))
|
||||
patchPool(orchestrator, mockAgents)
|
||||
|
||||
// task-1 and task-2 are independent (run in parallel), task-3 depends on both
|
||||
await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b' },
|
||||
{ title: 'task-3', description: 'third', assignee: 'agent-c', dependsOn: ['task-1', 'task-2'] },
|
||||
])
|
||||
|
||||
// Approval called once between the parallel batch and task-3
|
||||
expect(approvalSpy).toHaveBeenCalledTimes(1)
|
||||
const completedTasks = approvalSpy.mock.calls[0][0] as Task[]
|
||||
// Both task-1 and task-2 completed in the same round
|
||||
expect(completedTasks).toHaveLength(2)
|
||||
const titles = completedTasks.map((t: Task) => t.title).sort()
|
||||
expect(titles).toEqual(['task-1', 'task-2'])
|
||||
})
|
||||
|
||||
it('single batch with no second round — callback never fires', async () => {
|
||||
const approvalSpy = vi.fn().mockResolvedValue(true)
|
||||
const { orchestrator, team } = setup(approvalSpy)
|
||||
|
||||
const result = await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b' },
|
||||
])
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
// No second round → callback never called
|
||||
expect(approvalSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('mixed success/failure in batch — completedTasks only contains succeeded tasks', async () => {
|
||||
const approvalSpy = vi.fn().mockResolvedValue(true)
|
||||
const agentA: AgentConfig = { name: 'agent-a', model: 'mock', systemPrompt: 'A' }
|
||||
const agentB: AgentConfig = { name: 'agent-b', model: 'mock', systemPrompt: 'B' }
|
||||
const agentC: AgentConfig = { name: 'agent-c', model: 'mock', systemPrompt: 'C' }
|
||||
|
||||
const orchestrator = new OpenMultiAgent({
|
||||
defaultModel: 'mock',
|
||||
onApproval: approvalSpy,
|
||||
})
|
||||
|
||||
const team = orchestrator.createTeam('test', {
|
||||
name: 'test',
|
||||
agents: [agentA, agentB, agentC],
|
||||
})
|
||||
|
||||
const mockAgents = new Map<string, Agent>()
|
||||
mockAgents.set('agent-a', buildMockAgent(agentA, 'A done'))
|
||||
mockAgents.set('agent-b', buildMockAgent(agentB, 'B done'))
|
||||
mockAgents.set('agent-c', buildMockAgent(agentC, 'C done'))
|
||||
|
||||
// Patch buildPool so that pool.run for agent-b returns a failure result
|
||||
;(orchestrator as any).buildPool = () => {
|
||||
const pool = new AgentPool(5)
|
||||
for (const [, agent] of mockAgents) pool.add(agent)
|
||||
const originalRun = pool.run.bind(pool)
|
||||
pool.run = async (agentName: string, prompt: string, opts?: any) => {
|
||||
if (agentName === 'agent-b') {
|
||||
return {
|
||||
success: false,
|
||||
output: 'simulated failure',
|
||||
messages: [],
|
||||
tokenUsage: { input_tokens: 0, output_tokens: 0 },
|
||||
toolCalls: [],
|
||||
}
|
||||
}
|
||||
return originalRun(agentName, prompt, opts)
|
||||
}
|
||||
return pool
|
||||
}
|
||||
|
||||
// task-1 (success) and task-2 (fail) run in parallel, task-3 depends on task-1
|
||||
await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b' },
|
||||
{ title: 'task-3', description: 'third', assignee: 'agent-c', dependsOn: ['task-1'] },
|
||||
])
|
||||
|
||||
expect(approvalSpy).toHaveBeenCalledTimes(1)
|
||||
const completedTasks = approvalSpy.mock.calls[0][0] as Task[]
|
||||
// Only task-1 succeeded — task-2 failed, so it should not appear
|
||||
expect(completedTasks).toHaveLength(1)
|
||||
expect(completedTasks[0].title).toBe('task-1')
|
||||
expect(completedTasks[0].status).toBe('completed')
|
||||
})
|
||||
|
||||
it('onProgress receives task_skipped events when approval is rejected', async () => {
|
||||
const progressSpy = vi.fn()
|
||||
const agentA: AgentConfig = { name: 'agent-a', model: 'mock', systemPrompt: 'A' }
|
||||
const agentB: AgentConfig = { name: 'agent-b', model: 'mock', systemPrompt: 'B' }
|
||||
|
||||
const orchestrator = new OpenMultiAgent({
|
||||
defaultModel: 'mock',
|
||||
onApproval: vi.fn().mockResolvedValue(false),
|
||||
onProgress: progressSpy,
|
||||
})
|
||||
|
||||
const team = orchestrator.createTeam('test', {
|
||||
name: 'test',
|
||||
agents: [agentA, agentB],
|
||||
})
|
||||
|
||||
const mockAgents = new Map<string, Agent>()
|
||||
mockAgents.set('agent-a', buildMockAgent(agentA, 'A done'))
|
||||
mockAgents.set('agent-b', buildMockAgent(agentB, 'B done'))
|
||||
;(orchestrator as any).buildPool = () => {
|
||||
const pool = new AgentPool(5)
|
||||
for (const [, agent] of mockAgents) pool.add(agent)
|
||||
return pool
|
||||
}
|
||||
|
||||
await orchestrator.runTasks(team, [
|
||||
{ title: 'task-1', description: 'first', assignee: 'agent-a' },
|
||||
{ title: 'task-2', description: 'second', assignee: 'agent-b', dependsOn: ['task-1'] },
|
||||
])
|
||||
|
||||
const skippedEvents = progressSpy.mock.calls
|
||||
.map((c: any) => c[0])
|
||||
.filter((e: any) => e.type === 'task_skipped')
|
||||
|
||||
expect(skippedEvents).toHaveLength(1)
|
||||
expect(skippedEvents[0].data.status).toBe('skipped')
|
||||
})
|
||||
})
|
||||
Loading…
Reference in New Issue