feat(agent): add beforeRun / afterRun lifecycle hooks (#45)

* feat(agent): add beforeRun / afterRun lifecycle hooks (#31)

Add optional hook callbacks to AgentConfig for cross-cutting concerns
(guardrails, logging, token budgets) without modifying framework internals.

- beforeRun: receives prompt + agent config, can modify or throw to abort
- afterRun: receives AgentRunResult, can modify or throw to fail
- Works with all three execution modes: run(), prompt(), stream()
- 15 test cases covering modify, throw, async, composition, and history integrity

* fix(agent): preserve non-text content blocks in beforeRun hook

- applyHookContext now replaces only text blocks, keeping images and
  tool results intact (was silently stripping them)
- Use backward loop instead of reverse() + find() for efficiency
- Clarify JSDoc that only `prompt` is applied from hook return value
- Add test for mixed-content user messages

* fix(agent): address review feedback on beforeRun/afterRun hooks

- Normalize stream done event to always yield AgentRunResult
- Move transitionTo('completed') after afterRun to fix state ordering
- Strip hook functions from BeforeRunHookContext.agent to avoid self-references
- Pass originalPrompt to applyHookContext to avoid redundant message scan
- Clarify afterRun JSDoc: not called when the run throws
- Add tests: error-path skip, outputSchema+afterRun, ctx.agent shape, multi-turn hooks
This commit is contained in:
JackChen 2026-04-05 00:41:21 +08:00 committed by GitHub
parent 25b144acf3
commit a1ccbfea61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 572 additions and 2 deletions

View File

@ -27,6 +27,7 @@ import type {
AgentConfig, AgentConfig,
AgentState, AgentState,
AgentRunResult, AgentRunResult,
BeforeRunHookContext,
LLMMessage, LLMMessage,
StreamEvent, StreamEvent,
TokenUsage, TokenUsage,
@ -278,6 +279,13 @@ export class Agent {
const agentStartMs = Date.now() const agentStartMs = Date.now()
try { try {
// --- beforeRun hook ---
if (this.config.beforeRun) {
const hookCtx = this.buildBeforeRunHookContext(messages)
const modified = await this.config.beforeRun(hookCtx)
this.applyHookContext(messages, modified, hookCtx.prompt)
}
const runner = await this.getRunner() const runner = await this.getRunner()
const internalOnMessage = (msg: LLMMessage) => { const internalOnMessage = (msg: LLMMessage) => {
this.state.messages.push(msg) this.state.messages.push(msg)
@ -296,18 +304,28 @@ export class Agent {
// --- Structured output validation --- // --- Structured output validation ---
if (this.config.outputSchema) { if (this.config.outputSchema) {
const validated = await this.validateStructuredOutput( let validated = await this.validateStructuredOutput(
messages, messages,
result, result,
runner, runner,
runOptions, runOptions,
) )
// --- afterRun hook ---
if (this.config.afterRun) {
validated = await this.config.afterRun(validated)
}
this.emitAgentTrace(callerOptions, agentStartMs, validated) this.emitAgentTrace(callerOptions, agentStartMs, validated)
return validated return validated
} }
let agentResult = this.toAgentRunResult(result, true)
// --- afterRun hook ---
if (this.config.afterRun) {
agentResult = await this.config.afterRun(agentResult)
}
this.transitionTo('completed') this.transitionTo('completed')
const agentResult = this.toAgentRunResult(result, true)
this.emitAgentTrace(callerOptions, agentStartMs, agentResult) this.emitAgentTrace(callerOptions, agentStartMs, agentResult)
return agentResult return agentResult
} catch (err) { } catch (err) {
@ -440,13 +458,27 @@ export class Agent {
this.transitionTo('running') this.transitionTo('running')
try { try {
// --- beforeRun hook ---
if (this.config.beforeRun) {
const hookCtx = this.buildBeforeRunHookContext(messages)
const modified = await this.config.beforeRun(hookCtx)
this.applyHookContext(messages, modified, hookCtx.prompt)
}
const runner = await this.getRunner() const runner = await this.getRunner()
for await (const event of runner.stream(messages)) { for await (const event of runner.stream(messages)) {
if (event.type === 'done') { if (event.type === 'done') {
const result = event.data as import('./runner.js').RunResult const result = event.data as import('./runner.js').RunResult
this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage) this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage)
let agentResult = this.toAgentRunResult(result, true)
if (this.config.afterRun) {
agentResult = await this.config.afterRun(agentResult)
}
this.transitionTo('completed') this.transitionTo('completed')
yield { type: 'done', data: agentResult } satisfies StreamEvent
continue
} else if (event.type === 'error') { } else if (event.type === 'error') {
const error = event.data instanceof Error const error = event.data instanceof Error
? event.data ? event.data
@ -463,6 +495,50 @@ export class Agent {
} }
} }
// -------------------------------------------------------------------------
// Hook helpers
// -------------------------------------------------------------------------
/** Extract the prompt text from the last user message to build hook context. */
private buildBeforeRunHookContext(messages: LLMMessage[]): BeforeRunHookContext {
let prompt = ''
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i]!.role === 'user') {
prompt = messages[i]!.content
.filter((b): b is import('../types.js').TextBlock => b.type === 'text')
.map(b => b.text)
.join('')
break
}
}
// Strip hook functions to avoid circular self-references in the context
const { beforeRun, afterRun, ...agentInfo } = this.config
return { prompt, agent: agentInfo as AgentConfig }
}
/**
* Apply a (possibly modified) hook context back to the messages array.
*
* Only text blocks in the last user message are replaced; non-text content
* (images, tool results) is preserved. The array element is replaced (not
* mutated in place) so that shallow copies of the original array (e.g. from
* `prompt()`) are not affected.
*/
private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext, originalPrompt: string): void {
if (ctx.prompt === originalPrompt) return
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i]!.role === 'user') {
const nonTextBlocks = messages[i]!.content.filter(b => b.type !== 'text')
messages[i] = {
role: 'user',
content: [{ type: 'text', text: ctx.prompt }, ...nonTextBlocks],
}
break
}
}
}
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
// State transition helpers // State transition helpers
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------

View File

@ -147,6 +147,7 @@ export type {
AgentConfig, AgentConfig,
AgentState, AgentState,
AgentRunResult, AgentRunResult,
BeforeRunHookContext,
ToolCallRecord, ToolCallRecord,
// Team // Team

View File

@ -182,6 +182,14 @@ export interface ToolDefinition<TInput = Record<string, unknown>> {
// Agent // Agent
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
/** Context passed to the {@link AgentConfig.beforeRun} hook. */
export interface BeforeRunHookContext {
/** The user prompt text. */
readonly prompt: string
/** The agent's static configuration. */
readonly agent: AgentConfig
}
/** Static configuration for a single agent. */ /** Static configuration for a single agent. */
export interface AgentConfig { export interface AgentConfig {
readonly name: string readonly name: string
@ -207,6 +215,18 @@ export interface AgentConfig {
* retry with error feedback is attempted on validation failure. * retry with error feedback is attempted on validation failure.
*/ */
readonly outputSchema?: ZodSchema readonly outputSchema?: ZodSchema
/**
* Called before each agent run. Receives the prompt and agent config.
* Return a (possibly modified) context to continue, or throw to abort the run.
* Only `prompt` from the returned context is applied; `agent` is read-only informational.
*/
readonly beforeRun?: (context: BeforeRunHookContext) => Promise<BeforeRunHookContext> | BeforeRunHookContext
/**
* Called after each agent run completes successfully. Receives the run result.
* Return a (possibly modified) result, or throw to mark the run as failed.
* Not called when the run throws. For error observation, handle errors at the call site.
*/
readonly afterRun?: (result: AgentRunResult) => Promise<AgentRunResult> | AgentRunResult
} }
/** Lifecycle state tracked during an agent run. */ /** Lifecycle state tracked during an agent run. */

473
tests/agent-hooks.test.ts Normal file
View File

@ -0,0 +1,473 @@
import { describe, it, expect, vi } from 'vitest'
import { z } from 'zod'
import { Agent } from '../src/agent/agent.js'
import { AgentRunner } from '../src/agent/runner.js'
import { ToolRegistry } from '../src/tool/framework.js'
import { ToolExecutor } from '../src/tool/executor.js'
import type { AgentConfig, AgentRunResult, LLMAdapter, LLMMessage, LLMResponse } from '../src/types.js'
// ---------------------------------------------------------------------------
// Mock helpers
// ---------------------------------------------------------------------------
/**
* Create a mock adapter that records every `chat()` call's messages
* and returns a fixed text response.
*/
function mockAdapter(responseText: string) {
const calls: LLMMessage[][] = []
const adapter: LLMAdapter = {
name: 'mock',
async chat(messages) {
calls.push([...messages])
return {
id: 'mock-1',
content: [{ type: 'text' as const, text: responseText }],
model: 'mock-model',
stop_reason: 'end_turn',
usage: { input_tokens: 10, output_tokens: 20 },
} satisfies LLMResponse
},
async *stream() {
/* unused */
},
}
return { adapter, calls }
}
/** Build an Agent with a mocked LLM, bypassing createAdapter. */
function buildMockAgent(config: AgentConfig, responseText: string) {
const { adapter, calls } = mockAdapter(responseText)
const registry = new ToolRegistry()
const executor = new ToolExecutor(registry)
const agent = new Agent(config, registry, executor)
const runner = new AgentRunner(adapter, registry, executor, {
model: config.model,
systemPrompt: config.systemPrompt,
maxTurns: config.maxTurns,
maxTokens: config.maxTokens,
temperature: config.temperature,
agentName: config.name,
})
;(agent as any).runner = runner
return { agent, calls }
}
const baseConfig: AgentConfig = {
name: 'test-agent',
model: 'mock-model',
systemPrompt: 'You are a test agent.',
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
describe('Agent hooks — beforeRun / afterRun', () => {
// -----------------------------------------------------------------------
// Baseline — no hooks
// -----------------------------------------------------------------------
it('works normally without hooks', async () => {
const { agent } = buildMockAgent(baseConfig, 'hello')
const result = await agent.run('ping')
expect(result.success).toBe(true)
expect(result.output).toBe('hello')
})
// -----------------------------------------------------------------------
// beforeRun
// -----------------------------------------------------------------------
it('beforeRun can modify the prompt', async () => {
const config: AgentConfig = {
...baseConfig,
beforeRun: (ctx) => ({ ...ctx, prompt: 'modified prompt' }),
}
const { agent, calls } = buildMockAgent(config, 'response')
await agent.run('original prompt')
// The adapter should have received the modified prompt.
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
expect((textBlock as any).text).toBe('modified prompt')
})
it('beforeRun that returns context unchanged does not alter prompt', async () => {
const config: AgentConfig = {
...baseConfig,
beforeRun: (ctx) => ctx,
}
const { agent, calls } = buildMockAgent(config, 'response')
await agent.run('keep this')
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
expect((textBlock as any).text).toBe('keep this')
})
it('beforeRun throwing aborts the run with failure', async () => {
const config: AgentConfig = {
...baseConfig,
beforeRun: () => { throw new Error('budget exceeded') },
}
const { agent, calls } = buildMockAgent(config, 'should not reach')
const result = await agent.run('hi')
expect(result.success).toBe(false)
expect(result.output).toContain('budget exceeded')
// No LLM call should have been made.
expect(calls).toHaveLength(0)
})
it('async beforeRun works', async () => {
const config: AgentConfig = {
...baseConfig,
beforeRun: async (ctx) => {
await Promise.resolve()
return { ...ctx, prompt: 'async modified' }
},
}
const { agent, calls } = buildMockAgent(config, 'ok')
await agent.run('original')
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
expect((textBlock as any).text).toBe('async modified')
})
// -----------------------------------------------------------------------
// afterRun
// -----------------------------------------------------------------------
it('afterRun can modify the result', async () => {
const config: AgentConfig = {
...baseConfig,
afterRun: (result) => ({ ...result, output: 'modified output' }),
}
const { agent } = buildMockAgent(config, 'original output')
const result = await agent.run('hi')
expect(result.success).toBe(true)
expect(result.output).toBe('modified output')
})
it('afterRun throwing marks run as failed', async () => {
const config: AgentConfig = {
...baseConfig,
afterRun: () => { throw new Error('content violation') },
}
const { agent } = buildMockAgent(config, 'bad content')
const result = await agent.run('hi')
expect(result.success).toBe(false)
expect(result.output).toContain('content violation')
})
it('async afterRun works', async () => {
const config: AgentConfig = {
...baseConfig,
afterRun: async (result) => {
await Promise.resolve()
return { ...result, output: result.output.toUpperCase() }
},
}
const { agent } = buildMockAgent(config, 'hello')
const result = await agent.run('hi')
expect(result.output).toBe('HELLO')
})
// -----------------------------------------------------------------------
// Both hooks together
// -----------------------------------------------------------------------
it('beforeRun and afterRun compose correctly', async () => {
const hookOrder: string[] = []
const config: AgentConfig = {
...baseConfig,
beforeRun: (ctx) => {
hookOrder.push('before')
return { ...ctx, prompt: 'injected prompt' }
},
afterRun: (result) => {
hookOrder.push('after')
return { ...result, output: `[processed] ${result.output}` }
},
}
const { agent, calls } = buildMockAgent(config, 'raw output')
const result = await agent.run('original')
expect(hookOrder).toEqual(['before', 'after'])
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
expect((textBlock as any).text).toBe('injected prompt')
expect(result.output).toBe('[processed] raw output')
})
// -----------------------------------------------------------------------
// prompt() multi-turn mode
// -----------------------------------------------------------------------
it('hooks fire on prompt() calls', async () => {
const beforeSpy = vi.fn((ctx) => ctx)
const afterSpy = vi.fn((result) => result)
const config: AgentConfig = {
...baseConfig,
beforeRun: beforeSpy,
afterRun: afterSpy,
}
const { agent } = buildMockAgent(config, 'reply')
await agent.prompt('hello')
expect(beforeSpy).toHaveBeenCalledOnce()
expect(afterSpy).toHaveBeenCalledOnce()
expect(beforeSpy.mock.calls[0]![0].prompt).toBe('hello')
})
// -----------------------------------------------------------------------
// stream() mode
// -----------------------------------------------------------------------
it('beforeRun fires in stream mode', async () => {
const config: AgentConfig = {
...baseConfig,
beforeRun: (ctx) => ({ ...ctx, prompt: 'stream modified' }),
}
const { agent, calls } = buildMockAgent(config, 'streamed')
const events = []
for await (const event of agent.stream('original')) {
events.push(event)
}
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
expect((textBlock as any).text).toBe('stream modified')
// Should have at least a text event and a done event.
expect(events.some(e => e.type === 'done')).toBe(true)
})
it('afterRun fires in stream mode and modifies done event', async () => {
const config: AgentConfig = {
...baseConfig,
afterRun: (result) => ({ ...result, output: 'stream modified output' }),
}
const { agent } = buildMockAgent(config, 'original')
const events = []
for await (const event of agent.stream('hi')) {
events.push(event)
}
const doneEvent = events.find(e => e.type === 'done')
expect(doneEvent).toBeDefined()
expect((doneEvent!.data as AgentRunResult).output).toBe('stream modified output')
})
it('beforeRun throwing in stream mode yields error event', async () => {
const config: AgentConfig = {
...baseConfig,
beforeRun: () => { throw new Error('stream abort') },
}
const { agent } = buildMockAgent(config, 'unreachable')
const events = []
for await (const event of agent.stream('hi')) {
events.push(event)
}
const errorEvent = events.find(e => e.type === 'error')
expect(errorEvent).toBeDefined()
expect((errorEvent!.data as Error).message).toContain('stream abort')
})
it('afterRun throwing in stream mode yields error event', async () => {
const config: AgentConfig = {
...baseConfig,
afterRun: () => { throw new Error('stream content violation') },
}
const { agent } = buildMockAgent(config, 'streamed output')
const events = []
for await (const event of agent.stream('hi')) {
events.push(event)
}
// Text events may have been yielded before the error.
const errorEvent = events.find(e => e.type === 'error')
expect(errorEvent).toBeDefined()
expect((errorEvent!.data as Error).message).toContain('stream content violation')
// No done event should be present since afterRun rejected it.
expect(events.find(e => e.type === 'done')).toBeUndefined()
})
// -----------------------------------------------------------------------
// prompt() history integrity
// -----------------------------------------------------------------------
it('beforeRun modifying prompt preserves non-text content blocks', async () => {
// Simulate a multi-turn message where the last user message has mixed content
// (text + tool_result). beforeRun should only replace text, not strip other blocks.
const config: AgentConfig = {
...baseConfig,
beforeRun: (ctx) => ({ ...ctx, prompt: 'modified' }),
}
const { adapter, calls } = mockAdapter('ok')
const registry = new ToolRegistry()
const executor = new ToolExecutor(registry)
const agent = new Agent(config, registry, executor)
const runner = new AgentRunner(adapter, registry, executor, {
model: config.model,
agentName: config.name,
})
;(agent as any).runner = runner
// Directly call run which creates a single text-only user message.
// To test mixed content, we need to go through the private executeRun.
// Instead, we test via prompt() after injecting history with mixed content.
;(agent as any).messageHistory = [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'original' },
{ type: 'image' as const, source: { type: 'base64' as const, media_type: 'image/png', data: 'abc' } },
],
},
]
// prompt() appends a new user message then calls executeRun with full history
await agent.prompt('follow up')
// The last user message sent to the LLM should have modified text
const sentMessages = calls[0]!
const lastUser = [...sentMessages].reverse().find(m => m.role === 'user')!
const textBlock = lastUser.content.find(b => b.type === 'text')
expect((textBlock as any).text).toBe('modified')
// The earlier user message (with the image) should be untouched
const firstUser = sentMessages.find(m => m.role === 'user')!
const imageBlock = firstUser.content.find(b => b.type === 'image')
expect(imageBlock).toBeDefined()
})
it('beforeRun modifying prompt does not corrupt messageHistory', async () => {
const config: AgentConfig = {
...baseConfig,
beforeRun: (ctx) => ({ ...ctx, prompt: 'hook-modified' }),
}
const { agent, calls } = buildMockAgent(config, 'reply')
await agent.prompt('original message')
// The LLM should have received the modified prompt.
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
expect((lastUserMsg!.content[0] as any).text).toBe('hook-modified')
// But the persistent history should retain the original message.
const history = agent.getHistory()
const firstUserInHistory = history.find(m => m.role === 'user')
expect((firstUserInHistory!.content[0] as any).text).toBe('original message')
})
// -----------------------------------------------------------------------
// afterRun NOT called on error
// -----------------------------------------------------------------------
it('afterRun is not called when executeRun throws', async () => {
const afterSpy = vi.fn((result) => result)
const config: AgentConfig = {
...baseConfig,
// Use beforeRun to trigger an error inside executeRun's try block,
// before afterRun would normally run.
beforeRun: () => { throw new Error('rejected by policy') },
afterRun: afterSpy,
}
const { agent } = buildMockAgent(config, 'should not reach')
const result = await agent.run('hi')
expect(result.success).toBe(false)
expect(result.output).toContain('rejected by policy')
expect(afterSpy).not.toHaveBeenCalled()
})
// -----------------------------------------------------------------------
// outputSchema + afterRun
// -----------------------------------------------------------------------
it('afterRun fires after structured output validation', async () => {
const schema = z.object({ answer: z.string() })
const config: AgentConfig = {
...baseConfig,
outputSchema: schema,
afterRun: (result) => ({ ...result, output: '[post-processed] ' + result.output }),
}
// Return valid JSON matching the schema
const { agent } = buildMockAgent(config, '{"answer":"42"}')
const result = await agent.run('what is the answer?')
expect(result.success).toBe(true)
expect(result.output).toBe('[post-processed] {"answer":"42"}')
expect(result.structured).toEqual({ answer: '42' })
})
// -----------------------------------------------------------------------
// ctx.agent does not contain hook self-references
// -----------------------------------------------------------------------
it('beforeRun context.agent has correct config without hook self-references', async () => {
let receivedAgent: AgentConfig | undefined
const config: AgentConfig = {
...baseConfig,
beforeRun: (ctx) => {
receivedAgent = ctx.agent
return ctx
},
}
const { agent } = buildMockAgent(config, 'ok')
await agent.run('test')
expect(receivedAgent).toBeDefined()
expect(receivedAgent!.name).toBe('test-agent')
expect(receivedAgent!.model).toBe('mock-model')
// Hook functions should be stripped to avoid circular references
expect(receivedAgent!.beforeRun).toBeUndefined()
expect(receivedAgent!.afterRun).toBeUndefined()
})
// -----------------------------------------------------------------------
// Multiple prompt() turns fire hooks each time
// -----------------------------------------------------------------------
it('hooks fire on every prompt() call', async () => {
const beforeSpy = vi.fn((ctx) => ctx)
const afterSpy = vi.fn((result) => result)
const config: AgentConfig = {
...baseConfig,
beforeRun: beforeSpy,
afterRun: afterSpy,
}
const { agent } = buildMockAgent(config, 'reply')
await agent.prompt('turn 1')
await agent.prompt('turn 2')
expect(beforeSpy).toHaveBeenCalledTimes(2)
expect(afterSpy).toHaveBeenCalledTimes(2)
expect(beforeSpy.mock.calls[0]![0].prompt).toBe('turn 1')
expect(beforeSpy.mock.calls[1]![0].prompt).toBe('turn 2')
})
})