fix(agent): address review feedback on beforeRun/afterRun hooks
- Normalize stream done event to always yield AgentRunResult
- Move transitionTo('completed') after afterRun to fix state ordering
- Strip hook functions from BeforeRunHookContext.agent to avoid self-references
- Pass originalPrompt to applyHookContext to avoid redundant message scan
- Clarify afterRun JSDoc: not called when the run throws
- Add tests: error-path skip, outputSchema+afterRun, ctx.agent shape, multi-turn hooks
This commit is contained in:
parent
76e2d7c7fb
commit
9938b3d62d
|
|
@ -281,10 +281,9 @@ export class Agent {
|
|||
try {
|
||||
// --- beforeRun hook ---
|
||||
if (this.config.beforeRun) {
|
||||
const hookCtx = await this.config.beforeRun(
|
||||
this.buildBeforeRunHookContext(messages),
|
||||
)
|
||||
this.applyHookContext(messages, hookCtx)
|
||||
const hookCtx = this.buildBeforeRunHookContext(messages)
|
||||
const modified = await this.config.beforeRun(hookCtx)
|
||||
this.applyHookContext(messages, modified, hookCtx.prompt)
|
||||
}
|
||||
|
||||
const runner = await this.getRunner()
|
||||
|
|
@ -319,7 +318,6 @@ export class Agent {
|
|||
return validated
|
||||
}
|
||||
|
||||
this.transitionTo('completed')
|
||||
let agentResult = this.toAgentRunResult(result, true)
|
||||
|
||||
// --- afterRun hook ---
|
||||
|
|
@ -327,6 +325,7 @@ export class Agent {
|
|||
agentResult = await this.config.afterRun(agentResult)
|
||||
}
|
||||
|
||||
this.transitionTo('completed')
|
||||
this.emitAgentTrace(callerOptions, agentStartMs, agentResult)
|
||||
return agentResult
|
||||
} catch (err) {
|
||||
|
|
@ -461,10 +460,9 @@ export class Agent {
|
|||
try {
|
||||
// --- beforeRun hook ---
|
||||
if (this.config.beforeRun) {
|
||||
const hookCtx = await this.config.beforeRun(
|
||||
this.buildBeforeRunHookContext(messages),
|
||||
)
|
||||
this.applyHookContext(messages, hookCtx)
|
||||
const hookCtx = this.buildBeforeRunHookContext(messages)
|
||||
const modified = await this.config.beforeRun(hookCtx)
|
||||
this.applyHookContext(messages, modified, hookCtx.prompt)
|
||||
}
|
||||
|
||||
const runner = await this.getRunner()
|
||||
|
|
@ -473,15 +471,14 @@ export class Agent {
|
|||
if (event.type === 'done') {
|
||||
const result = event.data as import('./runner.js').RunResult
|
||||
this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage)
|
||||
this.transitionTo('completed')
|
||||
|
||||
// --- afterRun hook ---
|
||||
let agentResult = this.toAgentRunResult(result, true)
|
||||
if (this.config.afterRun) {
|
||||
const agentResult = this.toAgentRunResult(result, true)
|
||||
const modified = await this.config.afterRun(agentResult)
|
||||
yield { type: 'done', data: modified } satisfies StreamEvent
|
||||
continue
|
||||
agentResult = await this.config.afterRun(agentResult)
|
||||
}
|
||||
this.transitionTo('completed')
|
||||
yield { type: 'done', data: agentResult } satisfies StreamEvent
|
||||
continue
|
||||
} else if (event.type === 'error') {
|
||||
const error = event.data instanceof Error
|
||||
? event.data
|
||||
|
|
@ -514,7 +511,9 @@ export class Agent {
|
|||
break
|
||||
}
|
||||
}
|
||||
return { prompt, agent: this.config }
|
||||
// Strip hook functions to avoid circular self-references in the context
|
||||
const { beforeRun, afterRun, ...agentInfo } = this.config
|
||||
return { prompt, agent: agentInfo as AgentConfig }
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -525,9 +524,8 @@ export class Agent {
|
|||
* mutated in place) so that shallow copies of the original array (e.g. from
|
||||
* `prompt()`) are not affected.
|
||||
*/
|
||||
private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext): void {
|
||||
const original = this.buildBeforeRunHookContext(messages)
|
||||
if (ctx.prompt === original.prompt) return
|
||||
private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext, originalPrompt: string): void {
|
||||
if (ctx.prompt === originalPrompt) return
|
||||
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i]!.role === 'user') {
|
||||
|
|
|
|||
|
|
@ -222,8 +222,9 @@ export interface AgentConfig {
|
|||
*/
|
||||
readonly beforeRun?: (context: BeforeRunHookContext) => Promise<BeforeRunHookContext> | BeforeRunHookContext
|
||||
/**
|
||||
* Called after each agent run completes. Receives the run result.
|
||||
* Called after each agent run completes successfully. Receives the run result.
|
||||
* Return a (possibly modified) result, or throw to mark the run as failed.
|
||||
* Not called when the run throws. For error observation, handle errors at the call site.
|
||||
*/
|
||||
readonly afterRun?: (result: AgentRunResult) => Promise<AgentRunResult> | AgentRunResult
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { describe, it, expect, vi } from 'vitest'
|
||||
import { z } from 'zod'
|
||||
import { Agent } from '../src/agent/agent.js'
|
||||
import { AgentRunner } from '../src/agent/runner.js'
|
||||
import { ToolRegistry } from '../src/tool/framework.js'
|
||||
|
|
@ -377,4 +378,96 @@ describe('Agent hooks — beforeRun / afterRun', () => {
|
|||
const firstUserInHistory = history.find(m => m.role === 'user')
|
||||
expect((firstUserInHistory!.content[0] as any).text).toBe('original message')
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// afterRun NOT called on error
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('afterRun is not called when executeRun throws', async () => {
|
||||
const afterSpy = vi.fn((result) => result)
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
// Use beforeRun to trigger an error inside executeRun's try block,
|
||||
// before afterRun would normally run.
|
||||
beforeRun: () => { throw new Error('rejected by policy') },
|
||||
afterRun: afterSpy,
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'should not reach')
|
||||
const result = await agent.run('hi')
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.output).toContain('rejected by policy')
|
||||
expect(afterSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// outputSchema + afterRun
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('afterRun fires after structured output validation', async () => {
|
||||
const schema = z.object({ answer: z.string() })
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
outputSchema: schema,
|
||||
afterRun: (result) => ({ ...result, output: '[post-processed] ' + result.output }),
|
||||
}
|
||||
// Return valid JSON matching the schema
|
||||
const { agent } = buildMockAgent(config, '{"answer":"42"}')
|
||||
const result = await agent.run('what is the answer?')
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output).toBe('[post-processed] {"answer":"42"}')
|
||||
expect(result.structured).toEqual({ answer: '42' })
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ctx.agent does not contain hook self-references
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('beforeRun context.agent has correct config without hook self-references', async () => {
|
||||
let receivedAgent: AgentConfig | undefined
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: (ctx) => {
|
||||
receivedAgent = ctx.agent
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'ok')
|
||||
await agent.run('test')
|
||||
|
||||
expect(receivedAgent).toBeDefined()
|
||||
expect(receivedAgent!.name).toBe('test-agent')
|
||||
expect(receivedAgent!.model).toBe('mock-model')
|
||||
// Hook functions should be stripped to avoid circular references
|
||||
expect(receivedAgent!.beforeRun).toBeUndefined()
|
||||
expect(receivedAgent!.afterRun).toBeUndefined()
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Multiple prompt() turns fire hooks each time
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('hooks fire on every prompt() call', async () => {
|
||||
const beforeSpy = vi.fn((ctx) => ctx)
|
||||
const afterSpy = vi.fn((result) => result)
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: beforeSpy,
|
||||
afterRun: afterSpy,
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'reply')
|
||||
|
||||
await agent.prompt('turn 1')
|
||||
await agent.prompt('turn 2')
|
||||
|
||||
expect(beforeSpy).toHaveBeenCalledTimes(2)
|
||||
expect(afterSpy).toHaveBeenCalledTimes(2)
|
||||
expect(beforeSpy.mock.calls[0]![0].prompt).toBe('turn 1')
|
||||
expect(beforeSpy.mock.calls[1]![0].prompt).toBe('turn 2')
|
||||
})
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in New Issue