fix(agent): address review feedback on beforeRun/afterRun hooks
- Normalize stream done event to always yield AgentRunResult
- Move transitionTo('completed') after afterRun to fix state ordering
- Strip hook functions from BeforeRunHookContext.agent to avoid self-references
- Pass originalPrompt to applyHookContext to avoid redundant message scan
- Clarify afterRun JSDoc: not called when the run throws
- Add tests: error-path skip, outputSchema+afterRun, ctx.agent shape, multi-turn hooks
This commit is contained in:
parent
76e2d7c7fb
commit
9938b3d62d
|
|
@ -281,10 +281,9 @@ export class Agent {
|
||||||
try {
|
try {
|
||||||
// --- beforeRun hook ---
|
// --- beforeRun hook ---
|
||||||
if (this.config.beforeRun) {
|
if (this.config.beforeRun) {
|
||||||
const hookCtx = await this.config.beforeRun(
|
const hookCtx = this.buildBeforeRunHookContext(messages)
|
||||||
this.buildBeforeRunHookContext(messages),
|
const modified = await this.config.beforeRun(hookCtx)
|
||||||
)
|
this.applyHookContext(messages, modified, hookCtx.prompt)
|
||||||
this.applyHookContext(messages, hookCtx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const runner = await this.getRunner()
|
const runner = await this.getRunner()
|
||||||
|
|
@ -319,7 +318,6 @@ export class Agent {
|
||||||
return validated
|
return validated
|
||||||
}
|
}
|
||||||
|
|
||||||
this.transitionTo('completed')
|
|
||||||
let agentResult = this.toAgentRunResult(result, true)
|
let agentResult = this.toAgentRunResult(result, true)
|
||||||
|
|
||||||
// --- afterRun hook ---
|
// --- afterRun hook ---
|
||||||
|
|
@ -327,6 +325,7 @@ export class Agent {
|
||||||
agentResult = await this.config.afterRun(agentResult)
|
agentResult = await this.config.afterRun(agentResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.transitionTo('completed')
|
||||||
this.emitAgentTrace(callerOptions, agentStartMs, agentResult)
|
this.emitAgentTrace(callerOptions, agentStartMs, agentResult)
|
||||||
return agentResult
|
return agentResult
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|
@ -461,10 +460,9 @@ export class Agent {
|
||||||
try {
|
try {
|
||||||
// --- beforeRun hook ---
|
// --- beforeRun hook ---
|
||||||
if (this.config.beforeRun) {
|
if (this.config.beforeRun) {
|
||||||
const hookCtx = await this.config.beforeRun(
|
const hookCtx = this.buildBeforeRunHookContext(messages)
|
||||||
this.buildBeforeRunHookContext(messages),
|
const modified = await this.config.beforeRun(hookCtx)
|
||||||
)
|
this.applyHookContext(messages, modified, hookCtx.prompt)
|
||||||
this.applyHookContext(messages, hookCtx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const runner = await this.getRunner()
|
const runner = await this.getRunner()
|
||||||
|
|
@ -473,15 +471,14 @@ export class Agent {
|
||||||
if (event.type === 'done') {
|
if (event.type === 'done') {
|
||||||
const result = event.data as import('./runner.js').RunResult
|
const result = event.data as import('./runner.js').RunResult
|
||||||
this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage)
|
this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage)
|
||||||
this.transitionTo('completed')
|
|
||||||
|
|
||||||
// --- afterRun hook ---
|
let agentResult = this.toAgentRunResult(result, true)
|
||||||
if (this.config.afterRun) {
|
if (this.config.afterRun) {
|
||||||
const agentResult = this.toAgentRunResult(result, true)
|
agentResult = await this.config.afterRun(agentResult)
|
||||||
const modified = await this.config.afterRun(agentResult)
|
|
||||||
yield { type: 'done', data: modified } satisfies StreamEvent
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
this.transitionTo('completed')
|
||||||
|
yield { type: 'done', data: agentResult } satisfies StreamEvent
|
||||||
|
continue
|
||||||
} else if (event.type === 'error') {
|
} else if (event.type === 'error') {
|
||||||
const error = event.data instanceof Error
|
const error = event.data instanceof Error
|
||||||
? event.data
|
? event.data
|
||||||
|
|
@ -514,7 +511,9 @@ export class Agent {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return { prompt, agent: this.config }
|
// Strip hook functions to avoid circular self-references in the context
|
||||||
|
const { beforeRun, afterRun, ...agentInfo } = this.config
|
||||||
|
return { prompt, agent: agentInfo as AgentConfig }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -525,9 +524,8 @@ export class Agent {
|
||||||
* mutated in place) so that shallow copies of the original array (e.g. from
|
* mutated in place) so that shallow copies of the original array (e.g. from
|
||||||
* `prompt()`) are not affected.
|
* `prompt()`) are not affected.
|
||||||
*/
|
*/
|
||||||
private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext): void {
|
private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext, originalPrompt: string): void {
|
||||||
const original = this.buildBeforeRunHookContext(messages)
|
if (ctx.prompt === originalPrompt) return
|
||||||
if (ctx.prompt === original.prompt) return
|
|
||||||
|
|
||||||
for (let i = messages.length - 1; i >= 0; i--) {
|
for (let i = messages.length - 1; i >= 0; i--) {
|
||||||
if (messages[i]!.role === 'user') {
|
if (messages[i]!.role === 'user') {
|
||||||
|
|
|
||||||
|
|
@ -222,8 +222,9 @@ export interface AgentConfig {
|
||||||
*/
|
*/
|
||||||
readonly beforeRun?: (context: BeforeRunHookContext) => Promise<BeforeRunHookContext> | BeforeRunHookContext
|
readonly beforeRun?: (context: BeforeRunHookContext) => Promise<BeforeRunHookContext> | BeforeRunHookContext
|
||||||
/**
|
/**
|
||||||
* Called after each agent run completes. Receives the run result.
|
* Called after each agent run completes successfully. Receives the run result.
|
||||||
* Return a (possibly modified) result, or throw to mark the run as failed.
|
* Return a (possibly modified) result, or throw to mark the run as failed.
|
||||||
|
* Not called when the run throws. For error observation, handle errors at the call site.
|
||||||
*/
|
*/
|
||||||
readonly afterRun?: (result: AgentRunResult) => Promise<AgentRunResult> | AgentRunResult
|
readonly afterRun?: (result: AgentRunResult) => Promise<AgentRunResult> | AgentRunResult
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import { describe, it, expect, vi } from 'vitest'
|
import { describe, it, expect, vi } from 'vitest'
|
||||||
|
import { z } from 'zod'
|
||||||
import { Agent } from '../src/agent/agent.js'
|
import { Agent } from '../src/agent/agent.js'
|
||||||
import { AgentRunner } from '../src/agent/runner.js'
|
import { AgentRunner } from '../src/agent/runner.js'
|
||||||
import { ToolRegistry } from '../src/tool/framework.js'
|
import { ToolRegistry } from '../src/tool/framework.js'
|
||||||
|
|
@ -377,4 +378,96 @@ describe('Agent hooks — beforeRun / afterRun', () => {
|
||||||
const firstUserInHistory = history.find(m => m.role === 'user')
|
const firstUserInHistory = history.find(m => m.role === 'user')
|
||||||
expect((firstUserInHistory!.content[0] as any).text).toBe('original message')
|
expect((firstUserInHistory!.content[0] as any).text).toBe('original message')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// afterRun NOT called on error
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
it('afterRun is not called when executeRun throws', async () => {
|
||||||
|
const afterSpy = vi.fn((result) => result)
|
||||||
|
|
||||||
|
const config: AgentConfig = {
|
||||||
|
...baseConfig,
|
||||||
|
// Use beforeRun to trigger an error inside executeRun's try block,
|
||||||
|
// before afterRun would normally run.
|
||||||
|
beforeRun: () => { throw new Error('rejected by policy') },
|
||||||
|
afterRun: afterSpy,
|
||||||
|
}
|
||||||
|
const { agent } = buildMockAgent(config, 'should not reach')
|
||||||
|
const result = await agent.run('hi')
|
||||||
|
|
||||||
|
expect(result.success).toBe(false)
|
||||||
|
expect(result.output).toContain('rejected by policy')
|
||||||
|
expect(afterSpy).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// outputSchema + afterRun
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
it('afterRun fires after structured output validation', async () => {
|
||||||
|
const schema = z.object({ answer: z.string() })
|
||||||
|
|
||||||
|
const config: AgentConfig = {
|
||||||
|
...baseConfig,
|
||||||
|
outputSchema: schema,
|
||||||
|
afterRun: (result) => ({ ...result, output: '[post-processed] ' + result.output }),
|
||||||
|
}
|
||||||
|
// Return valid JSON matching the schema
|
||||||
|
const { agent } = buildMockAgent(config, '{"answer":"42"}')
|
||||||
|
const result = await agent.run('what is the answer?')
|
||||||
|
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
expect(result.output).toBe('[post-processed] {"answer":"42"}')
|
||||||
|
expect(result.structured).toEqual({ answer: '42' })
|
||||||
|
})
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// ctx.agent does not contain hook self-references
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
it('beforeRun context.agent has correct config without hook self-references', async () => {
|
||||||
|
let receivedAgent: AgentConfig | undefined
|
||||||
|
|
||||||
|
const config: AgentConfig = {
|
||||||
|
...baseConfig,
|
||||||
|
beforeRun: (ctx) => {
|
||||||
|
receivedAgent = ctx.agent
|
||||||
|
return ctx
|
||||||
|
},
|
||||||
|
}
|
||||||
|
const { agent } = buildMockAgent(config, 'ok')
|
||||||
|
await agent.run('test')
|
||||||
|
|
||||||
|
expect(receivedAgent).toBeDefined()
|
||||||
|
expect(receivedAgent!.name).toBe('test-agent')
|
||||||
|
expect(receivedAgent!.model).toBe('mock-model')
|
||||||
|
// Hook functions should be stripped to avoid circular references
|
||||||
|
expect(receivedAgent!.beforeRun).toBeUndefined()
|
||||||
|
expect(receivedAgent!.afterRun).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Multiple prompt() turns fire hooks each time
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
it('hooks fire on every prompt() call', async () => {
|
||||||
|
const beforeSpy = vi.fn((ctx) => ctx)
|
||||||
|
const afterSpy = vi.fn((result) => result)
|
||||||
|
|
||||||
|
const config: AgentConfig = {
|
||||||
|
...baseConfig,
|
||||||
|
beforeRun: beforeSpy,
|
||||||
|
afterRun: afterSpy,
|
||||||
|
}
|
||||||
|
const { agent } = buildMockAgent(config, 'reply')
|
||||||
|
|
||||||
|
await agent.prompt('turn 1')
|
||||||
|
await agent.prompt('turn 2')
|
||||||
|
|
||||||
|
expect(beforeSpy).toHaveBeenCalledTimes(2)
|
||||||
|
expect(afterSpy).toHaveBeenCalledTimes(2)
|
||||||
|
expect(beforeSpy.mock.calls[0]![0].prompt).toBe('turn 1')
|
||||||
|
expect(beforeSpy.mock.calls[1]![0].prompt).toBe('turn 2')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue