From 9abd570750b83e5f48b1a2bd4b71757490adbb98 Mon Sep 17 00:00:00 2001 From: JackChen Date: Sun, 5 Apr 2026 00:21:03 +0800 Subject: [PATCH] feat(agent): add beforeRun / afterRun lifecycle hooks (#31) Add optional hook callbacks to AgentConfig for cross-cutting concerns (guardrails, logging, token budgets) without modifying framework internals. - beforeRun: receives prompt + agent config, can modify or throw to abort - afterRun: receives AgentRunResult, can modify or throw to fail - Works with all three execution modes: run(), prompt(), stream() - 15 test cases covering modify, throw, async, composition, and history integrity --- src/agent/agent.ts | 72 +++++++- src/index.ts | 1 + src/types.ts | 18 ++ tests/agent-hooks.test.ts | 334 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 423 insertions(+), 2 deletions(-) create mode 100644 tests/agent-hooks.test.ts diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 58a1df3..0ae28f0 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -27,6 +27,7 @@ import type { AgentConfig, AgentState, AgentRunResult, + BeforeRunHookContext, LLMMessage, StreamEvent, TokenUsage, @@ -278,6 +279,14 @@ export class Agent { const agentStartMs = Date.now() try { + // --- beforeRun hook --- + if (this.config.beforeRun) { + const hookCtx = await this.config.beforeRun( + this.buildBeforeRunHookContext(messages), + ) + this.applyHookContext(messages, hookCtx) + } + const runner = await this.getRunner() const internalOnMessage = (msg: LLMMessage) => { this.state.messages.push(msg) @@ -296,18 +305,28 @@ export class Agent { // --- Structured output validation --- if (this.config.outputSchema) { - const validated = await this.validateStructuredOutput( + let validated = await this.validateStructuredOutput( messages, result, runner, runOptions, ) + // --- afterRun hook --- + if (this.config.afterRun) { + validated = await this.config.afterRun(validated) + } this.emitAgentTrace(callerOptions, agentStartMs, validated) return validated } this.transitionTo('completed') - const agentResult = this.toAgentRunResult(result, true) + let agentResult = this.toAgentRunResult(result, true) + + // --- afterRun hook --- + if (this.config.afterRun) { + agentResult = await this.config.afterRun(agentResult) + } + this.emitAgentTrace(callerOptions, agentStartMs, agentResult) return agentResult } catch (err) { @@ -440,6 +459,14 @@ export class Agent { this.transitionTo('running') try { + // --- beforeRun hook --- + if (this.config.beforeRun) { + const hookCtx = await this.config.beforeRun( + this.buildBeforeRunHookContext(messages), + ) + this.applyHookContext(messages, hookCtx) + } + const runner = await this.getRunner() for await (const event of runner.stream(messages)) { @@ -447,6 +474,14 @@ export class Agent { const result = event.data as import('./runner.js').RunResult this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage) this.transitionTo('completed') + + // --- afterRun hook --- + if (this.config.afterRun) { + const agentResult = this.toAgentRunResult(result, true) + const modified = await this.config.afterRun(agentResult) + yield { type: 'done', data: modified } satisfies StreamEvent + continue + } } else if (event.type === 'error') { const error = event.data instanceof Error ? event.data @@ -463,6 +498,39 @@ export class Agent { } } + // ------------------------------------------------------------------------- + // Hook helpers + // ------------------------------------------------------------------------- + + /** Extract the prompt text from the last user message to build hook context. */ + private buildBeforeRunHookContext(messages: LLMMessage[]): BeforeRunHookContext { + const lastUser = [...messages].reverse().find(m => m.role === 'user') + const prompt = lastUser + ? lastUser.content + .filter((b): b is import('../types.js').TextBlock => b.type === 'text') + .map(b => b.text) + .join('') + : '' + return { prompt, agent: this.config } + } + + /** Apply a (possibly modified) hook context back to the messages array. */ + private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext): void { + const original = this.buildBeforeRunHookContext(messages) + if (ctx.prompt === original.prompt) return + + // Find the last user message and replace its text content. + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i]!.role === 'user') { + messages[i] = { + role: 'user', + content: [{ type: 'text', text: ctx.prompt }], + } + break + } + } + } + // ------------------------------------------------------------------------- // State transition helpers // ------------------------------------------------------------------------- diff --git a/src/index.ts b/src/index.ts index 312f852..cc0fff3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -147,6 +147,7 @@ export type { AgentConfig, AgentState, AgentRunResult, + BeforeRunHookContext, ToolCallRecord, // Team diff --git a/src/types.ts b/src/types.ts index 0989df4..738e6d6 100644 --- a/src/types.ts +++ b/src/types.ts @@ -182,6 +182,14 @@ export interface ToolDefinition> { // Agent // --------------------------------------------------------------------------- +/** Context passed to the {@link AgentConfig.beforeRun} hook. */ +export interface BeforeRunHookContext { + /** The user prompt text. */ + readonly prompt: string + /** The agent's static configuration. */ + readonly agent: AgentConfig +} + /** Static configuration for a single agent. */ export interface AgentConfig { readonly name: string @@ -207,6 +215,16 @@ export interface AgentConfig { * retry with error feedback is attempted on validation failure. */ readonly outputSchema?: ZodSchema + /** + * Called before each agent run. Receives the prompt and agent config. + * Return a (possibly modified) context to continue, or throw to abort the run. + */ + readonly beforeRun?: (context: BeforeRunHookContext) => Promise | BeforeRunHookContext + /** + * Called after each agent run completes. Receives the run result. + * Return a (possibly modified) result, or throw to mark the run as failed. + */ + readonly afterRun?: (result: AgentRunResult) => Promise | AgentRunResult } /** Lifecycle state tracked during an agent run. */ diff --git a/tests/agent-hooks.test.ts b/tests/agent-hooks.test.ts new file mode 100644 index 0000000..10b88ba --- /dev/null +++ b/tests/agent-hooks.test.ts @@ -0,0 +1,334 @@ +import { describe, it, expect, vi } from 'vitest' +import { Agent } from '../src/agent/agent.js' +import { AgentRunner } from '../src/agent/runner.js' +import { ToolRegistry } from '../src/tool/framework.js' +import { ToolExecutor } from '../src/tool/executor.js' +import type { AgentConfig, AgentRunResult, LLMAdapter, LLMMessage, LLMResponse } from '../src/types.js' + +// --------------------------------------------------------------------------- +// Mock helpers +// --------------------------------------------------------------------------- + +/** + * Create a mock adapter that records every `chat()` call's messages + * and returns a fixed text response. + */ +function mockAdapter(responseText: string) { + const calls: LLMMessage[][] = [] + const adapter: LLMAdapter = { + name: 'mock', + async chat(messages) { + calls.push([...messages]) + return { + id: 'mock-1', + content: [{ type: 'text' as const, text: responseText }], + model: 'mock-model', + stop_reason: 'end_turn', + usage: { input_tokens: 10, output_tokens: 20 }, + } satisfies LLMResponse + }, + async *stream() { + /* unused */ + }, + } + return { adapter, calls } +} + +/** Build an Agent with a mocked LLM, bypassing createAdapter. */ +function buildMockAgent(config: AgentConfig, responseText: string) { + const { adapter, calls } = mockAdapter(responseText) + const registry = new ToolRegistry() + const executor = new ToolExecutor(registry) + const agent = new Agent(config, registry, executor) + + const runner = new AgentRunner(adapter, registry, executor, { + model: config.model, + systemPrompt: config.systemPrompt, + maxTurns: config.maxTurns, + maxTokens: config.maxTokens, + temperature: config.temperature, + agentName: config.name, + }) + ;(agent as any).runner = runner + + return { agent, calls } +} + +const baseConfig: AgentConfig = { + name: 'test-agent', + model: 'mock-model', + systemPrompt: 'You are a test agent.', +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe('Agent hooks — beforeRun / afterRun', () => { + // ----------------------------------------------------------------------- + // Baseline — no hooks + // ----------------------------------------------------------------------- + + it('works normally without hooks', async () => { + const { agent } = buildMockAgent(baseConfig, 'hello') + const result = await agent.run('ping') + + expect(result.success).toBe(true) + expect(result.output).toBe('hello') + }) + + // ----------------------------------------------------------------------- + // beforeRun + // ----------------------------------------------------------------------- + + it('beforeRun can modify the prompt', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ({ ...ctx, prompt: 'modified prompt' }), + } + const { agent, calls } = buildMockAgent(config, 'response') + await agent.run('original prompt') + + // The adapter should have received the modified prompt. + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('modified prompt') + }) + + it('beforeRun that returns context unchanged does not alter prompt', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ctx, + } + const { agent, calls } = buildMockAgent(config, 'response') + await agent.run('keep this') + + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('keep this') + }) + + it('beforeRun throwing aborts the run with failure', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: () => { throw new Error('budget exceeded') }, + } + const { agent, calls } = buildMockAgent(config, 'should not reach') + const result = await agent.run('hi') + + expect(result.success).toBe(false) + expect(result.output).toContain('budget exceeded') + // No LLM call should have been made. + expect(calls).toHaveLength(0) + }) + + it('async beforeRun works', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: async (ctx) => { + await Promise.resolve() + return { ...ctx, prompt: 'async modified' } + }, + } + const { agent, calls } = buildMockAgent(config, 'ok') + await agent.run('original') + + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('async modified') + }) + + // ----------------------------------------------------------------------- + // afterRun + // ----------------------------------------------------------------------- + + it('afterRun can modify the result', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: (result) => ({ ...result, output: 'modified output' }), + } + const { agent } = buildMockAgent(config, 'original output') + const result = await agent.run('hi') + + expect(result.success).toBe(true) + expect(result.output).toBe('modified output') + }) + + it('afterRun throwing marks run as failed', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: () => { throw new Error('content violation') }, + } + const { agent } = buildMockAgent(config, 'bad content') + const result = await agent.run('hi') + + expect(result.success).toBe(false) + expect(result.output).toContain('content violation') + }) + + it('async afterRun works', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: async (result) => { + await Promise.resolve() + return { ...result, output: result.output.toUpperCase() } + }, + } + const { agent } = buildMockAgent(config, 'hello') + const result = await agent.run('hi') + + expect(result.output).toBe('HELLO') + }) + + // ----------------------------------------------------------------------- + // Both hooks together + // ----------------------------------------------------------------------- + + it('beforeRun and afterRun compose correctly', async () => { + const hookOrder: string[] = [] + + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => { + hookOrder.push('before') + return { ...ctx, prompt: 'injected prompt' } + }, + afterRun: (result) => { + hookOrder.push('after') + return { ...result, output: `[processed] ${result.output}` } + }, + } + const { agent, calls } = buildMockAgent(config, 'raw output') + const result = await agent.run('original') + + expect(hookOrder).toEqual(['before', 'after']) + + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('injected prompt') + + expect(result.output).toBe('[processed] raw output') + }) + + // ----------------------------------------------------------------------- + // prompt() multi-turn mode + // ----------------------------------------------------------------------- + + it('hooks fire on prompt() calls', async () => { + const beforeSpy = vi.fn((ctx) => ctx) + const afterSpy = vi.fn((result) => result) + + const config: AgentConfig = { + ...baseConfig, + beforeRun: beforeSpy, + afterRun: afterSpy, + } + const { agent } = buildMockAgent(config, 'reply') + await agent.prompt('hello') + + expect(beforeSpy).toHaveBeenCalledOnce() + expect(afterSpy).toHaveBeenCalledOnce() + expect(beforeSpy.mock.calls[0]![0].prompt).toBe('hello') + }) + + // ----------------------------------------------------------------------- + // stream() mode + // ----------------------------------------------------------------------- + + it('beforeRun fires in stream mode', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ({ ...ctx, prompt: 'stream modified' }), + } + const { agent, calls } = buildMockAgent(config, 'streamed') + + const events = [] + for await (const event of agent.stream('original')) { + events.push(event) + } + + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + const textBlock = lastUserMsg!.content.find(b => b.type === 'text') + expect((textBlock as any).text).toBe('stream modified') + + // Should have at least a text event and a done event. + expect(events.some(e => e.type === 'done')).toBe(true) + }) + + it('afterRun fires in stream mode and modifies done event', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: (result) => ({ ...result, output: 'stream modified output' }), + } + const { agent } = buildMockAgent(config, 'original') + + const events = [] + for await (const event of agent.stream('hi')) { + events.push(event) + } + + const doneEvent = events.find(e => e.type === 'done') + expect(doneEvent).toBeDefined() + expect((doneEvent!.data as AgentRunResult).output).toBe('stream modified output') + }) + + it('beforeRun throwing in stream mode yields error event', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: () => { throw new Error('stream abort') }, + } + const { agent } = buildMockAgent(config, 'unreachable') + + const events = [] + for await (const event of agent.stream('hi')) { + events.push(event) + } + + const errorEvent = events.find(e => e.type === 'error') + expect(errorEvent).toBeDefined() + expect((errorEvent!.data as Error).message).toContain('stream abort') + }) + + it('afterRun throwing in stream mode yields error event', async () => { + const config: AgentConfig = { + ...baseConfig, + afterRun: () => { throw new Error('stream content violation') }, + } + const { agent } = buildMockAgent(config, 'streamed output') + + const events = [] + for await (const event of agent.stream('hi')) { + events.push(event) + } + + // Text events may have been yielded before the error. + const errorEvent = events.find(e => e.type === 'error') + expect(errorEvent).toBeDefined() + expect((errorEvent!.data as Error).message).toContain('stream content violation') + // No done event should be present since afterRun rejected it. + expect(events.find(e => e.type === 'done')).toBeUndefined() + }) + + // ----------------------------------------------------------------------- + // prompt() history integrity + // ----------------------------------------------------------------------- + + it('beforeRun modifying prompt does not corrupt messageHistory', async () => { + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => ({ ...ctx, prompt: 'hook-modified' }), + } + const { agent, calls } = buildMockAgent(config, 'reply') + + await agent.prompt('original message') + + // The LLM should have received the modified prompt. + const lastUserMsg = calls[0]!.find(m => m.role === 'user') + expect((lastUserMsg!.content[0] as any).text).toBe('hook-modified') + + // But the persistent history should retain the original message. + const history = agent.getHistory() + const firstUserInHistory = history.find(m => m.role === 'user') + expect((firstUserInHistory!.content[0] as any).text).toBe('original message') + }) +})