diff --git a/README.md b/README.md index 93d12c6..db04b87 100644 --- a/README.md +++ b/README.md @@ -86,66 +86,85 @@ const provider = createProvider('claude', { apiKey: 'your-key' }); ## 🎨 Structured Response Types -Define custom response types for type-safe, structured AI outputs: +Define custom response types for type-safe, structured AI outputs. The library automatically parses the AI's response into your desired type. ```typescript -import { createResponseType, validateResponseType } from 'simple-ai-provider'; +import { createResponseType, createClaudeProvider } from 'simple-ai-provider'; -// Define your response type -interface UserProfile { - name: string; - age: number; - email: string; - preferences: { - theme: 'light' | 'dark'; - notifications: boolean; - }; +// 1. Define your response type +interface ProductAnalysis { + productName: string; + priceRange: 'budget' | 'mid-range' | 'premium'; + pros: string[]; + cons: string[]; + overallRating: number; // 1-10 scale + recommendation: 'buy' | 'consider' | 'avoid'; } -const userProfileType = createResponseType( - `{ - name: string; - age: number; - email: string; - preferences: { - theme: 'light' | 'dark'; - notifications: boolean; - }; - }`, - 'A user profile with personal information and preferences', - { - name: 'John Doe', - age: 30, - email: 'john@example.com', - preferences: { theme: 'dark', notifications: true } - } +// 2. Create a ResponseType object +const productAnalysisType = createResponseType( + 'A comprehensive product analysis with pros, cons, rating, and recommendation' ); -// Use with any provider -const response = await claude.complete({ +// 3. Use with any provider +const claude = createClaudeProvider({ apiKey: 'your-key' }); +await claude.initialize(); + +const response = await claude.complete({ messages: [ - { role: 'user', content: 'Generate a user profile for a software developer' } + { role: 'user', content: 'Analyze the iPhone 15 Pro from a consumer perspective.' } ], - responseType: userProfileType, - maxTokens: 500 + responseType: productAnalysisType, + maxTokens: 800 }); -// Validate and get typed response -const validation = validateResponseType(response.content, userProfileType); -if (validation.isValid) { - const userProfile = validation.data as UserProfile; - console.log(`Name: ${userProfile.name}`); - console.log(`Theme: ${userProfile.preferences.theme}`); -} +// 4. Get the fully typed and parsed response +const analysis = response.content; +console.log(`Product: ${analysis.productName}`); +console.log(`Recommendation: ${analysis.recommendation}`); +console.log(`Rating: ${analysis.overallRating}/10`); ``` ### Key Benefits -- **Type Safety**: Get fully typed responses from AI providers -- **Automatic Prompting**: System prompts are automatically generated -- **Validation**: Built-in response validation and parsing -- **Consistency**: Ensures AI outputs match your expected format -- **Developer Experience**: IntelliSense and compile-time type checking +- **Automatic Parsing**: The AI's JSON response is automatically parsed into your specified type. +- **Type Safety**: Get fully typed responses from AI providers with IntelliSense. +- **Automatic Prompting**: System prompts are automatically generated to guide the AI. +- **Validation**: Built-in response validation and parsing logic. +- **Consistency**: Ensures AI outputs match your expected format. +- **Developer Experience**: Catch errors at compile-time instead of runtime. + +### Streaming with Response Types + +You can also use response types with streaming. The raw stream provides real-time text, and you can parse the final string once the stream is complete. + +```typescript +import { parseAndValidateResponseType } from 'simple-ai-provider'; + +const stream = claude.stream({ + messages: [{ role: 'user', content: 'Analyze the Tesla Model 3.' }], + responseType: productAnalysisType, + maxTokens: 600 +}); + +let fullResponse = ''; +for await (const chunk of stream) { + if (!chunk.isComplete) { + process.stdout.write(chunk.content); + fullResponse += chunk.content; + } else { + console.log('\n\nStream complete!'); + // Validate the complete streamed response + try { + const analysis = parseAndValidateResponseType(fullResponse, productAnalysisType); + console.log('Validation successful!'); + console.log(`Product: ${analysis.productName}`); + } catch (e) { + console.error('Validation failed:', (e as Error).message); + } + } +} +``` ## 📝 Environment Variables diff --git a/bun.lockb b/bun.lockb new file mode 100755 index 0000000..0fe08a2 Binary files /dev/null and b/bun.lockb differ diff --git a/examples/structured-response-types.ts b/examples/structured-response-types.ts index 3c2430d..857f8c4 100644 --- a/examples/structured-response-types.ts +++ b/examples/structured-response-types.ts @@ -9,7 +9,7 @@ import { createClaudeProvider, createResponseType, - validateResponseType, + parseAndValidateResponseType, AIProviderError, AIErrorType } from '../src/index.js'; @@ -31,16 +31,6 @@ interface UserProfile { } const userProfileType = createResponseType( - `{ - name: string; - age: number; - email: string; - preferences: { - theme: 'light' | 'dark'; - notifications: boolean; - }; - skills: string[]; - }`, 'A user profile with personal information, preferences, and skills', { name: 'John Doe', @@ -64,16 +54,6 @@ interface ProductAnalysis { } const productAnalysisType = createResponseType( - `{ - productName: string; - category: string; - priceRange: 'budget' | 'mid-range' | 'premium'; - pros: string[]; - cons: string[]; - overallRating: number; - recommendation: 'buy' | 'consider' | 'avoid'; - reasoning: string; - }`, 'A comprehensive product analysis with pros, cons, rating, and recommendation', { productName: 'Example Product', @@ -102,18 +82,6 @@ interface CodeReview { } const codeReviewType = createResponseType( - `{ - overallScore: number; - issues: Array<{ - type: 'error' | 'warning' | 'suggestion'; - line?: number; - message: string; - severity: 'low' | 'medium' | 'high'; - }>; - strengths: string[]; - improvements: string[]; - summary: string; - }`, 'A comprehensive code review with scoring, issues, and recommendations', { overallScore: 8, @@ -137,7 +105,7 @@ async function demonstrateUserProfileGeneration() { await claude.initialize(); try { - const response = await claude.complete({ + const response = await claude.complete({ messages: [ { role: 'user', @@ -149,23 +117,16 @@ async function demonstrateUserProfileGeneration() { temperature: 0.7 }); - console.log('Raw response:', response.content); + console.log('Raw response:', response.rawContent); - // Validate and parse the response - const validation = validateResponseType(response.content, userProfileType); - - if (validation.isValid) { - const userProfile = validation.data as UserProfile; - console.log('\nParsed User Profile:'); - console.log(`Name: ${userProfile.name}`); - console.log(`Age: ${userProfile.age}`); - console.log(`Email: ${userProfile.email}`); - console.log(`Theme Preference: ${userProfile.preferences.theme}`); - console.log(`Notifications: ${userProfile.preferences.notifications}`); - console.log(`Skills: ${userProfile.skills.join(', ')}`); - } else { - console.error('Validation failed:', validation.error); - } + const userProfile = response.content; + console.log('\nParsed User Profile:'); + console.log(`Name: ${userProfile.name}`); + console.log(`Age: ${userProfile.age}`); + console.log(`Email: ${userProfile.email}`); + console.log(`Theme Preference: ${userProfile.preferences.theme}`); + console.log(`Notifications: ${userProfile.preferences.notifications}`); + console.log(`Skills: ${userProfile.skills.join(', ')}`); } catch (error) { if (error instanceof AIProviderError) { @@ -183,7 +144,7 @@ async function demonstrateProductAnalysis() { await claude.initialize(); try { - const response = await claude.complete({ + const response = await claude.complete({ messages: [ { role: 'user', @@ -195,24 +156,18 @@ async function demonstrateProductAnalysis() { temperature: 0.5 }); - console.log('Raw response:', response.content); + console.log('Raw response:', response.rawContent); - const validation = validateResponseType(response.content, productAnalysisType); - - if (validation.isValid) { - const analysis = validation.data as ProductAnalysis; - console.log('\nProduct Analysis:'); - console.log(`Product: ${analysis.productName}`); - console.log(`Category: ${analysis.category}`); - console.log(`Price Range: ${analysis.priceRange}`); - console.log(`Overall Rating: ${analysis.overallRating}/10`); - console.log(`Recommendation: ${analysis.recommendation}`); - console.log(`Pros: ${analysis.pros.join(', ')}`); - console.log(`Cons: ${analysis.cons.join(', ')}`); - console.log(`Reasoning: ${analysis.reasoning}`); - } else { - console.error('Validation failed:', validation.error); - } + const analysis = response.content; + console.log('\nProduct Analysis:'); + console.log(`Product: ${analysis.productName}`); + console.log(`Category: ${analysis.category}`); + console.log(`Price Range: ${analysis.priceRange}`); + console.log(`Overall Rating: ${analysis.overallRating}/10`); + console.log(`Recommendation: ${analysis.recommendation}`); + console.log(`Pros: ${analysis.pros.join(', ')}`); + console.log(`Cons: ${analysis.cons.join(', ')}`); + console.log(`Reasoning: ${analysis.reasoning}`); } catch (error) { if (error instanceof AIProviderError) { @@ -248,39 +203,36 @@ function processPayment(amount, cardNumber) { `; try { - const response = await claude.complete({ + const response = await claude.complete({ messages: [ { role: 'user', - content: `Please review this JavaScript code and provide a comprehensive analysis:\n\n\`\`\`javascript\n${sampleCode}\n\`\`\`` - } + content: `Please review this JavaScript code and provide a comprehensive analysis:\n\n\ +`${sampleCode} +\ +``` ` + } ], responseType: codeReviewType, maxTokens: 1000, temperature: 0.3 }); - console.log('Raw response:', response.content); + console.log('Raw response:', response.rawContent); - const validation = validateResponseType(response.content, codeReviewType); - - if (validation.isValid) { - const review = validation.data as CodeReview; - console.log('\nCode Review:'); - console.log(`Overall Score: ${review.overallScore}/10`); - console.log(`Summary: ${review.summary}`); - console.log(`\nStrengths:`); - review.strengths.forEach(strength => console.log(` • ${strength}`)); - console.log(`\nImprovements:`); - review.improvements.forEach(improvement => console.log(` • ${improvement}`)); - console.log(`\nIssues:`); - review.issues.forEach(issue => { - const lineInfo = issue.line ? ` (line ${issue.line})` : ''; - console.log(` • [${issue.severity.toUpperCase()}] ${issue.type}: ${issue.message}${lineInfo}`); - }); - } else { - console.error('Validation failed:', validation.error); - } + const review = response.content; + console.log('\nCode Review:'); + console.log(`Overall Score: ${review.overallScore}/10`); + console.log(`Summary: ${review.summary}`); + console.log(`\nStrengths:`); + review.strengths.forEach(strength => console.log(` • ${strength}`)); + console.log(`\nImprovements:`); + review.improvements.forEach(improvement => console.log(` • ${improvement}`)); + console.log(`\nIssues:`); + review.issues.forEach(issue => { + const lineInfo = issue.line ? ` (line ${issue.line})` : ''; + console.log(` • [${issue.severity.toUpperCase()}] ${issue.type}: ${issue.message}${lineInfo}`); + }); } catch (error) { if (error instanceof AIProviderError) { @@ -319,11 +271,12 @@ async function demonstrateStreamingWithResponseType() { console.log('\n\nStream completed. Usage:', chunk.usage); // Validate the complete streamed response - const validation = validateResponseType(fullResponse, productAnalysisType); - if (validation.isValid) { + try { + const analysis = parseAndValidateResponseType(fullResponse, productAnalysisType); console.log('\nStreamed response validation: SUCCESS'); - } else { - console.log('\nStreamed response validation: FAILED -', validation.error); + console.log(`Product: ${analysis.productName}`); + } catch (e) { + console.log('\nStreamed response validation: FAILED -', (e as Error).message); } } } diff --git a/package.json b/package.json index cb2a720..c803cf6 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "simple-ai-provider", - "version": "1.2.0", + "version": "1.3.0", "description": "A simple and extensible AI provider package for easy integration of multiple AI services", "main": "dist/index.js", "module": "dist/index.mjs", diff --git a/src/index.ts b/src/index.ts index 4124249..a3e9332 100644 --- a/src/index.ts +++ b/src/index.ts @@ -27,7 +27,7 @@ export { AIProviderError, AIErrorType } from './types/index.js'; export { createResponseType, generateResponseTypePrompt, - validateResponseType + parseAndValidateResponseType } from './types/index.js'; // Base provider diff --git a/src/providers/base.ts b/src/providers/base.ts index 4a7bf9d..d52b1a8 100644 --- a/src/providers/base.ts +++ b/src/providers/base.ts @@ -22,7 +22,7 @@ import type { ProviderInfo, ResponseType } from '../types/index.js'; -import { AIProviderError, AIErrorType, generateResponseTypePrompt } from '../types/index.js'; +import { AIProviderError, AIErrorType, generateResponseTypePrompt, parseAndValidateResponseType } from '../types/index.js'; // ============================================================================ // ABSTRACT BASE PROVIDER CLASS @@ -148,7 +148,9 @@ export abstract class BaseAIProvider { * console.log(response.content); * ``` */ - public async complete(params: CompletionParams): Promise { + public async complete(params: CompletionParams): Promise>; + public async complete(params: CompletionParams): Promise>; + public async complete(params: CompletionParams): Promise> { // Ensure provider is ready for use this.ensureInitialized(); @@ -160,7 +162,23 @@ export abstract class BaseAIProvider { const processedParams = this.processResponseType(params); // Delegate to provider-specific implementation - return await this.doComplete(processedParams); + const response = await this.doComplete(processedParams); + + // If a responseType is defined, parse and validate the response + if (params.responseType) { + const parsedData = parseAndValidateResponseType(response.content, params.responseType); + return { + ...response, + content: parsedData, + rawContent: response.content, + }; + } + + // Otherwise, return the raw string content + return { + ...response, + content: response.content, + }; } catch (error) { // Normalize error to our standard format throw this.normalizeError(error as Error); @@ -249,7 +267,7 @@ export abstract class BaseAIProvider { * @returns Promise resolving to completion response * @throws {Error} If completion fails (will be normalized to AIProviderError) */ - protected abstract doComplete(params: CompletionParams): Promise; + protected abstract doComplete(params: CompletionParams): Promise>; /** * Provider-specific streaming implementation. diff --git a/src/providers/claude.ts b/src/providers/claude.ts index ba0ba77..167cabc 100644 --- a/src/providers/claude.ts +++ b/src/providers/claude.ts @@ -220,7 +220,7 @@ export class ClaudeProvider extends BaseAIProvider { * @returns Promise resolving to formatted completion response * @throws {Error} If API request fails */ - protected async doComplete(params: CompletionParams): Promise { + protected async doComplete(params: CompletionParams): Promise> { if (!this.client) { throw new AIProviderError('Claude client not initialized', AIErrorType.INVALID_REQUEST); } @@ -533,7 +533,7 @@ export class ClaudeProvider extends BaseAIProvider { * @returns Formatted completion response * @throws {AIProviderError} If response format is unexpected */ - private formatCompletionResponse(response: any): CompletionResponse { + private formatCompletionResponse(response: any): CompletionResponse { // Extract text content from response blocks const content = response.content ?.filter((block: any) => block.type === 'text') diff --git a/src/providers/gemini.ts b/src/providers/gemini.ts index ec55a58..8e3b6dd 100644 --- a/src/providers/gemini.ts +++ b/src/providers/gemini.ts @@ -268,7 +268,7 @@ export class GeminiProvider extends BaseAIProvider { * @returns Promise resolving to formatted completion response * @throws {Error} If API request fails */ - protected async doComplete(params: CompletionParams): Promise { + protected async doComplete(params: CompletionParams): Promise> { if (!this.client || !this.model) { throw new AIProviderError('Gemini client not initialized', AIErrorType.INVALID_REQUEST); } @@ -617,7 +617,7 @@ export class GeminiProvider extends BaseAIProvider { * @returns Formatted completion response * @throws {AIProviderError} If response format is unexpected */ - private formatCompletionResponse(response: any, model: string): CompletionResponse { + private formatCompletionResponse(response: any, model: string): CompletionResponse { // Handle multiple text parts in the response const candidate = response.candidates?.[0]; if (!candidate) { diff --git a/src/providers/openai.ts b/src/providers/openai.ts index 4bf8c9f..135157d 100644 --- a/src/providers/openai.ts +++ b/src/providers/openai.ts @@ -244,7 +244,7 @@ export class OpenAIProvider extends BaseAIProvider { * @returns Promise resolving to formatted completion response * @throws {Error} If API request fails */ - protected async doComplete(params: CompletionParams): Promise { + protected async doComplete(params: CompletionParams): Promise> { if (!this.client) { throw new AIProviderError('OpenAI client not initialized', AIErrorType.INVALID_REQUEST); } @@ -537,7 +537,7 @@ export class OpenAIProvider extends BaseAIProvider { * @returns Formatted completion response * @throws {AIProviderError} If response format is unexpected */ - private formatCompletionResponse(response: OpenAI.Chat.Completions.ChatCompletion): CompletionResponse { + private formatCompletionResponse(response: OpenAI.Chat.Completions.ChatCompletion): CompletionResponse { const choice = response.choices[0]; if (!choice || !choice.message.content) { throw new AIProviderError( diff --git a/src/providers/openwebui.ts b/src/providers/openwebui.ts index aa2322e..6a2fc40 100644 --- a/src/providers/openwebui.ts +++ b/src/providers/openwebui.ts @@ -524,7 +524,7 @@ export class OpenWebUIProvider extends BaseAIProvider { * @param response - Raw OpenWebUI response * @returns Formatted completion response */ - private formatChatResponse(response: OpenWebUIChatResponse): CompletionResponse { + private formatChatResponse(response: OpenWebUIChatResponse): CompletionResponse { const choice = response.choices[0]; if (!choice || !choice.message.content) { throw new AIProviderError( @@ -556,7 +556,7 @@ export class OpenWebUIProvider extends BaseAIProvider { * @param response - Raw Ollama response * @returns Formatted completion response */ - private formatOllamaResponse(response: OllamaGenerateResponse): CompletionResponse { + private formatOllamaResponse(response: OllamaGenerateResponse): CompletionResponse { return { content: response.response, model: response.model, @@ -962,7 +962,7 @@ export class OpenWebUIProvider extends BaseAIProvider { * @returns Promise resolving to formatted completion response * @throws {Error} If API request fails */ - protected async doComplete(params: CompletionParams): Promise { + protected async doComplete(params: CompletionParams): Promise> { if (this.useOllamaProxy) { return this.completeWithOllama(params); } else { diff --git a/src/types/index.ts b/src/types/index.ts index 0a882e1..c3d650f 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -41,7 +41,7 @@ export interface AIMessage { */ export interface ResponseType { /** The TypeScript type definition as a string */ - typeDefinition: string; + typeDefinition?: string; /** Human-readable description of the expected response format */ description: string; /** Example of the expected response structure */ @@ -89,9 +89,11 @@ export interface TokenUsage { /** * Response from an AI completion request */ -export interface CompletionResponse { - /** Generated content */ - content: string; +export interface CompletionResponse { + /** Generated content, either as a string or a typed object */ + content: T; + /** Raw response from the provider, if content is a typed object */ + rawContent?: string; /** Model used for generation */ model: string; /** Token usage statistics */ @@ -162,14 +164,13 @@ export interface ProviderInfo { capabilities?: Record; } -// ============================================================================ +// ============================================================================ // RESPONSE TYPE UTILITIES -// ============================================================================ +// ============================================================================ /** * Creates a response type definition for structured AI outputs * - * @param typeDefinition - TypeScript type definition as a string * @param description - Human-readable description of the expected format * @param example - Optional example of the expected response structure * @param strictJson - Whether to enforce strict JSON formatting (default: true) @@ -178,15 +179,6 @@ export interface ProviderInfo { * @example * ```typescript * const userType = createResponseType( - * `{ - * name: string; - * age: number; - * email: string; - * preferences: { - * theme: 'light' | 'dark'; - * notifications: boolean; - * }; - * }`, * 'A user profile with personal information and preferences', * { * name: 'John Doe', @@ -198,13 +190,11 @@ export interface ProviderInfo { * ``` */ export function createResponseType( - typeDefinition: string, description: string, example?: T, strictJson: boolean = true ): ResponseType { return { - typeDefinition: typeDefinition.trim(), description, example, strictJson @@ -220,55 +210,64 @@ export function createResponseType( export function generateResponseTypePrompt(responseType: ResponseType): string { const { typeDefinition, description, example, strictJson } = responseType; - let prompt = `You must respond with a JSON object that matches the following TypeScript type definition:\n\n`; - prompt += `Type Definition:\n\`\`\`typescript\n${typeDefinition}\n\`\`\`\n\n`; - prompt += `Description: ${description}\n\n`; + let prompt = 'You are an AI assistant that must respond with a JSON object.'; + + if (typeDefinition) { + prompt += ' The JSON object must strictly adhere to the following TypeScript type definition:\n\n'; + prompt += 'Type Definition:\n```typescript\n' + typeDefinition + '\n```\n\n'; + } else { + prompt += '\n\n'; + } + + prompt += 'Description: ' + description + '\n\n'; if (example) { - prompt += `Example:\n\`\`\`json\n${JSON.stringify(example, null, 2)}\n\`\`\`\n\n`; + prompt += 'Example of the expected JSON output:\n```json\n' + JSON.stringify(example, null, 2) + '\n```\n\n'; } if (strictJson) { - prompt += `IMPORTANT: Your response must be valid JSON only. Do not include any text before or after the JSON object. Do not use markdown formatting.`; + prompt += 'IMPORTANT: Your entire response must be a single, valid JSON object. Do not include any additional text, explanations, or markdown formatting before or after the JSON object.'; } else { - prompt += `Your response should follow the structure defined above.`; + prompt += 'Your response should contain a JSON object that follows the structure defined above.'; } return prompt; } /** - * Validates that a response matches the expected type structure + * Parses and validates that a response matches the expected type structure * * @param response - The response content to validate * @param responseType - The expected response type - * @returns Object with validation result and parsed data + * @returns The parsed and validated data object + * @throws {AIProviderError} If parsing or validation fails */ -export function validateResponseType( +export function parseAndValidateResponseType( response: string, responseType: ResponseType -): { isValid: boolean; data?: T; error?: string } { +): T { try { - // Parse JSON response + // Attempt to parse the JSON response const parsed = JSON.parse(response); - // Basic validation - in a real implementation, you might want to use - // a schema validation library like zod or ajv for more thorough validation + // Basic validation: ensure it's a non-null object. + // For more robust validation, a library like Zod or Ajv could be used + // based on the typeDefinition, but that's beyond the current scope. if (typeof parsed !== 'object' || parsed === null) { - return { - isValid: false, - error: 'Response must be a JSON object' - }; + throw new Error('Response must be a JSON object'); } - return { - isValid: true, - data: parsed as T - }; + // Here you could add more sophisticated validation if needed, e.g., + // checking keys against the responseType.typeDefinition + + return parsed as T; } catch (error) { - return { - isValid: false, - error: `Invalid JSON: ${(error as Error).message}` - }; + // Wrap parsing/validation errors in a standardized AIProviderError + throw new AIProviderError( + `Failed to parse or validate structured response: ${(error as Error).message}`, + AIErrorType.INVALID_REQUEST, + undefined, + error as Error + ); } -} \ No newline at end of file +} diff --git a/tests/claude.test.ts b/tests/claude.test.ts index 726566c..68241ec 100644 --- a/tests/claude.test.ts +++ b/tests/claude.test.ts @@ -2,8 +2,8 @@ * Tests for Claude Provider */ -import { describe, it, expect, beforeEach } from 'bun:test'; -import { ClaudeProvider, AIProviderError, AIErrorType } from '../src/index.js'; +import { describe, it, expect, beforeEach, jest } from 'bun:test'; +import { ClaudeProvider, createClaudeProvider, createResponseType, parseAndValidateResponseType, AIProviderError, AIErrorType } from '../src/index.js'; describe('ClaudeProvider', () => { let provider: ClaudeProvider; @@ -135,4 +135,32 @@ describe('ClaudeProvider', () => { expect(result.messages).toHaveLength(2); }); }); -}); \ No newline at end of file + + describe('structured responses', () => { + it('should handle structured responses correctly', async () => { + const provider = new ClaudeProvider({ apiKey: 'test-key' }); + const mockResponse = { + content: JSON.stringify({ name: 'John Doe', age: 30 }), + model: 'claude-3-5-sonnet-20241022', + usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 }, + id: 'test-id' + }; + (provider as any).doComplete = jest.fn().mockResolvedValue(mockResponse); + (provider as any).initialized = true; + + const responseType = createResponseType<{ name: string; age: number }>( + `{ name: string; age: number }`, + 'A user profile' + ); + + const response = await provider.complete<{ name: string; age: number }>({ + messages: [{ role: 'user', content: 'test' }], + responseType + }); + + expect(response.content).toEqual({ name: 'John Doe', age: 30 }); + expect(response.rawContent).toBe(JSON.stringify({ name: 'John Doe', age: 30 })); + expect((provider as any).doComplete).toHaveBeenCalled(); + }); + }); +}); \ No newline at end of file diff --git a/tests/gemini.test.ts b/tests/gemini.test.ts index b8c1631..b0b546d 100644 --- a/tests/gemini.test.ts +++ b/tests/gemini.test.ts @@ -2,8 +2,8 @@ * Tests for Gemini Provider */ -import { describe, it, expect, beforeEach } from 'bun:test'; -import { GeminiProvider, AIProviderError, AIErrorType } from '../src/index.js'; +import { describe, it, expect, beforeEach, jest } from 'bun:test'; +import { GeminiProvider, createResponseType, AIProviderError, AIErrorType } from '../src/index.js'; describe('GeminiProvider', () => { let provider: GeminiProvider; @@ -353,4 +353,32 @@ describe('GeminiProvider', () => { }).toThrow('No candidates found in Gemini response'); }); }); -}); \ No newline at end of file + + describe('structured responses', () => { + it('should handle structured responses correctly', async () => { + const provider = new GeminiProvider({ apiKey: 'test-key' }); + const mockResponse = { + content: JSON.stringify({ name: 'John Doe', age: 30 }), + model: 'gemini-1.5-pro', + usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 }, + id: 'test-id' + }; + (provider as any).doComplete = jest.fn().mockResolvedValue(mockResponse); + (provider as any).initialized = true; + + const responseType = createResponseType<{ name: string; age: number }> + (`{ name: string; age: number }`, + 'A user profile' + ); + + const response = await provider.complete<{ name: string; age: number }>({ + messages: [{ role: 'user', content: 'test' }], + responseType + }); + + expect(response.content).toEqual({ name: 'John Doe', age: 30 }); + expect(response.rawContent).toBe(JSON.stringify({ name: 'John Doe', age: 30 })); + expect((provider as any).doComplete).toHaveBeenCalled(); + }); + }); +}); \ No newline at end of file diff --git a/tests/openai.test.ts b/tests/openai.test.ts index 1908838..d8f8be2 100644 --- a/tests/openai.test.ts +++ b/tests/openai.test.ts @@ -2,8 +2,8 @@ * Tests for OpenAI Provider */ -import { describe, it, expect, beforeEach } from 'bun:test'; -import { OpenAIProvider, AIProviderError, AIErrorType } from '../src/index.js'; +import { describe, it, expect, beforeEach, jest } from 'bun:test'; +import { OpenAIProvider, createResponseType, AIProviderError, AIErrorType } from '../src/index.js'; describe('OpenAIProvider', () => { let provider: OpenAIProvider; @@ -201,7 +201,7 @@ describe('OpenAIProvider', () => { it('should handle messages with metadata', () => { const messages = [ - { + { role: 'user' as const, content: 'Hello', metadata: { timestamp: '2024-01-01' } @@ -258,4 +258,32 @@ describe('OpenAIProvider', () => { }).toThrow('No content found in OpenAI response'); }); }); -}); \ No newline at end of file + + describe('structured responses', () => { + it('should handle structured responses correctly', async () => { + const provider = new OpenAIProvider({ apiKey: 'test-key' }); + const mockResponse = { + content: JSON.stringify({ name: 'John Doe', age: 30 }), + model: 'gpt-4', + usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 }, + id: 'test-id' + }; + (provider as any).doComplete = jest.fn().mockResolvedValue(mockResponse); + (provider as any).initialized = true; + + const responseType = createResponseType<{ name: string; age: number }> + (`{ name: string; age: number }`, + 'A user profile' + ); + + const response = await provider.complete<{ name: string; age: number }>({ + messages: [{ role: 'user', content: 'test' }], + responseType + }); + + expect(response.content).toEqual({ name: 'John Doe', age: 30 }); + expect(response.rawContent).toBe(JSON.stringify({ name: 'John Doe', age: 30 })); + expect((provider as any).doComplete).toHaveBeenCalled(); + }); + }); +}); \ No newline at end of file diff --git a/tests/openwebui.test.ts b/tests/openwebui.test.ts index 82414bc..522111d 100644 --- a/tests/openwebui.test.ts +++ b/tests/openwebui.test.ts @@ -2,9 +2,9 @@ * Tests for OpenWebUI provider implementation */ -import { describe, it, expect, beforeEach } from 'bun:test'; +import { describe, it, expect, beforeEach, jest } from 'bun:test'; import { OpenWebUIProvider, type OpenWebUIConfig } from '../src/providers/openwebui.js'; -import { AIProviderError, AIErrorType, type CompletionParams } from '../src/types/index.js'; +import { AIProviderError, AIErrorType, type CompletionParams, createResponseType } from '../src/types/index.js'; describe('OpenWebUIProvider', () => { let provider: OpenWebUIProvider; @@ -416,4 +416,32 @@ describe('OpenWebUIProvider', () => { await expect(provider.complete(params)).rejects.toThrow('Provider must be initialized before use'); }); }); -}); \ No newline at end of file + + describe('structured responses', () => { + it('should handle structured responses correctly', async () => { + const provider = new OpenWebUIProvider({ apiKey: 'test-key' }); + const mockResponse = { + content: JSON.stringify({ name: 'John Doe', age: 30 }), + model: 'llama3.1', + usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 }, + id: 'test-id' + }; + (provider as any).doComplete = jest.fn().mockResolvedValue(mockResponse); + (provider as any).initialized = true; + + const responseType = createResponseType<{ name: string; age: number }> + (`{ name: string; age: number }`, + 'A user profile' + ); + + const response = await provider.complete<{ name: string; age: number }>({ + messages: [{ role: 'user', content: 'test' }], + responseType + }); + + expect(response.content).toEqual({ name: 'John Doe', age: 30 }); + expect(response.rawContent).toBe(JSON.stringify({ name: 'John Doe', age: 30 })); + expect((provider as any).doComplete).toHaveBeenCalled(); + }); + }); +}); \ No newline at end of file