Middleware Implementation
Middleware allows you to enhance the behavior of language models by intercepting and modifying calls. You can implement custom middleware for logging, caching, guardrails, RAG, and more.Implementing middleware requires understanding of the language model specification.
Middleware Specification
Middleware implements theLanguageModelV3Middleware interface:
import type { LanguageModelV3Middleware } from '@ai-sdk/provider';
export const myMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
// Transform parameters before they reach the model
return params;
},
wrapGenerate: async ({ doGenerate, params }) => {
// Wrap the doGenerate method
const result = await doGenerate();
return result;
},
wrapStream: async ({ doStream, params }) => {
// Wrap the doStream method
const result = await doStream();
return result;
},
};
Middleware Methods
transformParams
Transforms parameters before they’re passed to the model:import type {
LanguageModelV3Middleware,
LanguageModelV3CallOptions,
} from '@ai-sdk/provider';
export const parameterMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
// Add a system message
return {
...params,
prompt: {
...params.prompt,
messages: [
{
role: 'system',
content: 'You are a helpful assistant.',
},
...params.prompt.messages,
],
},
} as LanguageModelV3CallOptions;
},
};
wrapGenerate
Wraps thedoGenerate method for non-streaming calls:
import type { LanguageModelV3Middleware } from '@ai-sdk/provider';
export const generateMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate, params }) => {
console.log('Generate called with:', params.prompt);
const startTime = Date.now();
const result = await doGenerate();
const duration = Date.now() - startTime;
console.log('Generate completed in', duration, 'ms');
console.log('Generated text:', result.text);
return result;
},
};
wrapStream
Wraps thedoStream method for streaming calls:
import type {
LanguageModelV3Middleware,
LanguageModelV3StreamPart,
} from '@ai-sdk/provider';
export const streamMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
wrapStream: async ({ doStream, params }) => {
const { stream, ...rest } = await doStream();
let fullText = '';
const transformStream = new TransformStream<
LanguageModelV3StreamPart,
LanguageModelV3StreamPart
>({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
fullText += chunk.textDelta;
}
controller.enqueue(chunk);
},
flush() {
console.log('Stream complete. Full text:', fullText);
},
});
return {
stream: stream.pipeThrough(transformStream),
...rest,
};
},
};
Real-World Middleware Examples
Logging Middleware
Log all model interactions:import type {
LanguageModelV3Middleware,
LanguageModelV3StreamPart,
} from '@ai-sdk/provider';
export const loggingMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate, params }) => {
const requestId = crypto.randomUUID();
console.log(`[${requestId}] Generate request:`, {
messages: params.prompt.messages.length,
maxTokens: params.maxOutputTokens,
});
const startTime = Date.now();
try {
const result = await doGenerate();
const duration = Date.now() - startTime;
console.log(`[${requestId}] Generate response:`, {
duration,
tokens: result.usage,
finishReason: result.finishReason,
});
return result;
} catch (error) {
console.error(`[${requestId}] Generate error:`, error);
throw error;
}
},
wrapStream: async ({ doStream, params }) => {
const requestId = crypto.randomUUID();
console.log(`[${requestId}] Stream request:`, {
messages: params.prompt.messages.length,
});
const { stream, ...rest } = await doStream();
let chunkCount = 0;
const transformStream = new TransformStream<
LanguageModelV3StreamPart,
LanguageModelV3StreamPart
>({
transform(chunk, controller) {
chunkCount++;
controller.enqueue(chunk);
},
flush() {
console.log(`[${requestId}] Stream complete:`, { chunks: chunkCount });
},
});
return {
stream: stream.pipeThrough(transformStream),
...rest,
};
},
};
Caching Middleware
Cache model responses:import type { LanguageModelV3Middleware } from '@ai-sdk/provider';
interface CacheEntry {
result: any;
timestamp: number;
}
export function createCachingMiddleware({
ttl = 3600000, // 1 hour
}: {
ttl?: number;
} = {}): LanguageModelV3Middleware {
const cache = new Map<string, CacheEntry>();
const getCacheKey = (params: any): string => {
return JSON.stringify({
messages: params.prompt.messages,
settings: {
temperature: params.temperature,
maxTokens: params.maxOutputTokens,
},
});
};
return {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate, params }) => {
const cacheKey = getCacheKey(params);
const cached = cache.get(cacheKey);
if (cached && Date.now() - cached.timestamp < ttl) {
console.log('Cache hit:', cacheKey.slice(0, 50));
return cached.result;
}
const result = await doGenerate();
cache.set(cacheKey, {
result,
timestamp: Date.now(),
});
// Clean up old entries
for (const [key, entry] of cache.entries()) {
if (Date.now() - entry.timestamp > ttl) {
cache.delete(key);
}
}
return result;
},
};
}
RAG Middleware
Add context from a vector database:import type { LanguageModelV3Middleware } from '@ai-sdk/provider';
export function createRAGMiddleware({
retrieveContext,
}: {
retrieveContext: (query: string) => Promise<string[]>;
}): LanguageModelV3Middleware {
return {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
// Get the last user message
const messages = params.prompt.messages;
const lastMessage = messages[messages.length - 1];
if (lastMessage?.role !== 'user') {
return params;
}
// Extract text from user message
const userText =
typeof lastMessage.content === 'string'
? lastMessage.content
: lastMessage.content
.filter((part: any) => part.type === 'text')
.map((part: any) => part.text)
.join(' ');
// Retrieve context
const contextChunks = await retrieveContext(userText);
if (contextChunks.length === 0) {
return params;
}
// Add context to the message
const contextText =
'\n\nRelevant context:\n' + contextChunks.join('\n\n');
const enhancedContent =
typeof lastMessage.content === 'string'
? lastMessage.content + contextText
: [
...lastMessage.content,
{ type: 'text', text: contextText },
];
return {
...params,
prompt: {
...params.prompt,
messages: [
...messages.slice(0, -1),
{
...lastMessage,
content: enhancedContent,
},
],
},
};
},
};
}
Guardrail Middleware
Filter sensitive information:import type {
LanguageModelV3Middleware,
LanguageModelV3StreamPart,
} from '@ai-sdk/provider';
export function createGuardrailMiddleware({
filters,
}: {
filters: Array<{
pattern: RegExp;
replacement: string;
}>;
}): LanguageModelV3Middleware {
const applyFilters = (text: string): string => {
let filtered = text;
for (const { pattern, replacement } of filters) {
filtered = filtered.replace(pattern, replacement);
}
return filtered;
};
return {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate }) => {
const result = await doGenerate();
return {
...result,
text: result.text ? applyFilters(result.text) : undefined,
};
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream();
const transformStream = new TransformStream<
LanguageModelV3StreamPart,
LanguageModelV3StreamPart
>({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
controller.enqueue({
...chunk,
textDelta: applyFilters(chunk.textDelta),
});
} else {
controller.enqueue(chunk);
}
},
});
return {
stream: stream.pipeThrough(transformStream),
...rest,
};
},
};
}
// Usage
const model = wrapLanguageModel({
model: openai('gpt-4'),
middleware: createGuardrailMiddleware({
filters: [
{ pattern: /\b\d{3}-\d{2}-\d{4}\b/g, replacement: '[SSN REDACTED]' },
{ pattern: /\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b/gi, replacement: '[EMAIL REDACTED]' },
],
}),
});
Retry Middleware
Add custom retry logic:import type { LanguageModelV3Middleware } from '@ai-sdk/provider';
export function createRetryMiddleware({
maxRetries = 3,
delay = 1000,
}: {
maxRetries?: number;
delay?: number;
} = {}): LanguageModelV3Middleware {
const retry = async <T>(
fn: () => Promise<T>,
attempt = 1,
): Promise<T> => {
try {
return await fn();
} catch (error) {
if (attempt >= maxRetries) {
throw error;
}
await new Promise(resolve => setTimeout(resolve, delay * attempt));
return retry(fn, attempt + 1);
}
};
return {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate }) => {
return retry(() => doGenerate());
},
wrapStream: async ({ doStream }) => {
return retry(() => doStream());
},
};
}
Composing Middleware
Combine multiple middleware:import { wrapLanguageModel } from 'ai';
import { openai } from '@ai-sdk/openai';
const model = wrapLanguageModel({
model: openai('gpt-4'),
middleware: [
loggingMiddleware,
createCachingMiddleware({ ttl: 3600000 }),
createRAGMiddleware({ retrieveContext }),
createGuardrailMiddleware({ filters }),
],
});
// Applied as: logging(caching(rag(guardrail(model))))
Accessing Provider Metadata
Pass custom metadata to middleware:import { generateText, wrapLanguageModel } from 'ai';
import { openai } from '@ai-sdk/openai';
import type { LanguageModelV3Middleware } from '@ai-sdk/provider';
export const metadataMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate, params }) => {
const metadata = params?.providerMetadata?.myMiddleware;
console.log('Request metadata:', metadata);
return doGenerate();
},
};
const { text } = await generateText({
model: wrapLanguageModel({
model: openai('gpt-4'),
middleware: metadataMiddleware,
}),
prompt: 'Hello',
providerOptions: {
myMiddleware: {
userId: '12345',
requestId: 'abc-def',
},
},
});
Testing Middleware
import { describe, it, expect, vi } from 'vitest';
import { generateText, wrapLanguageModel } from 'ai';
import { createMockLanguageModelV3 } from 'ai/test';
describe('Logging Middleware', () => {
it('should log requests and responses', async () => {
const consoleLogSpy = vi.spyOn(console, 'log');
const mockModel = createMockLanguageModelV3({
doGenerate: async () => ({
text: 'Hello!',
finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 },
}),
});
const model = wrapLanguageModel({
model: mockModel,
middleware: loggingMiddleware,
});
await generateText({
model,
prompt: 'Hi',
});
expect(consoleLogSpy).toHaveBeenCalled();
});
});
Best Practices
- Always handle both generate and stream: Implement both
wrapGenerateandwrapStreamfor consistent behavior - Preserve stream characteristics: Don’t buffer entire streams in memory
- Handle errors gracefully: Wrap operations in try-catch blocks
- Document side effects: Clearly document any caching, logging, or external calls
- Make middleware configurable: Accept options for customization
- Test thoroughly: Test both success and error paths
Next Steps
- Learn about error recovery strategies
- Explore custom providers
- Review built-in middleware