diff --git a/examples/ts-react-chat/src/routes/api.tanchat.ts b/examples/ts-react-chat/src/routes/api.tanchat.ts index 39cf0d5f..eee05b63 100644 --- a/examples/ts-react-chat/src/routes/api.tanchat.ts +++ b/examples/ts-react-chat/src/routes/api.tanchat.ts @@ -36,7 +36,7 @@ Step 1: Call getGuitars() Step 2: Call recommendGuitar(id: "6") Step 3: Done - do NOT add any text after calling recommendGuitar ` -const addToCartToolServer = addToCartToolDef.server((args) => ({ +const addToCartToolServer = addToCartToolDef.server((args, options) => ({ success: true, cartId: 'CART_' + Date.now(), guitarId: args.guitarId, @@ -97,6 +97,13 @@ export const Route = createFileRoute('/api/tanchat')({ `[API Route] Using provider: ${provider}, model: ${selectedModel}`, ) + // Server-side context (e.g., database connections, user session) + // This is separate from client context and only used for server tools + const serverContext = { + // Add server-side context here if needed + // e.g., db, userId from session, etc. + } + const stream = chat({ adapter: adapter as any, model: selectedModel as any, @@ -112,6 +119,7 @@ export const Route = createFileRoute('/api/tanchat')({ messages, abortController, conversationId, + context: serverContext, }) return toStreamResponse(stream, { abortController }) } catch (error: any) { diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index 3b9e1787..f9d806e2 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -10,11 +10,19 @@ import type { ToolCallPart, UIMessage, } from './types' -import type { AnyClientTool, ModelMessage, StreamChunk } from '@tanstack/ai' +import type { + AnyClientTool, + ModelMessage, + StreamChunk, + ToolOptions, +} from '@tanstack/ai' import type { ConnectionAdapter } from './connection-adapters' import type { ChatClientEventEmitter } from './events' -export class ChatClient { +export class ChatClient< + TTools extends ReadonlyArray = any, + TContext = unknown, +> { private processor: StreamProcessor private connection: ConnectionAdapter private uniqueId: string @@ -26,6 +34,7 @@ export class ChatClient { private clientToolsRef: { current: Map } private currentStreamId: string | null = null private currentMessageId: string | null = null + private options: Partial> private callbacksRef: { current: { @@ -39,9 +48,10 @@ export class ChatClient { } } - constructor(options: ChatClientOptions) { + constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.body = options.body || {} + this.options = { context: options.context } this.connection = options.connection this.events = new DefaultChatClientEventEmitter(this.uniqueId) @@ -135,7 +145,9 @@ export class ChatClient { const clientTool = this.clientToolsRef.current.get(args.toolName) if (clientTool?.execute) { try { - const output = await clientTool.execute(args.input) + const output = await clientTool.execute(args.input, { + context: this.options.context, + }) await this.addToolResult({ toolCallId: args.toolCallId, tool: args.toolName, diff --git a/packages/typescript/ai-client/src/types.ts b/packages/typescript/ai-client/src/types.ts index b5e58962..cc513da5 100644 --- a/packages/typescript/ai-client/src/types.ts +++ b/packages/typescript/ai-client/src/types.ts @@ -133,6 +133,7 @@ export interface UIMessage = any> { export interface ChatClientOptions< TTools extends ReadonlyArray = any, + TContext = unknown, > { /** * Connection adapter for streaming @@ -208,6 +209,24 @@ export interface ChatClientOptions< */ chunkStrategy?: ChunkStrategy } + + /** + * Context object that is automatically passed to client-side tool execute functions. + * + * This allows client tools to access shared context (like user ID, local storage, + * browser APIs, etc.) without needing to capture them via closures. + * + * Note: This context is only used for client-side tools. Server tools should receive + * their own context from the server-side chat() function. + * + * @example + * const client = new ChatClient({ + * connection: fetchServerSentEvents('/api/chat'), + * context: { userId: '123', localStorage }, + * tools: [clientTool], + * }); + */ + context?: TContext } export interface ChatRequestBody { diff --git a/packages/typescript/ai-client/tests/chat-client.test.ts b/packages/typescript/ai-client/tests/chat-client.test.ts index 05ae21d3..9e5bb89e 100644 --- a/packages/typescript/ai-client/tests/chat-client.test.ts +++ b/packages/typescript/ai-client/tests/chat-client.test.ts @@ -1,4 +1,6 @@ import { describe, expect, it, vi } from 'vitest' +import { toolDefinition } from '@tanstack/ai' +import { z } from 'zod' import { ChatClient } from '../src/chat-client' import { createMockConnectionAdapter, @@ -6,6 +8,7 @@ import { createThinkingChunks, createToolCallChunks, } from './test-utils' +import type { ToolOptions } from '@tanstack/ai' import type { UIMessage } from '../src/types' describe('ChatClient', () => { @@ -515,7 +518,7 @@ describe('ChatClient', () => { // Should have at least one call for the assistant message const assistantAppendedCall = messageAppendedCalls.find(([, data]) => { const payload = data as Record - return payload && payload.role === 'assistant' + return payload.role === 'assistant' }) expect(assistantAppendedCall).toBeDefined() }) @@ -585,4 +588,146 @@ describe('ChatClient', () => { expect(thinkingCalls.length).toBeGreaterThan(0) }) }) + + describe('context support', () => { + it('should pass context to client tool execute functions', async () => { + interface TestContext { + userId: string + localStorage: { + setItem: (key: string, value: string) => void + getItem: (key: string) => string | null + } + } + + const mockStorage = { + setItem: vi.fn(), + getItem: vi.fn(() => null), + } + + const testContext: TestContext = { + userId: '123', + localStorage: mockStorage, + } + + const executeFn = vi.fn( + async ( + _args: any, + options: ToolOptions, + ) => { + const ctx = options.context as TestContext + ctx.localStorage.setItem( + `pref_${ctx.userId}_${_args.key}`, + _args.value, + ) + return { success: true } + }, + ) + + const toolDef = toolDefinition({ + name: 'savePreference', + description: 'Save user preference', + inputSchema: z.object({ + key: z.string(), + value: z.string(), + }), + outputSchema: z.object({ + success: z.boolean(), + }), + }) + + const tool = toolDef.client(executeFn) + + const chunks = createToolCallChunks([ + { + id: 'tool-1', + name: 'savePreference', + arguments: '{"key":"theme","value":"dark"}', + }, + ]) + const adapter = createMockConnectionAdapter({ chunks }) + + const client = new ChatClient({ + connection: adapter, + tools: [tool], + context: testContext, + }) + + await client.sendMessage('Save my preference') + + // Wait a bit for async tool execution + await new Promise((resolve) => setTimeout(resolve, 10)) + + // Tool should have been called with context + expect(executeFn).toHaveBeenCalled() + const lastCall = executeFn.mock.calls[0] + expect(lastCall?.[0]).toEqual({ key: 'theme', value: 'dark' }) + expect(lastCall?.[1]).toEqual({ context: testContext }) + + // localStorage should have been called + expect(mockStorage.setItem).toHaveBeenCalledWith('pref_123_theme', 'dark') + }) + + it('should not send context to server (context is only for client tools)', async () => { + const testContext = { + userId: '123', + sessionId: 'session-456', + } + + let capturedBody: any = null + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('Response'), + onConnect: (_messages, body) => { + capturedBody = body + }, + }) + + const client = new ChatClient({ + connection: adapter, + context: testContext, + }) + + await client.sendMessage('Hello') + + // Context should NOT be in the request body (only used for client tools) + expect(capturedBody).toBeDefined() + expect(capturedBody.context).toBeUndefined() + }) + + it('should work without context (context is optional)', async () => { + const executeFn = vi.fn(async (args: any) => { + return { result: args.value } + }) + + const toolDef = toolDefinition({ + name: 'simpleTool', + description: 'Simple tool', + inputSchema: z.object({ + value: z.string(), + }), + outputSchema: z.object({ + result: z.string(), + }), + }) + + const tool = toolDef.client(executeFn) + + const chunks = createToolCallChunks([ + { id: 'tool-1', name: 'simpleTool', arguments: '{"value":"test"}' }, + ]) + const adapter = createMockConnectionAdapter({ chunks }) + + const client = new ChatClient({ + connection: adapter, + tools: [tool], + }) + + await client.sendMessage('Test') + + // Tool should have been called without context + expect(executeFn).toHaveBeenCalledWith( + { value: 'test' }, + { context: undefined }, + ) + }) + }) }) diff --git a/packages/typescript/ai/src/core/chat.ts b/packages/typescript/ai/src/core/chat.ts index aebaf65a..19faaf62 100644 --- a/packages/typescript/ai/src/core/chat.ts +++ b/packages/typescript/ai/src/core/chat.ts @@ -16,6 +16,7 @@ import type { StreamChunk, Tool, ToolCall, + ToolOptions, } from '../types' interface ChatEngineConfig< @@ -45,6 +46,7 @@ class ChatEngine< private readonly streamId: string private readonly effectiveRequest?: Request | RequestInit private readonly effectiveSignal?: AbortSignal + private readonly options: Partial> private messages: Array private iterationCount = 0 @@ -75,6 +77,7 @@ class ChatEngine< ? { signal: config.params.abortController.signal } : undefined this.effectiveSignal = config.params.abortController?.signal + this.options = { context: config.params.context } } async *chat(): AsyncGenerator { @@ -381,6 +384,7 @@ class ChatEngine< this.tools, approvals, clientToolResults, + this.options, ) if ( @@ -449,6 +453,7 @@ class ChatEngine< this.tools, approvals, clientToolResults, + this.options, ) if ( diff --git a/packages/typescript/ai/src/tools/tool-calls.ts b/packages/typescript/ai/src/tools/tool-calls.ts index 48287245..e7b8a4b7 100644 --- a/packages/typescript/ai/src/tools/tool-calls.ts +++ b/packages/typescript/ai/src/tools/tool-calls.ts @@ -4,6 +4,7 @@ import type { SchemaInput, Tool, ToolCall, + ToolOptions, ToolResultStreamChunk, } from '../types' import type { z } from 'zod' @@ -42,7 +43,7 @@ function isZodSchema(schema: SchemaInput | undefined): schema is z.ZodType { * * // After stream completes, execute tools * if (manager.hasToolCalls()) { - * const toolResults = yield* manager.executeTools(doneChunk); + * const toolResults = yield* manager.executeTools(doneChunk, { context }); * messages = [...messages, ...toolResults]; * manager.clear(); * } @@ -120,6 +121,7 @@ export class ToolCallManager { */ async *executeTools( doneChunk: DoneStreamChunk, + options: Partial> = {}, ): AsyncGenerator, void> { const toolCallsArray = this.getToolCalls() const toolResults: Array = [] @@ -151,8 +153,10 @@ export class ToolCallManager { } } - // Execute the tool - let result = await tool.execute(args) + // Execute the tool with options + let result = await tool.execute(args, { + context: options.context, + }) // Validate output against outputSchema if provided (only for Zod schemas) if ( @@ -253,12 +257,14 @@ interface ExecuteToolCallsResult { * @param tools - Available tools with their configurations * @param approvals - Map of approval decisions (approval.id -> approved boolean) * @param clientResults - Map of client-side execution results (toolCallId -> result) + * @param options - Options object containing context to pass to tool execute functions */ export async function executeToolCalls( toolCalls: Array, tools: ReadonlyArray, approvals: Map = new Map(), clientResults: Map = new Map(), + options: Partial> = {}, ): Promise { const results: Array = [] const needsApproval: Array = [] @@ -390,7 +396,9 @@ export async function executeToolCalls( // Execute after approval const startTime = Date.now() try { - let result = await tool.execute(input) + let result = await tool.execute(input, { + context: options.context, + }) const duration = Date.now() - startTime // Validate output against outputSchema if provided (only for Zod schemas) @@ -453,7 +461,9 @@ export async function executeToolCalls( // CASE 3: Normal server tool - execute immediately const startTime = Date.now() try { - let result = await tool.execute(input) + let result = await tool.execute(input, { + context: options.context, + }) const duration = Date.now() - startTime // Validate output against outputSchema if provided (only for Zod schemas) diff --git a/packages/typescript/ai/src/tools/tool-definition.ts b/packages/typescript/ai/src/tools/tool-definition.ts index 7b6fb6bc..86ee983e 100644 --- a/packages/typescript/ai/src/tools/tool-definition.ts +++ b/packages/typescript/ai/src/tools/tool-definition.ts @@ -1,5 +1,11 @@ import type { z } from 'zod' -import type { InferSchemaType, JSONSchema, SchemaInput, Tool } from '../types' +import type { + InferSchemaType, + JSONSchema, + SchemaInput, + Tool, + ToolOptions, +} from '../types' /** * Marker type for server-side tools @@ -8,8 +14,10 @@ export interface ServerTool< TInput extends SchemaInput = z.ZodType, TOutput extends SchemaInput = z.ZodType, TName extends string = string, + TContext = unknown, > extends Tool { __toolSide: 'server' + __contextType?: TContext } /** @@ -19,8 +27,10 @@ export interface ClientTool< TInput extends SchemaInput = z.ZodType, TOutput extends SchemaInput = z.ZodType, TName extends string = string, + TContext = unknown, > { __toolSide: 'client' + __contextType?: TContext name: TName description: string inputSchema?: TInput @@ -29,6 +39,7 @@ export interface ClientTool< metadata?: Record execute?: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType } @@ -47,7 +58,7 @@ export interface ToolDefinitionInstance< * Union type for any kind of client-side tool (client tool or definition) */ export type AnyClientTool = - | ClientTool + | ClientTool | ToolDefinitionInstance /** @@ -104,20 +115,22 @@ export interface ToolDefinition< /** * Create a server-side tool with execute function */ - server: ( + server: ( execute: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType, - ) => ServerTool + ) => ServerTool /** * Create a client-side tool with optional execute function */ - client: ( + client: ( execute?: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType, - ) => ClientTool + ) => ClientTool } /** @@ -181,27 +194,31 @@ export function toolDefinition< const definition: ToolDefinition = { __toolSide: 'definition', ...config, - server( + server( execute: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType, - ): ServerTool { + ): ServerTool { return { __toolSide: 'server', + __contextType: undefined as TContext, ...config, - execute, + execute: execute as any, } }, - client( + client( execute?: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType, - ): ClientTool { + ): ClientTool { return { __toolSide: 'client', + __contextType: undefined as TContext, ...config, - execute, + execute: execute as any, } }, } diff --git a/packages/typescript/ai/src/types.ts b/packages/typescript/ai/src/types.ts index 748526c2..5f8acc7f 100644 --- a/packages/typescript/ai/src/types.ts +++ b/packages/typescript/ai/src/types.ts @@ -68,6 +68,14 @@ export interface ToolCall { } } +/** + * Options object passed to tool execute functions + * @template TContext - The type of context object + */ +export interface ToolOptions { + context?: TContext +} + // ============================================================================ // Multimodal Content Types // ============================================================================ @@ -403,15 +411,20 @@ export interface Tool< * Can return any value - will be automatically stringified if needed. * * @param args - The arguments parsed from the model's tool call (validated against inputSchema) + * @param options - Optional options object passed from chat() options (if provided) * @returns Result to send back to the model (validated against outputSchema if provided) * * @example - * execute: async (args) => { + * execute: async (args, options) => { + * const user = await options.context?.db.users.find({ id: options.context.userId }); // Can access context * const weather = await fetchWeather(args.location); * return weather; // Can return object or string * } */ - execute?: (args: any) => Promise | any + execute?: ( + args: InferSchemaType, + options: ToolOptions, + ) => Promise> | InferSchemaType /** If true, tool execution requires user approval before running. Works with both server and client tools. */ needsApproval?: boolean @@ -550,6 +563,7 @@ export interface ChatOptions< TProviderOptionsSuperset extends Record = Record, TOutput extends ResponseFormat | undefined = undefined, TProviderOptionsForModel = TProviderOptionsSuperset, + TContext = unknown, > { model: TModel messages: Array @@ -579,6 +593,29 @@ export interface ChatOptions< * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortController */ abortController?: AbortController + /** + * Context object that is automatically passed to all tool execute functions. + * + * This allows tools to access shared context (like user ID, database connections, + * request metadata, etc.) without needing to capture them via closures. + * Works for both server and client tools. + * + * @example + * const stream = chat({ + * adapter: openai(), + * model: 'gpt-4o', + * messages, + * context: { userId: '123', db }, + * tools: [getUserData], + * }); + * + * // In tool definition: + * const getUserData = getUserDataDef.server(async (args, options) => { + * // options.context.userId and options.context.db are available + * return await options.context.db.users.find({ userId: options.context.userId }); + * }); + */ + context?: TContext } export type StreamChunkType = diff --git a/packages/typescript/ai/tests/ai-chat.test.ts b/packages/typescript/ai/tests/ai-chat.test.ts index 2ec8a650..23c61327 100644 --- a/packages/typescript/ai/tests/ai-chat.test.ts +++ b/packages/typescript/ai/tests/ai-chat.test.ts @@ -5,7 +5,13 @@ import { chat } from '../src/core/chat' import { BaseAdapter } from '../src/base-adapter' import { aiEventClient } from '../src/event-client.js' import { maxIterations } from '../src/utilities/agent-loop-strategies' -import type { ChatOptions, ModelMessage, StreamChunk, Tool } from '../src/types' +import type { + ChatOptions, + ModelMessage, + StreamChunk, + Tool, + ToolOptions, +} from '../src/types' // Mock event client to track events const eventListeners = new Map) => void>>() @@ -452,7 +458,10 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { }), ) - expect(tool.execute).toHaveBeenCalledWith({ location: 'Paris' }) + expect(tool.execute).toHaveBeenCalledWith( + { location: 'Paris' }, + { context: undefined }, + ) expect(adapter.chatStreamCallCount).toBeGreaterThanOrEqual(2) const toolResultChunks = chunks.filter((c) => c.type === 'tool_result') @@ -559,7 +568,10 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { ) // Tool should be executed with complete arguments - expect(tool.execute).toHaveBeenCalledWith({ a: 10, b: 20 }) + expect(tool.execute).toHaveBeenCalledWith( + { a: 10, b: 20 }, + { context: undefined }, + ) const toolResultChunks = chunks.filter((c) => c.type === 'tool_result') expect(toolResultChunks.length).toBeGreaterThan(0) }) @@ -1489,7 +1501,10 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { const chunks = await collectChunks(stream) expect(chunks[0]?.type).toBe('tool_result') - expect(toolExecute).toHaveBeenCalledWith({ path: '/tmp/test.txt' }) + expect(toolExecute).toHaveBeenCalledWith( + { path: '/tmp/test.txt' }, + { context: undefined }, + ) expect(adapter.chatStreamCallCount).toBe(1) }) }) @@ -2601,7 +2616,10 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { await collectChunks(stream2) // Tool should have been executed because approval was provided - expect(tool.execute).toHaveBeenCalledWith({ path: '/tmp/test.txt' }) + expect(tool.execute).toHaveBeenCalledWith( + { path: '/tmp/test.txt' }, + { context: undefined }, + ) }) it('should extract client tool outputs from messages with parts', async () => { @@ -2998,4 +3016,236 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { expect(hasSeventy).toBe(true) }) }) + + describe('Context Support', () => { + it('should pass context to tool execute functions', async () => { + interface TestContext { + userId: string + db: { + users: { + find: (id: string) => Promise<{ name: string; email: string }> + } + } + } + + const contextTool = { + name: 'get_user', + description: 'Get user by ID', + inputSchema: z.object({ + userId: z.string(), + }), + execute: vi.fn( + async ( + args: any, + options: ToolOptions, + ) => { + const testContext = options.context as TestContext + const user = await testContext.db.users.find(args.userId) + return JSON.stringify({ + name: user.name, + email: user.email, + fromContext: testContext.userId, + }) + }, + ), + } satisfies Tool + + const mockDb = { + users: { + find: vi.fn(async (id: string) => ({ + name: `User ${id}`, + email: `user${id}@example.com`, + })), + }, + } + + const testContext: TestContext = { + userId: '123', + db: mockDb, + } + + class ContextToolAdapter extends MockAdapter { + iteration = 0 + async *chatStream(options: ChatOptions): AsyncIterable { + this.trackStreamCall(options) + const baseId = 'test-id' + + if (this.chatStreamCallCount === 1) { + // First iteration: request tool call + yield { + type: 'content', + id: baseId, + model: 'test-model', + timestamp: Date.now(), + delta: 'Getting user...', + content: 'Getting user...', + role: 'assistant', + } + yield { + type: 'tool_call', + id: baseId, + model: 'test-model', + timestamp: Date.now(), + toolCall: { + id: 'call_123', + type: 'function', + function: { + name: 'get_user', + arguments: '{"userId":"456"}', + }, + }, + index: 0, + } + yield { + type: 'done', + id: baseId, + model: 'test-model', + timestamp: Date.now(), + finishReason: 'tool_calls', + } + this.iteration++ + } else { + // Second iteration: should receive tool result + const toolResults = options.messages.filter( + (m) => m.role === 'tool', + ) + expect(toolResults.length).toBeGreaterThan(0) + + yield { + type: 'content', + id: `${baseId}-2`, + model: 'test-model', + timestamp: Date.now(), + delta: 'User found', + content: 'User found', + role: 'assistant', + } + yield { + type: 'done', + id: `${baseId}-2`, + model: 'test-model', + timestamp: Date.now(), + finishReason: 'stop', + } + this.iteration++ + } + } + } + + const adapter = new ContextToolAdapter() + + const stream = chat({ + adapter, + model: 'test-model', + messages: [{ role: 'user', content: 'Get user 456' }], + tools: [contextTool], + context: testContext, + agentLoopStrategy: maxIterations(2), + }) + + const chunks = await collectChunks(stream) + + // Tool should have been called with context + expect(contextTool.execute).toHaveBeenCalledWith( + { userId: '456' }, + { context: testContext }, + ) + + // Database should have been called + expect(mockDb.users.find).toHaveBeenCalledWith('456') + + // Should have tool result chunks + const toolResultChunks = chunks.filter((c) => c.type === 'tool_result') + expect(toolResultChunks.length).toBeGreaterThan(0) + + // Verify result contains context data + const resultContent = toolResultChunks[0]?.content || '' + const result = JSON.parse(resultContent) + expect(result.name).toBe('User 456') + expect(result.fromContext).toBe('123') + }) + + it('should work without context (context is optional)', async () => { + const noContextTool = { + name: 'simple_tool', + description: 'Simple tool', + inputSchema: z.object({ + value: z.string(), + }), + execute: vi.fn(async (args: any) => { + return JSON.stringify({ result: args.value }) + }), + } + + class SimpleToolAdapter extends MockAdapter { + iteration = 0 + async *chatStream(options: ChatOptions): AsyncIterable { + this.trackStreamCall(options) + const baseId = 'test-id' + + if (this.chatStreamCallCount === 1) { + yield { + type: 'tool_call', + id: baseId, + model: 'test-model', + timestamp: Date.now(), + toolCall: { + id: 'call_123', + type: 'function', + function: { + name: 'simple_tool', + arguments: '{"value":"test"}', + }, + }, + index: 0, + } + yield { + type: 'done', + id: baseId, + model: 'test-model', + timestamp: Date.now(), + finishReason: 'tool_calls', + } + this.iteration++ + } else { + yield { + type: 'content', + id: `${baseId}-2`, + model: 'test-model', + timestamp: Date.now(), + delta: 'Done', + content: 'Done', + role: 'assistant', + } + yield { + type: 'done', + id: `${baseId}-2`, + model: 'test-model', + timestamp: Date.now(), + finishReason: 'stop', + } + this.iteration++ + } + } + } + + const adapter = new SimpleToolAdapter() + + const stream = chat({ + adapter, + model: 'test-model', + messages: [{ role: 'user', content: 'Test' }], + tools: [noContextTool], + agentLoopStrategy: maxIterations(2), + }) + + await collectChunks(stream) + + // Tool should have been called without context + expect(noContextTool.execute).toHaveBeenCalledWith( + { value: 'test' }, + { context: undefined }, + ) + }) + }) }) diff --git a/packages/typescript/ai/tests/tool-call-manager.test.ts b/packages/typescript/ai/tests/tool-call-manager.test.ts index 9d74205c..e24d94a2 100644 --- a/packages/typescript/ai/tests/tool-call-manager.test.ts +++ b/packages/typescript/ai/tests/tool-call-manager.test.ts @@ -121,7 +121,10 @@ describe('ToolCallManager', () => { expect(finalResult[0]?.toolCallId).toBe('call_123') // Tool execute should have been called - expect(mockWeatherTool.execute).toHaveBeenCalledWith({ location: 'Paris' }) + expect(mockWeatherTool.execute).toHaveBeenCalledWith( + { location: 'Paris' }, + { context: undefined }, + ) }) it('should handle tool execution errors gracefully', async () => { diff --git a/packages/typescript/ai/tests/tool-definition.test.ts b/packages/typescript/ai/tests/tool-definition.test.ts index 0cb3a91c..b6dfd178 100644 --- a/packages/typescript/ai/tests/tool-definition.test.ts +++ b/packages/typescript/ai/tests/tool-definition.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, vi } from 'vitest' +import { describe, expect, it, vi } from 'vitest' import { z } from 'zod' import { toolDefinition } from '../src/tools/tool-definition' @@ -60,9 +60,15 @@ describe('toolDefinition', () => { expect(serverTool.execute).toBeDefined() if (serverTool.execute) { - const result = await serverTool.execute({ location: 'Paris' }) + const result = await serverTool.execute( + { location: 'Paris' }, + { context: undefined as any }, + ) expect(result).toEqual({ temperature: 72, conditions: 'sunny' }) - expect(executeFn).toHaveBeenCalledWith({ location: 'Paris' }) + expect(executeFn).toHaveBeenCalledWith( + { location: 'Paris' }, + { context: undefined as any }, + ) } }) @@ -90,9 +96,15 @@ describe('toolDefinition', () => { expect(clientTool.execute).toBeDefined() if (clientTool.execute) { - const result = await clientTool.execute({ key: 'test', value: 'data' }) + const result = await clientTool.execute( + { key: 'test', value: 'data' }, + { context: undefined as any }, + ) expect(result).toEqual({ success: true }) - expect(executeFn).toHaveBeenCalledWith({ key: 'test', value: 'data' }) + expect(executeFn).toHaveBeenCalledWith( + { key: 'test', value: 'data' }, + { context: undefined as any }, + ) } }) @@ -176,7 +188,10 @@ describe('toolDefinition', () => { }) if (serverTool.execute) { - const result = serverTool.execute({ value: 5 }) + const result = serverTool.execute( + { value: 5 }, + { context: undefined as any }, + ) expect(result).toEqual({ doubled: 10 }) } }) @@ -218,11 +233,14 @@ describe('toolDefinition', () => { }) // Verify it can be called - void serverTool.execute?.({ - orderId: '123', - items: [], - shipping: { address: '123 Main St', method: 'standard' }, - }) + void serverTool.execute?.( + { + orderId: '123', + items: [], + shipping: { address: '123 Main St', method: 'standard' }, + }, + { context: undefined as any }, + ) expect(serverTool.__toolSide).toBe('server') }) @@ -255,4 +273,67 @@ describe('toolDefinition', () => { expect(tool.__toolSide).toBe('definition') expect(tool.inputSchema).toBeDefined() }) + + describe('context support', () => { + it('should pass context to tool execute functions', async () => { + interface TestContext { + userId: string + db: { + users: { + find: (id: string) => Promise<{ name: string; email: string }> + } + } + } + + const tool = toolDefinition({ + name: 'getUser', + description: 'Get user by ID', + inputSchema: z.object({ + userId: z.string(), + }), + outputSchema: z.object({ + name: z.string(), + email: z.string(), + }), + }) + + const mockContext: TestContext = { + userId: '123', + db: { + users: { + find: vi.fn(async (_id: string) => ({ + name: 'John Doe', + email: 'john@example.com', + })), + }, + }, + } + + const executeFn = vi.fn( + async (args: any, options: { context: TestContext }) => { + const user = await options.context.db.users.find(args.userId) + return user + }, + ) + + const serverTool = tool.server(executeFn) + + if (serverTool.execute) { + const result = await serverTool.execute( + { userId: '123' }, + { context: mockContext }, + ) + + expect(result).toEqual({ + name: 'John Doe', + email: 'john@example.com', + }) + expect(executeFn).toHaveBeenCalledWith( + { userId: '123' }, + { context: mockContext }, + ) + expect(mockContext.db.users.find).toHaveBeenCalledWith('123') + } + }) + }) })