From a1ccbfea6125e302a51b99daef17dbe124b07965 Mon Sep 17 00:00:00 2001 From: JackChen <26346076+JackChen-me@users.noreply.github.com> Date: Sun, 5 Apr 2026 00:41:21 +0800 Subject: [PATCH] feat(agent): add beforeRun / afterRun lifecycle hooks (#45) * feat(agent): add beforeRun / afterRun lifecycle hooks (#31) Add optional hook callbacks to AgentConfig for cross-cutting concerns (guardrails, logging, token budgets) without modifying framework internals. - beforeRun: receives prompt + agent config, can modify or throw to abort - afterRun: receives AgentRunResult, can modify or throw to fail - Works with all three execution modes: run(), prompt(), stream() - 15 test cases covering modify, throw, async, composition, and history integrity * fix(agent): preserve non-text content blocks in beforeRun hook - applyHookContext now replaces only text blocks, keeping images and tool results intact (was silently stripping them) - Use backward loop instead of reverse() + find() for efficiency - Clarify JSDoc that only `prompt` is applied from hook return value - Add test for mixed-content user messages * fix(agent): address review feedback on beforeRun/afterRun hooks - Normalize stream done event to always yield AgentRunResult - Move transitionTo('completed') after afterRun to fix state ordering - Strip hook functions from BeforeRunHookContext.agent to avoid self-references - Pass originalPrompt to applyHookContext to avoid redundant message scan - Clarify afterRun JSDoc: not called when the run throws - Add tests: error-path skip, outputSchema+afterRun, ctx.agent shape, multi-turn hooks --- src/agent/agent.ts | 80 ++++++- src/index.ts | 1 + src/types.ts | 20 ++ tests/agent-hooks.test.ts | 473 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 572 insertions(+), 2 deletions(-) create mode 100644 tests/agent-hooks.test.ts diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 58a1df3..caf5a9c 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -27,6 +27,7 @@ import type { AgentConfig, AgentState, AgentRunResult, + BeforeRunHookContext, LLMMessage, StreamEvent, TokenUsage, @@ -278,6 +279,13 @@ export class Agent { const agentStartMs = Date.now() try { + // --- beforeRun hook --- + if (this.config.beforeRun) { + const hookCtx = this.buildBeforeRunHookContext(messages) + const modified = await this.config.beforeRun(hookCtx) + this.applyHookContext(messages, modified, hookCtx.prompt) + } + const runner = await this.getRunner() const internalOnMessage = (msg: LLMMessage) => { this.state.messages.push(msg) @@ -296,18 +304,28 @@ export class Agent { // --- Structured output validation --- if (this.config.outputSchema) { - const validated = await this.validateStructuredOutput( + let validated = await this.validateStructuredOutput( messages, result, runner, runOptions, ) + // --- afterRun hook --- + if (this.config.afterRun) { + validated = await this.config.afterRun(validated) + } this.emitAgentTrace(callerOptions, agentStartMs, validated) return validated } + let agentResult = this.toAgentRunResult(result, true) + + // --- afterRun hook --- + if (this.config.afterRun) { + agentResult = await this.config.afterRun(agentResult) + } + this.transitionTo('completed') - const agentResult = this.toAgentRunResult(result, true) this.emitAgentTrace(callerOptions, agentStartMs, agentResult) return agentResult } catch (err) { @@ -440,13 +458,27 @@ export class Agent { this.transitionTo('running') try { + // --- beforeRun hook --- + if (this.config.beforeRun) { + const hookCtx = this.buildBeforeRunHookContext(messages) + const modified = await this.config.beforeRun(hookCtx) + this.applyHookContext(messages, modified, hookCtx.prompt) + } + const runner = await this.getRunner() for await (const event of runner.stream(messages)) { if (event.type === 'done') { const result = event.data as import('./runner.js').RunResult this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage) + + let agentResult = this.toAgentRunResult(result, true) + if (this.config.afterRun) { + agentResult = await this.config.afterRun(agentResult) + } this.transitionTo('completed') + yield { type: 'done', data: agentResult } satisfies StreamEvent + continue } else if (event.type === 'error') { const error = event.data instanceof Error ? event.data @@ -463,6 +495,50 @@ export class Agent { } } + // ------------------------------------------------------------------------- + // Hook helpers + // ------------------------------------------------------------------------- + + /** Extract the prompt text from the last user message to build hook context. */ + private buildBeforeRunHookContext(messages: LLMMessage[]): BeforeRunHookContext { + let prompt = '' + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i]!.role === 'user') { + prompt = messages[i]!.content + .filter((b): b is import('../types.js').TextBlock => b.type === 'text') + .map(b => b.text) + .join('') + break + } + } + // Strip hook functions to avoid circular self-references in the context + const { beforeRun, afterRun, ...agentInfo } = this.config + return { prompt, agent: agentInfo as AgentConfig } + } + + /** + * Apply a (possibly modified) hook context back to the messages array. + * + * Only text blocks in the last user message are replaced; non-text content + * (images, tool results) is preserved. The array element is replaced (not + * mutated in place) so that shallow copies of the original array (e.g. from + * `prompt()`) are not affected. + */ + private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext, originalPrompt: string): void { + if (ctx.prompt === originalPrompt) return + + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i]!.role === 'user') { + const nonTextBlocks = messages[i]!.content.filter(b => b.type !== 'text') + messages[i] = { + role: 'user', + content: [{ type: 'text', text: ctx.prompt }, ...nonTextBlocks], + } + break + } + } + } + // ------------------------------------------------------------------------- // State transition helpers // ------------------------------------------------------------------------- diff --git a/src/index.ts b/src/index.ts index 312f852..cc0fff3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -147,6 +147,7 @@ export type { AgentConfig, AgentState, AgentRunResult, + BeforeRunHookContext, ToolCallRecord, // Team diff --git a/src/types.ts b/src/types.ts index 0989df4..6695d81 100644 --- a/src/types.ts +++ b/src/types.ts @@ -182,6 +182,14 @@ export interface ToolDefinition> { // Agent // --------------------------------------------------------------------------- +/** Context passed to the {@link AgentConfig.beforeRun} hook. */ +export interface BeforeRunHookContext { + /** The user prompt text. */ + readonly prompt: string + /** The agent's static configuration. */ + readonly agent: AgentConfig +} + /** Static configuration for a single agent. */ export interface AgentConfig { readonly name: string @@ -207,6 +215,18 @@ export interface AgentConfig { * retry with error feedback is attempted on validation failure. */ readonly outputSchema?: ZodSchema + /** + * Called before each agent run. Receives the prompt and agent config. + * Return a (possibly modified) context to continue, or throw to abort the run. + * Only `prompt` from the returned context is applied; `agent` is read-only informational. + */ + readonly beforeRun?: (context: BeforeRunHookContext) => Promise | BeforeRunHookContext + /** + * Called after each agent run completes successfully. Receives the run result. + * Return a (possibly modified) result, or throw to mark the run as failed. + * Not called when the run throws. For error observation, handle errors at the call site. + */ + readonly afterRun?: (result: AgentRunResult) => Promise | AgentRunResult } /** Lifecycle state tracked during an agent run. */ diff --git a/tests/agent-hooks.test.ts b/tests/agent-hooks.test.ts new file mode 100644 index 0000000..13044a3 --- /dev/null +++ b/tests/agent-hooks.test.ts @@ -0,0 +1,473 @@ +import { describe, it, expect, vi } from 'vitest' +import { z } from 'zod' +import { Agent } from '../src/agent/agent.js' +import { AgentRunner } from '../src/agent/runner.js' +import { ToolRegistry } from '../src/tool/framework.js' +import { ToolExecutor } from '../src/tool/executor.js' +import type { AgentConfig, AgentRunResult, LLMAdapter, LLMMessage, LLMResponse } from '../src/types.js' + +// --------------------------------------------------------------------------- +// Mock helpers +// --------------------------------------------------------------------------- + +/** + * Create a mock adapter that records every `chat()` call's messages + * and returns a fixed text response. + */ +function mockAdapter(responseText: string) { + const calls: LLMMessage[][] = [] + const adapter: LLMAdapter = { + name: 'mock', + async chat(messages) { + calls.push([...messages]) + return { + id: 'mock-1', + content: [{ type: 'text' as const, text: responseText }], + model: 'mock-model', + stop_reason: 'end_turn', + usage: { input_tokens: 10, output_tokens: 20 }, + } satisfies LLMResponse + }, + async *stream() { + /* unused */ + }, + } + return { adapter, calls } +} + +/** Build an Agent with a mocked LLM, bypassing createAdapter. */ +function buildMockAgent(config: AgentConfig, responseText: string) { + const { adapter, calls } = mockAdapter(responseText) + const registry = new ToolRegistry() + const executor = new ToolExecutor(registry) + const agent = new Agent(config, registry, executor) + + const runner = new AgentRunner(adapter, registry, executor, { + model: config.model, + systemPrompt: config.systemPrompt, + maxTurns: config.maxTurns, + maxTokens: config.maxTokens, + temperature: config.temperature, + agentName: config.name, + }) + ;(agent as any).runner = runner + + return { agent, calls } +} + +const baseConfig: AgentConfig = { + name: 'test-agent', + model: 'mock-model', + systemPrompt: 'You are a test agent.', +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe('Agent hooks — beforeRun / afterRun', () => { + // ----------------------------------------------------------------------- + // Baseline — no hooks + // ----------------------------------------------------------------------- + + it('works normally without hooks', async () => { + const { agent } = buildMockAgent(baseConfig, 'hello') + const result = await agent.run('ping') + + expect(result.success).toBe(true) + expect(result.output).toBe('hello') + }) + + // ----------------------------------------------------------------------- + // beforeRun + // ----------------------------------------------------------------------- + + it('beforeRun can modify the prompt', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ({ ...ctx, prompt: 'modified prompt' }), + } + const { agent, calls } = buildMockAgent(config, 'response') + await agent.run('original prompt') + + // The adapter should have received the modified prompt. + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('modified prompt') + }) + + it('beforeRun that returns context unchanged does not alter prompt', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ctx, + } + const { agent, calls } = buildMockAgent(config, 'response') + await agent.run('keep this') + + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('keep this') + }) + + it('beforeRun throwing aborts the run with failure', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: () => { throw new Error('budget exceeded') }, + } + const { agent, calls } = buildMockAgent(config, 'should not reach') + const result = await agent.run('hi') + + expect(result.success).toBe(false) + expect(result.output).toContain('budget exceeded') + // No LLM call should have been made. + expect(calls).toHaveLength(0) + }) + + it('async beforeRun works', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: async (ctx) => { + await Promise.resolve() + return { ...ctx, prompt: 'async modified' } + }, + } + const { agent, calls } = buildMockAgent(config, 'ok') + await agent.run('original') + + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('async modified') + }) + + // ----------------------------------------------------------------------- + // afterRun + // ----------------------------------------------------------------------- + + it('afterRun can modify the result', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: (result) => ({ ...result, output: 'modified output' }), + } + const { agent } = buildMockAgent(config, 'original output') + const result = await agent.run('hi') + + expect(result.success).toBe(true) + expect(result.output).toBe('modified output') + }) + + it('afterRun throwing marks run as failed', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: () => { throw new Error('content violation') }, + } + const { agent } = buildMockAgent(config, 'bad content') + const result = await agent.run('hi') + + expect(result.success).toBe(false) + expect(result.output).toContain('content violation') + }) + + it('async afterRun works', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: async (result) => { + await Promise.resolve() + return { ...result, output: result.output.toUpperCase() } + }, + } + const { agent } = buildMockAgent(config, 'hello') + const result = await agent.run('hi') + + expect(result.output).toBe('HELLO') + }) + + // ----------------------------------------------------------------------- + // Both hooks together + // ----------------------------------------------------------------------- + + it('beforeRun and afterRun compose correctly', async () => { + const hookOrder: string[] = [] + + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => { + hookOrder.push('before') + return { ...ctx, prompt: 'injected prompt' } + }, + afterRun: (result) => { + hookOrder.push('after') + return { ...result, output: `[processed] ${result.output}` } + }, + } + const { agent, calls } = buildMockAgent(config, 'raw output') + const result = await agent.run('original') + + expect(hookOrder).toEqual(['before', 'after']) + + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('injected prompt') + + expect(result.output).toBe('[processed] raw output') + }) + + // ----------------------------------------------------------------------- + // prompt() multi-turn mode + // ----------------------------------------------------------------------- + + it('hooks fire on prompt() calls', async () => { + const beforeSpy = vi.fn((ctx) => ctx) + const afterSpy = vi.fn((result) => result) + + const config: AgentConfig = { + ...baseConfig, + beforeRun: beforeSpy, + afterRun: afterSpy, + } + const { agent } = buildMockAgent(config, 'reply') + await agent.prompt('hello') + + expect(beforeSpy).toHaveBeenCalledOnce() + expect(afterSpy).toHaveBeenCalledOnce() + expect(beforeSpy.mock.calls[0]![0].prompt).toBe('hello') + }) + + // ----------------------------------------------------------------------- + // stream() mode + // ----------------------------------------------------------------------- + + it('beforeRun fires in stream mode', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ({ ...ctx, prompt: 'stream modified' }), + } + const { agent, calls } = buildMockAgent(config, 'streamed') + + const events = [] + for await (const event of agent.stream('original')) { + events.push(event) + } + + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('stream modified') + + // Should have at least a text event and a done event. + expect(events.some(e => e.type === 'done')).toBe(true) + }) + + it('afterRun fires in stream mode and modifies done event', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: (result) => ({ ...result, output: 'stream modified output' }), + } + const { agent } = buildMockAgent(config, 'original') + + const events = [] + for await (const event of agent.stream('hi')) { + events.push(event) + } + + const doneEvent = events.find(e => e.type === 'done') + expect(doneEvent).toBeDefined() + expect((doneEvent!.data as AgentRunResult).output).toBe('stream modified output') + }) + + it('beforeRun throwing in stream mode yields error event', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: () => { throw new Error('stream abort') }, + } + const { agent } = buildMockAgent(config, 'unreachable') + + const events = [] + for await (const event of agent.stream('hi')) { + events.push(event) + } + + const errorEvent = events.find(e => e.type === 'error') + expect(errorEvent).toBeDefined() + expect((errorEvent!.data as Error).message).toContain('stream abort') + }) + + it('afterRun throwing in stream mode yields error event', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: () => { throw new Error('stream content violation') }, + } + const { agent } = buildMockAgent(config, 'streamed output') + + const events = [] + for await (const event of agent.stream('hi')) { + events.push(event) + } + + // Text events may have been yielded before the error. + const errorEvent = events.find(e => e.type === 'error') + expect(errorEvent).toBeDefined() + expect((errorEvent!.data as Error).message).toContain('stream content violation') + // No done event should be present since afterRun rejected it. + expect(events.find(e => e.type === 'done')).toBeUndefined() + }) + + // ----------------------------------------------------------------------- + // prompt() history integrity + // ----------------------------------------------------------------------- + + it('beforeRun modifying prompt preserves non-text content blocks', async () => { + // Simulate a multi-turn message where the last user message has mixed content + // (text + tool_result). beforeRun should only replace text, not strip other blocks. + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ({ ...ctx, prompt: 'modified' }), + } + const { adapter, calls } = mockAdapter('ok') + const registry = new ToolRegistry() + const executor = new ToolExecutor(registry) + const agent = new Agent(config, registry, executor) + + const runner = new AgentRunner(adapter, registry, executor, { + model: config.model, + agentName: config.name, + }) + ;(agent as any).runner = runner + + // Directly call run which creates a single text-only user message. + // To test mixed content, we need to go through the private executeRun. + // Instead, we test via prompt() after injecting history with mixed content. + ;(agent as any).messageHistory = [ + { + role: 'user' as const, + content: [ + { type: 'text' as const, text: 'original' }, + { type: 'image' as const, source: { type: 'base64' as const, media_type: 'image/png', data: 'abc' } }, + ], + }, + ] + + // prompt() appends a new user message then calls executeRun with full history + await agent.prompt('follow up') + + // The last user message sent to the LLM should have modified text + const sentMessages = calls[0]! + const lastUser = [...sentMessages].reverse().find(m => m.role === 'user')! + const textBlock = lastUser.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('modified') + + // The earlier user message (with the image) should be untouched + const firstUser = sentMessages.find(m => m.role === 'user')! + const imageBlock = firstUser.content.find(b => b.type === 'image') + expect(imageBlock).toBeDefined() + }) + + it('beforeRun modifying prompt does not corrupt messageHistory', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ({ ...ctx, prompt: 'hook-modified' }), + } + const { agent, calls } = buildMockAgent(config, 'reply') + + await agent.prompt('original message') + + // The LLM should have received the modified prompt. + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + expect((lastUserMsg!.content[0] as any).text).toBe('hook-modified') + + // But the persistent history should retain the original message. + const history = agent.getHistory() + const firstUserInHistory = history.find(m => m.role === 'user') + expect((firstUserInHistory!.content[0] as any).text).toBe('original message') + }) + + // ----------------------------------------------------------------------- + // afterRun NOT called on error + // ----------------------------------------------------------------------- + + it('afterRun is not called when executeRun throws', async () => { + const afterSpy = vi.fn((result) => result) + + const config: AgentConfig = { + ...baseConfig, + // Use beforeRun to trigger an error inside executeRun's try block, + // before afterRun would normally run. + beforeRun: () => { throw new Error('rejected by policy') }, + afterRun: afterSpy, + } + const { agent } = buildMockAgent(config, 'should not reach') + const result = await agent.run('hi') + + expect(result.success).toBe(false) + expect(result.output).toContain('rejected by policy') + expect(afterSpy).not.toHaveBeenCalled() + }) + + // ----------------------------------------------------------------------- + // outputSchema + afterRun + // ----------------------------------------------------------------------- + + it('afterRun fires after structured output validation', async () => { + const schema = z.object({ answer: z.string() }) + + const config: AgentConfig = { + ...baseConfig, + outputSchema: schema, + afterRun: (result) => ({ ...result, output: '[post-processed] ' + result.output }), + } + // Return valid JSON matching the schema + const { agent } = buildMockAgent(config, '{"answer":"42"}') + const result = await agent.run('what is the answer?') + + expect(result.success).toBe(true) + expect(result.output).toBe('[post-processed] {"answer":"42"}') + expect(result.structured).toEqual({ answer: '42' }) + }) + + // ----------------------------------------------------------------------- + // ctx.agent does not contain hook self-references + // ----------------------------------------------------------------------- + + it('beforeRun context.agent has correct config without hook self-references', async () => { + let receivedAgent: AgentConfig | undefined + + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => { + receivedAgent = ctx.agent + return ctx + }, + } + const { agent } = buildMockAgent(config, 'ok') + await agent.run('test') + + expect(receivedAgent).toBeDefined() + expect(receivedAgent!.name).toBe('test-agent') + expect(receivedAgent!.model).toBe('mock-model') + // Hook functions should be stripped to avoid circular references + expect(receivedAgent!.beforeRun).toBeUndefined() + expect(receivedAgent!.afterRun).toBeUndefined() + }) + + // ----------------------------------------------------------------------- + // Multiple prompt() turns fire hooks each time + // ----------------------------------------------------------------------- + + it('hooks fire on every prompt() call', async () => { + const beforeSpy = vi.fn((ctx) => ctx) + const afterSpy = vi.fn((result) => result) + + const config: AgentConfig = { + ...baseConfig, + beforeRun: beforeSpy, + afterRun: afterSpy, + } + const { agent } = buildMockAgent(config, 'reply') + + await agent.prompt('turn 1') + await agent.prompt('turn 2') + + expect(beforeSpy).toHaveBeenCalledTimes(2) + expect(afterSpy).toHaveBeenCalledTimes(2) + expect(beforeSpy.mock.calls[0]![0].prompt).toBe('turn 1') + expect(beforeSpy.mock.calls[1]![0].prompt).toBe('turn 2') + }) +})