feat(agent): add beforeRun / afterRun lifecycle hooks (#45)
* feat(agent): add beforeRun / afterRun lifecycle hooks (#31) Add optional hook callbacks to AgentConfig for cross-cutting concerns (guardrails, logging, token budgets) without modifying framework internals. - beforeRun: receives prompt + agent config, can modify or throw to abort - afterRun: receives AgentRunResult, can modify or throw to fail - Works with all three execution modes: run(), prompt(), stream() - 15 test cases covering modify, throw, async, composition, and history integrity * fix(agent): preserve non-text content blocks in beforeRun hook - applyHookContext now replaces only text blocks, keeping images and tool results intact (was silently stripping them) - Use backward loop instead of reverse() + find() for efficiency - Clarify JSDoc that only `prompt` is applied from hook return value - Add test for mixed-content user messages * fix(agent): address review feedback on beforeRun/afterRun hooks - Normalize stream done event to always yield AgentRunResult - Move transitionTo('completed') after afterRun to fix state ordering - Strip hook functions from BeforeRunHookContext.agent to avoid self-references - Pass originalPrompt to applyHookContext to avoid redundant message scan - Clarify afterRun JSDoc: not called when the run throws - Add tests: error-path skip, outputSchema+afterRun, ctx.agent shape, multi-turn hooks
This commit is contained in:
parent
25b144acf3
commit
a1ccbfea61
|
|
@ -27,6 +27,7 @@ import type {
|
|||
AgentConfig,
|
||||
AgentState,
|
||||
AgentRunResult,
|
||||
BeforeRunHookContext,
|
||||
LLMMessage,
|
||||
StreamEvent,
|
||||
TokenUsage,
|
||||
|
|
@ -278,6 +279,13 @@ export class Agent {
|
|||
const agentStartMs = Date.now()
|
||||
|
||||
try {
|
||||
// --- beforeRun hook ---
|
||||
if (this.config.beforeRun) {
|
||||
const hookCtx = this.buildBeforeRunHookContext(messages)
|
||||
const modified = await this.config.beforeRun(hookCtx)
|
||||
this.applyHookContext(messages, modified, hookCtx.prompt)
|
||||
}
|
||||
|
||||
const runner = await this.getRunner()
|
||||
const internalOnMessage = (msg: LLMMessage) => {
|
||||
this.state.messages.push(msg)
|
||||
|
|
@ -296,18 +304,28 @@ export class Agent {
|
|||
|
||||
// --- Structured output validation ---
|
||||
if (this.config.outputSchema) {
|
||||
const validated = await this.validateStructuredOutput(
|
||||
let validated = await this.validateStructuredOutput(
|
||||
messages,
|
||||
result,
|
||||
runner,
|
||||
runOptions,
|
||||
)
|
||||
// --- afterRun hook ---
|
||||
if (this.config.afterRun) {
|
||||
validated = await this.config.afterRun(validated)
|
||||
}
|
||||
this.emitAgentTrace(callerOptions, agentStartMs, validated)
|
||||
return validated
|
||||
}
|
||||
|
||||
let agentResult = this.toAgentRunResult(result, true)
|
||||
|
||||
// --- afterRun hook ---
|
||||
if (this.config.afterRun) {
|
||||
agentResult = await this.config.afterRun(agentResult)
|
||||
}
|
||||
|
||||
this.transitionTo('completed')
|
||||
const agentResult = this.toAgentRunResult(result, true)
|
||||
this.emitAgentTrace(callerOptions, agentStartMs, agentResult)
|
||||
return agentResult
|
||||
} catch (err) {
|
||||
|
|
@ -440,13 +458,27 @@ export class Agent {
|
|||
this.transitionTo('running')
|
||||
|
||||
try {
|
||||
// --- beforeRun hook ---
|
||||
if (this.config.beforeRun) {
|
||||
const hookCtx = this.buildBeforeRunHookContext(messages)
|
||||
const modified = await this.config.beforeRun(hookCtx)
|
||||
this.applyHookContext(messages, modified, hookCtx.prompt)
|
||||
}
|
||||
|
||||
const runner = await this.getRunner()
|
||||
|
||||
for await (const event of runner.stream(messages)) {
|
||||
if (event.type === 'done') {
|
||||
const result = event.data as import('./runner.js').RunResult
|
||||
this.state.tokenUsage = addUsage(this.state.tokenUsage, result.tokenUsage)
|
||||
|
||||
let agentResult = this.toAgentRunResult(result, true)
|
||||
if (this.config.afterRun) {
|
||||
agentResult = await this.config.afterRun(agentResult)
|
||||
}
|
||||
this.transitionTo('completed')
|
||||
yield { type: 'done', data: agentResult } satisfies StreamEvent
|
||||
continue
|
||||
} else if (event.type === 'error') {
|
||||
const error = event.data instanceof Error
|
||||
? event.data
|
||||
|
|
@ -463,6 +495,50 @@ export class Agent {
|
|||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Hook helpers
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
/** Extract the prompt text from the last user message to build hook context. */
|
||||
private buildBeforeRunHookContext(messages: LLMMessage[]): BeforeRunHookContext {
|
||||
let prompt = ''
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i]!.role === 'user') {
|
||||
prompt = messages[i]!.content
|
||||
.filter((b): b is import('../types.js').TextBlock => b.type === 'text')
|
||||
.map(b => b.text)
|
||||
.join('')
|
||||
break
|
||||
}
|
||||
}
|
||||
// Strip hook functions to avoid circular self-references in the context
|
||||
const { beforeRun, afterRun, ...agentInfo } = this.config
|
||||
return { prompt, agent: agentInfo as AgentConfig }
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a (possibly modified) hook context back to the messages array.
|
||||
*
|
||||
* Only text blocks in the last user message are replaced; non-text content
|
||||
* (images, tool results) is preserved. The array element is replaced (not
|
||||
* mutated in place) so that shallow copies of the original array (e.g. from
|
||||
* `prompt()`) are not affected.
|
||||
*/
|
||||
private applyHookContext(messages: LLMMessage[], ctx: BeforeRunHookContext, originalPrompt: string): void {
|
||||
if (ctx.prompt === originalPrompt) return
|
||||
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i]!.role === 'user') {
|
||||
const nonTextBlocks = messages[i]!.content.filter(b => b.type !== 'text')
|
||||
messages[i] = {
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: ctx.prompt }, ...nonTextBlocks],
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// State transition helpers
|
||||
// -------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -147,6 +147,7 @@ export type {
|
|||
AgentConfig,
|
||||
AgentState,
|
||||
AgentRunResult,
|
||||
BeforeRunHookContext,
|
||||
ToolCallRecord,
|
||||
|
||||
// Team
|
||||
|
|
|
|||
20
src/types.ts
20
src/types.ts
|
|
@ -182,6 +182,14 @@ export interface ToolDefinition<TInput = Record<string, unknown>> {
|
|||
// Agent
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Context passed to the {@link AgentConfig.beforeRun} hook. */
|
||||
export interface BeforeRunHookContext {
|
||||
/** The user prompt text. */
|
||||
readonly prompt: string
|
||||
/** The agent's static configuration. */
|
||||
readonly agent: AgentConfig
|
||||
}
|
||||
|
||||
/** Static configuration for a single agent. */
|
||||
export interface AgentConfig {
|
||||
readonly name: string
|
||||
|
|
@ -207,6 +215,18 @@ export interface AgentConfig {
|
|||
* retry with error feedback is attempted on validation failure.
|
||||
*/
|
||||
readonly outputSchema?: ZodSchema
|
||||
/**
|
||||
* Called before each agent run. Receives the prompt and agent config.
|
||||
* Return a (possibly modified) context to continue, or throw to abort the run.
|
||||
* Only `prompt` from the returned context is applied; `agent` is read-only informational.
|
||||
*/
|
||||
readonly beforeRun?: (context: BeforeRunHookContext) => Promise<BeforeRunHookContext> | BeforeRunHookContext
|
||||
/**
|
||||
* Called after each agent run completes successfully. Receives the run result.
|
||||
* Return a (possibly modified) result, or throw to mark the run as failed.
|
||||
* Not called when the run throws. For error observation, handle errors at the call site.
|
||||
*/
|
||||
readonly afterRun?: (result: AgentRunResult) => Promise<AgentRunResult> | AgentRunResult
|
||||
}
|
||||
|
||||
/** Lifecycle state tracked during an agent run. */
|
||||
|
|
|
|||
|
|
@ -0,0 +1,473 @@
|
|||
import { describe, it, expect, vi } from 'vitest'
|
||||
import { z } from 'zod'
|
||||
import { Agent } from '../src/agent/agent.js'
|
||||
import { AgentRunner } from '../src/agent/runner.js'
|
||||
import { ToolRegistry } from '../src/tool/framework.js'
|
||||
import { ToolExecutor } from '../src/tool/executor.js'
|
||||
import type { AgentConfig, AgentRunResult, LLMAdapter, LLMMessage, LLMResponse } from '../src/types.js'
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Create a mock adapter that records every `chat()` call's messages
|
||||
* and returns a fixed text response.
|
||||
*/
|
||||
function mockAdapter(responseText: string) {
|
||||
const calls: LLMMessage[][] = []
|
||||
const adapter: LLMAdapter = {
|
||||
name: 'mock',
|
||||
async chat(messages) {
|
||||
calls.push([...messages])
|
||||
return {
|
||||
id: 'mock-1',
|
||||
content: [{ type: 'text' as const, text: responseText }],
|
||||
model: 'mock-model',
|
||||
stop_reason: 'end_turn',
|
||||
usage: { input_tokens: 10, output_tokens: 20 },
|
||||
} satisfies LLMResponse
|
||||
},
|
||||
async *stream() {
|
||||
/* unused */
|
||||
},
|
||||
}
|
||||
return { adapter, calls }
|
||||
}
|
||||
|
||||
/** Build an Agent with a mocked LLM, bypassing createAdapter. */
|
||||
function buildMockAgent(config: AgentConfig, responseText: string) {
|
||||
const { adapter, calls } = mockAdapter(responseText)
|
||||
const registry = new ToolRegistry()
|
||||
const executor = new ToolExecutor(registry)
|
||||
const agent = new Agent(config, registry, executor)
|
||||
|
||||
const runner = new AgentRunner(adapter, registry, executor, {
|
||||
model: config.model,
|
||||
systemPrompt: config.systemPrompt,
|
||||
maxTurns: config.maxTurns,
|
||||
maxTokens: config.maxTokens,
|
||||
temperature: config.temperature,
|
||||
agentName: config.name,
|
||||
})
|
||||
;(agent as any).runner = runner
|
||||
|
||||
return { agent, calls }
|
||||
}
|
||||
|
||||
const baseConfig: AgentConfig = {
|
||||
name: 'test-agent',
|
||||
model: 'mock-model',
|
||||
systemPrompt: 'You are a test agent.',
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('Agent hooks — beforeRun / afterRun', () => {
|
||||
// -----------------------------------------------------------------------
|
||||
// Baseline — no hooks
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('works normally without hooks', async () => {
|
||||
const { agent } = buildMockAgent(baseConfig, 'hello')
|
||||
const result = await agent.run('ping')
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output).toBe('hello')
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// beforeRun
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('beforeRun can modify the prompt', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: (ctx) => ({ ...ctx, prompt: 'modified prompt' }),
|
||||
}
|
||||
const { agent, calls } = buildMockAgent(config, 'response')
|
||||
await agent.run('original prompt')
|
||||
|
||||
// The adapter should have received the modified prompt.
|
||||
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
|
||||
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
|
||||
expect((textBlock as any).text).toBe('modified prompt')
|
||||
})
|
||||
|
||||
it('beforeRun that returns context unchanged does not alter prompt', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: (ctx) => ctx,
|
||||
}
|
||||
const { agent, calls } = buildMockAgent(config, 'response')
|
||||
await agent.run('keep this')
|
||||
|
||||
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
|
||||
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
|
||||
expect((textBlock as any).text).toBe('keep this')
|
||||
})
|
||||
|
||||
it('beforeRun throwing aborts the run with failure', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: () => { throw new Error('budget exceeded') },
|
||||
}
|
||||
const { agent, calls } = buildMockAgent(config, 'should not reach')
|
||||
const result = await agent.run('hi')
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.output).toContain('budget exceeded')
|
||||
// No LLM call should have been made.
|
||||
expect(calls).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('async beforeRun works', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: async (ctx) => {
|
||||
await Promise.resolve()
|
||||
return { ...ctx, prompt: 'async modified' }
|
||||
},
|
||||
}
|
||||
const { agent, calls } = buildMockAgent(config, 'ok')
|
||||
await agent.run('original')
|
||||
|
||||
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
|
||||
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
|
||||
expect((textBlock as any).text).toBe('async modified')
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// afterRun
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('afterRun can modify the result', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
afterRun: (result) => ({ ...result, output: 'modified output' }),
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'original output')
|
||||
const result = await agent.run('hi')
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output).toBe('modified output')
|
||||
})
|
||||
|
||||
it('afterRun throwing marks run as failed', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
afterRun: () => { throw new Error('content violation') },
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'bad content')
|
||||
const result = await agent.run('hi')
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.output).toContain('content violation')
|
||||
})
|
||||
|
||||
it('async afterRun works', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
afterRun: async (result) => {
|
||||
await Promise.resolve()
|
||||
return { ...result, output: result.output.toUpperCase() }
|
||||
},
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'hello')
|
||||
const result = await agent.run('hi')
|
||||
|
||||
expect(result.output).toBe('HELLO')
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Both hooks together
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('beforeRun and afterRun compose correctly', async () => {
|
||||
const hookOrder: string[] = []
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: (ctx) => {
|
||||
hookOrder.push('before')
|
||||
return { ...ctx, prompt: 'injected prompt' }
|
||||
},
|
||||
afterRun: (result) => {
|
||||
hookOrder.push('after')
|
||||
return { ...result, output: `[processed] ${result.output}` }
|
||||
},
|
||||
}
|
||||
const { agent, calls } = buildMockAgent(config, 'raw output')
|
||||
const result = await agent.run('original')
|
||||
|
||||
expect(hookOrder).toEqual(['before', 'after'])
|
||||
|
||||
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
|
||||
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
|
||||
expect((textBlock as any).text).toBe('injected prompt')
|
||||
|
||||
expect(result.output).toBe('[processed] raw output')
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// prompt() multi-turn mode
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('hooks fire on prompt() calls', async () => {
|
||||
const beforeSpy = vi.fn((ctx) => ctx)
|
||||
const afterSpy = vi.fn((result) => result)
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: beforeSpy,
|
||||
afterRun: afterSpy,
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'reply')
|
||||
await agent.prompt('hello')
|
||||
|
||||
expect(beforeSpy).toHaveBeenCalledOnce()
|
||||
expect(afterSpy).toHaveBeenCalledOnce()
|
||||
expect(beforeSpy.mock.calls[0]![0].prompt).toBe('hello')
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// stream() mode
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('beforeRun fires in stream mode', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: (ctx) => ({ ...ctx, prompt: 'stream modified' }),
|
||||
}
|
||||
const { agent, calls } = buildMockAgent(config, 'streamed')
|
||||
|
||||
const events = []
|
||||
for await (const event of agent.stream('original')) {
|
||||
events.push(event)
|
||||
}
|
||||
|
||||
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
|
||||
const textBlock = lastUserMsg!.content.find(b => b.type === 'text')
|
||||
expect((textBlock as any).text).toBe('stream modified')
|
||||
|
||||
// Should have at least a text event and a done event.
|
||||
expect(events.some(e => e.type === 'done')).toBe(true)
|
||||
})
|
||||
|
||||
it('afterRun fires in stream mode and modifies done event', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
afterRun: (result) => ({ ...result, output: 'stream modified output' }),
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'original')
|
||||
|
||||
const events = []
|
||||
for await (const event of agent.stream('hi')) {
|
||||
events.push(event)
|
||||
}
|
||||
|
||||
const doneEvent = events.find(e => e.type === 'done')
|
||||
expect(doneEvent).toBeDefined()
|
||||
expect((doneEvent!.data as AgentRunResult).output).toBe('stream modified output')
|
||||
})
|
||||
|
||||
it('beforeRun throwing in stream mode yields error event', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: () => { throw new Error('stream abort') },
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'unreachable')
|
||||
|
||||
const events = []
|
||||
for await (const event of agent.stream('hi')) {
|
||||
events.push(event)
|
||||
}
|
||||
|
||||
const errorEvent = events.find(e => e.type === 'error')
|
||||
expect(errorEvent).toBeDefined()
|
||||
expect((errorEvent!.data as Error).message).toContain('stream abort')
|
||||
})
|
||||
|
||||
it('afterRun throwing in stream mode yields error event', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
afterRun: () => { throw new Error('stream content violation') },
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'streamed output')
|
||||
|
||||
const events = []
|
||||
for await (const event of agent.stream('hi')) {
|
||||
events.push(event)
|
||||
}
|
||||
|
||||
// Text events may have been yielded before the error.
|
||||
const errorEvent = events.find(e => e.type === 'error')
|
||||
expect(errorEvent).toBeDefined()
|
||||
expect((errorEvent!.data as Error).message).toContain('stream content violation')
|
||||
// No done event should be present since afterRun rejected it.
|
||||
expect(events.find(e => e.type === 'done')).toBeUndefined()
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// prompt() history integrity
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('beforeRun modifying prompt preserves non-text content blocks', async () => {
|
||||
// Simulate a multi-turn message where the last user message has mixed content
|
||||
// (text + tool_result). beforeRun should only replace text, not strip other blocks.
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: (ctx) => ({ ...ctx, prompt: 'modified' }),
|
||||
}
|
||||
const { adapter, calls } = mockAdapter('ok')
|
||||
const registry = new ToolRegistry()
|
||||
const executor = new ToolExecutor(registry)
|
||||
const agent = new Agent(config, registry, executor)
|
||||
|
||||
const runner = new AgentRunner(adapter, registry, executor, {
|
||||
model: config.model,
|
||||
agentName: config.name,
|
||||
})
|
||||
;(agent as any).runner = runner
|
||||
|
||||
// Directly call run which creates a single text-only user message.
|
||||
// To test mixed content, we need to go through the private executeRun.
|
||||
// Instead, we test via prompt() after injecting history with mixed content.
|
||||
;(agent as any).messageHistory = [
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: [
|
||||
{ type: 'text' as const, text: 'original' },
|
||||
{ type: 'image' as const, source: { type: 'base64' as const, media_type: 'image/png', data: 'abc' } },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
// prompt() appends a new user message then calls executeRun with full history
|
||||
await agent.prompt('follow up')
|
||||
|
||||
// The last user message sent to the LLM should have modified text
|
||||
const sentMessages = calls[0]!
|
||||
const lastUser = [...sentMessages].reverse().find(m => m.role === 'user')!
|
||||
const textBlock = lastUser.content.find(b => b.type === 'text')
|
||||
expect((textBlock as any).text).toBe('modified')
|
||||
|
||||
// The earlier user message (with the image) should be untouched
|
||||
const firstUser = sentMessages.find(m => m.role === 'user')!
|
||||
const imageBlock = firstUser.content.find(b => b.type === 'image')
|
||||
expect(imageBlock).toBeDefined()
|
||||
})
|
||||
|
||||
it('beforeRun modifying prompt does not corrupt messageHistory', async () => {
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: (ctx) => ({ ...ctx, prompt: 'hook-modified' }),
|
||||
}
|
||||
const { agent, calls } = buildMockAgent(config, 'reply')
|
||||
|
||||
await agent.prompt('original message')
|
||||
|
||||
// The LLM should have received the modified prompt.
|
||||
const lastUserMsg = calls[0]!.find(m => m.role === 'user')
|
||||
expect((lastUserMsg!.content[0] as any).text).toBe('hook-modified')
|
||||
|
||||
// But the persistent history should retain the original message.
|
||||
const history = agent.getHistory()
|
||||
const firstUserInHistory = history.find(m => m.role === 'user')
|
||||
expect((firstUserInHistory!.content[0] as any).text).toBe('original message')
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// afterRun NOT called on error
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('afterRun is not called when executeRun throws', async () => {
|
||||
const afterSpy = vi.fn((result) => result)
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
// Use beforeRun to trigger an error inside executeRun's try block,
|
||||
// before afterRun would normally run.
|
||||
beforeRun: () => { throw new Error('rejected by policy') },
|
||||
afterRun: afterSpy,
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'should not reach')
|
||||
const result = await agent.run('hi')
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.output).toContain('rejected by policy')
|
||||
expect(afterSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// outputSchema + afterRun
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('afterRun fires after structured output validation', async () => {
|
||||
const schema = z.object({ answer: z.string() })
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
outputSchema: schema,
|
||||
afterRun: (result) => ({ ...result, output: '[post-processed] ' + result.output }),
|
||||
}
|
||||
// Return valid JSON matching the schema
|
||||
const { agent } = buildMockAgent(config, '{"answer":"42"}')
|
||||
const result = await agent.run('what is the answer?')
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output).toBe('[post-processed] {"answer":"42"}')
|
||||
expect(result.structured).toEqual({ answer: '42' })
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ctx.agent does not contain hook self-references
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('beforeRun context.agent has correct config without hook self-references', async () => {
|
||||
let receivedAgent: AgentConfig | undefined
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: (ctx) => {
|
||||
receivedAgent = ctx.agent
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'ok')
|
||||
await agent.run('test')
|
||||
|
||||
expect(receivedAgent).toBeDefined()
|
||||
expect(receivedAgent!.name).toBe('test-agent')
|
||||
expect(receivedAgent!.model).toBe('mock-model')
|
||||
// Hook functions should be stripped to avoid circular references
|
||||
expect(receivedAgent!.beforeRun).toBeUndefined()
|
||||
expect(receivedAgent!.afterRun).toBeUndefined()
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Multiple prompt() turns fire hooks each time
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
it('hooks fire on every prompt() call', async () => {
|
||||
const beforeSpy = vi.fn((ctx) => ctx)
|
||||
const afterSpy = vi.fn((result) => result)
|
||||
|
||||
const config: AgentConfig = {
|
||||
...baseConfig,
|
||||
beforeRun: beforeSpy,
|
||||
afterRun: afterSpy,
|
||||
}
|
||||
const { agent } = buildMockAgent(config, 'reply')
|
||||
|
||||
await agent.prompt('turn 1')
|
||||
await agent.prompt('turn 2')
|
||||
|
||||
expect(beforeSpy).toHaveBeenCalledTimes(2)
|
||||
expect(afterSpy).toHaveBeenCalledTimes(2)
|
||||
expect(beforeSpy.mock.calls[0]![0].prompt).toBe('turn 1')
|
||||
expect(beforeSpy.mock.calls[1]![0].prompt).toBe('turn 2')
|
||||
})
|
||||
})
|
||||
Loading…
Reference in New Issue