From 9938b3d62d3cdc6c4f2f0b043a93bcfe3a31253d Mon Sep 17 00:00:00 2001 From: JackChen Date: Sun, 5 Apr 2026 00:40:17 +0800 Subject: [PATCH] fix(agent): address review feedback on beforeRun/afterRun hooks - Normalize stream done event to always yield AgentRunResult - Move transitionTo('completed') after afterRun to fix state ordering - Strip hook functions from BeforeRunHookContext.agent to avoid self-references - Pass originalPrompt to applyHookContext to avoid redundant message scan - Clarify afterRun JSDoc: not called when the run throws - Add tests: error-path skip, outputSchema+afterRun, ctx.agent shape, multi-turn hooks --- src/agent/agent.ts | 36 +++++++-------- src/types.ts | 3 +- tests/agent-hooks.test.ts | 93 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 20 deletions(-) diff --git a/src/agent/agent.ts b/src/agent/agent.ts index e22c613..caf5a9c 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -281,10 +281,9 @@ export class Agent { try { // --- beforeRun hook --- if (this.config.beforeRun) { - const hookCtx = await this.config.beforeRun( - this.buildBeforeRunHookContext(messages), - ) - this.applyHookContext(messages, hookCtx) + const hookCtx = this.buildBeforeRunHookContext(messages) + const modified = await this.config.beforeRun(hookCtx) + this.applyHookContext(messages, modified, hookCtx.prompt) } const runner = await this.getRunner() @@ -319,7 +318,6 @@ export class Agent { return validated } - this.transitionTo('completed') let agentResult = this.toAgentRunResult(result, true) // --- afterRun hook --- @@ -327,6 +325,7 @@ export class Agent { agentResult = await this.config.afterRun(agentResult) } + this.transitionTo('completed') this.emitAgentTrace(callerOptions, agentStartMs, agentResult) return agentResult } catch (err) { @@ -461,10 +460,9 @@ export class Agent { try { // --- beforeRun hook --- if (this.config.beforeRun) { - const hookCtx = await this.config.beforeRun( - this.buildBeforeRunHookContext(messages), - ) - this.applyHookContext(messages, hookCtx) + const hookCtx = this.buildBeforeRunHookContext(messages) + const modified = await this.config.beforeRun(hookCtx) + this.applyHookContext(messages, modified, hookCtx.prompt) } const runner = await this.getRunner() @@ -473,15 +471,14 @@ export class Agent { if (event.type === 'done') { const result = event.data as import('./runner.js').RunResult this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage) - this.transitionTo('completed') - // --- afterRun hook --- + let agentResult = this.toAgentRunResult(result, true) if (this.config.afterRun) { - const agentResult = this.toAgentRunResult(result, true) - const modified = await this.config.afterRun(agentResult) - yield { type: 'done', data: modified } satisfies StreamEvent - continue + agentResult = await this.config.afterRun(agentResult) } + this.transitionTo('completed') + yield { type: 'done', data: agentResult } satisfies StreamEvent + continue } else if (event.type === 'error') { const error = event.data instanceof Error ? event.data @@ -514,7 +511,9 @@ export class Agent { break } } - return { prompt, agent: this.config } + // Strip hook functions to avoid circular self-references in the context + const { beforeRun, afterRun, ...agentInfo } = this.config + return { prompt, agent: agentInfo as AgentConfig } } /** @@ -525,9 +524,8 @@ export class Agent { * mutated in place) so that shallow copies of the original array (e.g. from * `prompt()`) are not affected. */ - private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext): void { - const original = this.buildBeforeRunHookContext(messages) - if (ctx.prompt === original.prompt) return + private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext, originalPrompt: string): void { + if (ctx.prompt === originalPrompt) return for (let i = messages.length - 1; i >= 0; i--) { if (messages[i]!.role === 'user') { diff --git a/src/types.ts b/src/types.ts index 9dcbd6f..6695d81 100644 --- a/src/types.ts +++ b/src/types.ts @@ -222,8 +222,9 @@ export interface AgentConfig { */ readonly beforeRun?: (context: BeforeRunHookContext) => Promise | BeforeRunHookContext /** - * Called after each agent run completes. Receives the run result. + * Called after each agent run completes successfully. Receives the run result. * Return a (possibly modified) result, or throw to mark the run as failed. + * Not called when the run throws. For error observation, handle errors at the call site. */ readonly afterRun?: (result: AgentRunResult) => Promise | AgentRunResult } diff --git a/tests/agent-hooks.test.ts b/tests/agent-hooks.test.ts index a3cbf30..13044a3 100644 --- a/tests/agent-hooks.test.ts +++ b/tests/agent-hooks.test.ts @@ -1,4 +1,5 @@ import { describe, it, expect, vi } from 'vitest' +import { z } from 'zod' import { Agent } from '../src/agent/agent.js' import { AgentRunner } from '../src/agent/runner.js' import { ToolRegistry } from '../src/tool/framework.js' @@ -377,4 +378,96 @@ describe('Agent hooks — beforeRun / afterRun', () => { const firstUserInHistory = history.find(m => m.role === 'user') expect((firstUserInHistory!.content[0] as any).text).toBe('original message') }) + + // ----------------------------------------------------------------------- + // afterRun NOT called on error + // ----------------------------------------------------------------------- + + it('afterRun is not called when executeRun throws', async () => { + const afterSpy = vi.fn((result) => result) + + const config: AgentConfig = { + ...baseConfig, + // Use beforeRun to trigger an error inside executeRun's try block, + // before afterRun would normally run. + beforeRun: () => { throw new Error('rejected by policy') }, + afterRun: afterSpy, + } + const { agent } = buildMockAgent(config, 'should not reach') + const result = await agent.run('hi') + + expect(result.success).toBe(false) + expect(result.output).toContain('rejected by policy') + expect(afterSpy).not.toHaveBeenCalled() + }) + + // ----------------------------------------------------------------------- + // outputSchema + afterRun + // ----------------------------------------------------------------------- + + it('afterRun fires after structured output validation', async () => { + const schema = z.object({ answer: z.string() }) + + const config: AgentConfig = { + ...baseConfig, + outputSchema: schema, + afterRun: (result) => ({ ...result, output: '[post-processed] ' + result.output }), + } + // Return valid JSON matching the schema + const { agent } = buildMockAgent(config, '{"answer":"42"}') + const result = await agent.run('what is the answer?') + + expect(result.success).toBe(true) + expect(result.output).toBe('[post-processed] {"answer":"42"}') + expect(result.structured).toEqual({ answer: '42' }) + }) + + // ----------------------------------------------------------------------- + // ctx.agent does not contain hook self-references + // ----------------------------------------------------------------------- + + it('beforeRun context.agent has correct config without hook self-references', async () => { + let receivedAgent: AgentConfig | undefined + + const config: AgentConfig = { + ...baseConfig, + beforeRun: (ctx) => { + receivedAgent = ctx.agent + return ctx + }, + } + const { agent } = buildMockAgent(config, 'ok') + await agent.run('test') + + expect(receivedAgent).toBeDefined() + expect(receivedAgent!.name).toBe('test-agent') + expect(receivedAgent!.model).toBe('mock-model') + // Hook functions should be stripped to avoid circular references + expect(receivedAgent!.beforeRun).toBeUndefined() + expect(receivedAgent!.afterRun).toBeUndefined() + }) + + // ----------------------------------------------------------------------- + // Multiple prompt() turns fire hooks each time + // ----------------------------------------------------------------------- + + it('hooks fire on every prompt() call', async () => { + const beforeSpy = vi.fn((ctx) => ctx) + const afterSpy = vi.fn((result) => result) + + const config: AgentConfig = { + ...baseConfig, + beforeRun: beforeSpy, + afterRun: afterSpy, + } + const { agent } = buildMockAgent(config, 'reply') + + await agent.prompt('turn 1') + await agent.prompt('turn 2') + + expect(beforeSpy).toHaveBeenCalledTimes(2) + expect(afterSpy).toHaveBeenCalledTimes(2) + expect(beforeSpy.mock.calls[0]![0].prompt).toBe('turn 1') + expect(beforeSpy.mock.calls[1]![0].prompt).toBe('turn 2') + }) })