From 5da37f388fa892dca1a571c0b62a1e10d71a878c Mon Sep 17 00:00:00 2001 From: Jan-Marlon Leibl Date: Wed, 28 May 2025 12:13:29 +0200 Subject: [PATCH] feat: add Google Gemini provider integration and docs --- README.md | 205 ++++++++++++------- bun.lock | 3 + examples/multi-provider.ts | 99 ++++++++- package.json | 3 + src/index.ts | 4 +- src/providers/gemini.ts | 403 +++++++++++++++++++++++++++++++++++++ src/providers/index.ts | 3 +- src/utils/factory.ts | 25 ++- tests/gemini.test.ts | 357 ++++++++++++++++++++++++++++++++ 9 files changed, 1015 insertions(+), 87 deletions(-) create mode 100644 src/providers/gemini.ts create mode 100644 tests/gemini.test.ts diff --git a/README.md b/README.md index f62ac83..fda8446 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Simple AI Provider -A professional, extensible TypeScript package for integrating multiple AI providers into your applications with a unified interface. Currently supports **Claude (Anthropic)** and **OpenAI (GPT)** with plans to add more providers. +A professional, extensible TypeScript package for integrating multiple AI providers into your applications with a unified interface. Currently supports **Claude (Anthropic)**, **OpenAI (GPT)**, and **Google Gemini** with plans to add more providers. ## Features @@ -11,7 +11,7 @@ A professional, extensible TypeScript package for integrating multiple AI provid - šŸ›”ļø **Error Handling**: Robust error handling with categorized error types - šŸ”§ **Extensible**: Easy to add new AI providers - šŸ“¦ **Modern**: Built with ES modules and modern JavaScript features -- 🌐 **Multi-Provider**: Switch between Claude and OpenAI seamlessly +- 🌐 **Multi-Provider**: Switch between Claude, OpenAI, and Gemini seamlessly ## Installation @@ -71,39 +71,65 @@ const response = await openai.complete({ console.log(response.content); ``` +### Basic Usage with Gemini + +```typescript +import { createGeminiProvider } from 'simple-ai-provider'; + +// Create a Gemini provider +const gemini = createGeminiProvider('your-google-ai-api-key'); + +// Initialize the provider +await gemini.initialize(); + +// Generate a completion +const response = await gemini.complete({ + messages: [ + { role: 'user', content: 'Hello! How are you today?' } + ], + maxTokens: 100, + temperature: 0.7 +}); + +console.log(response.content); +``` + ### Multi-Provider Usage ```typescript -import { createProvider, createClaudeProvider, createOpenAIProvider } from 'simple-ai-provider'; +import { createProvider, createClaudeProvider, createOpenAIProvider, createGeminiProvider } from 'simple-ai-provider'; // Method 1: Using specific factory functions const claude = createClaudeProvider('your-anthropic-api-key'); const openai = createOpenAIProvider('your-openai-api-key'); +const gemini = createGeminiProvider('your-google-ai-api-key'); // Method 2: Using generic factory const claude2 = createProvider('claude', { apiKey: 'your-anthropic-api-key' }); const openai2 = createProvider('openai', { apiKey: 'your-openai-api-key' }); +const gemini2 = createProvider('gemini', { apiKey: 'your-google-ai-api-key' }); -// Initialize both -await Promise.all([claude.initialize(), openai.initialize()]); +// Initialize all +await Promise.all([claude.initialize(), openai.initialize(), gemini.initialize()]); -// Use the same interface for both providers +// Use the same interface for all providers const prompt = { messages: [{ role: 'user', content: 'Explain AI' }] }; const claudeResponse = await claude.complete(prompt); const openaiResponse = await openai.complete(prompt); +const geminiResponse = await gemini.complete(prompt); ``` ### Streaming Responses ```typescript -import { createOpenAIProvider } from 'simple-ai-provider'; +import { createGeminiProvider } from 'simple-ai-provider'; -const openai = createOpenAIProvider('your-openai-api-key'); -await openai.initialize(); +const gemini = createGeminiProvider('your-google-ai-api-key'); +await gemini.initialize(); // Stream a completion -for await (const chunk of openai.stream({ +for await (const chunk of gemini.stream({ messages: [ { role: 'user', content: 'Write a short story about a robot.' } ], @@ -120,7 +146,7 @@ for await (const chunk of openai.stream({ ### Advanced Configuration ```typescript -import { ClaudeProvider, OpenAIProvider } from 'simple-ai-provider'; +import { ClaudeProvider, OpenAIProvider, GeminiProvider } from 'simple-ai-provider'; // Claude with custom configuration const claude = new ClaudeProvider({ @@ -141,14 +167,28 @@ const openai = new OpenAIProvider({ maxRetries: 5 }); -await Promise.all([claude.initialize(), openai.initialize()]); +// Gemini with safety settings and generation config +const gemini = new GeminiProvider({ + apiKey: 'your-google-ai-api-key', + defaultModel: 'gemini-1.5-pro', + safetySettings: [], // Configure content filtering + generationConfig: { + temperature: 0.8, + topP: 0.9, + topK: 40, + maxOutputTokens: 2048 + }, + timeout: 45000 +}); -const response = await openai.complete({ +await Promise.all([claude.initialize(), openai.initialize(), gemini.initialize()]); + +const response = await gemini.complete({ messages: [ { role: 'system', content: 'You are a helpful assistant.' }, { role: 'user', content: 'Explain quantum computing in simple terms.' } ], - model: 'gpt-4-turbo', + model: 'gemini-1.5-flash', maxTokens: 300, temperature: 0.5, topP: 0.9, @@ -216,6 +256,20 @@ const openai = createOpenAIProvider('your-api-key', { }); ``` +#### `createGeminiProvider(apiKey, options?)` +Creates a Gemini provider with simplified configuration. + +```typescript +const gemini = createGeminiProvider('your-api-key', { + defaultModel: 'gemini-1.5-pro', + safetySettings: [], + generationConfig: { + temperature: 0.8, + topK: 40 + } +}); +``` + #### `createProvider(type, config)` Generic factory function for creating any provider type. @@ -229,6 +283,11 @@ const openai = createProvider('openai', { apiKey: 'your-api-key', defaultModel: 'gpt-4' }); + +const gemini = createProvider('gemini', { + apiKey: 'your-api-key', + defaultModel: 'gemini-1.5-flash' +}); ``` ### Provider Methods @@ -297,6 +356,13 @@ try { - `gpt-3.5-turbo-0125` - `gpt-3.5-turbo-1106` +### Google Gemini +- `gemini-1.5-flash` (default) +- `gemini-1.5-flash-8b` +- `gemini-1.5-pro` +- `gemini-1.0-pro` +- `gemini-1.0-pro-vision` + ## Environment Variables You can set your API keys as environment variables: @@ -304,35 +370,42 @@ You can set your API keys as environment variables: ```bash export ANTHROPIC_API_KEY="your-anthropic-api-key" export OPENAI_API_KEY="your-openai-api-key" +export GOOGLE_AI_API_KEY="your-google-ai-api-key" ``` ```typescript const claude = createClaudeProvider(process.env.ANTHROPIC_API_KEY!); const openai = createOpenAIProvider(process.env.OPENAI_API_KEY!); +const gemini = createGeminiProvider(process.env.GOOGLE_AI_API_KEY!); ``` ## Provider Comparison -| Feature | Claude | OpenAI | -|---------|--------|--------| -| **Models** | 5 models | 8+ models | -| **Max Context** | 200K tokens | 128K tokens | -| **Streaming** | āœ… | āœ… | -| **Vision** | āœ… | āœ… | -| **Function Calling** | āœ… | āœ… | -| **JSON Mode** | āŒ | āœ… | -| **System Messages** | āœ… (separate) | āœ… (inline) | +| Feature | Claude | OpenAI | Gemini | +|---------|--------|--------|--------| +| **Models** | 5 models | 8+ models | 5 models | +| **Max Context** | 200K tokens | 128K tokens | 1M tokens | +| **Streaming** | āœ… | āœ… | āœ… | +| **Vision** | āœ… | āœ… | āœ… | +| **Function Calling** | āœ… | āœ… | āœ… | +| **JSON Mode** | āŒ | āœ… | āŒ | +| **System Messages** | āœ… (separate) | āœ… (inline) | āœ… (separate) | +| **Multimodal** | āœ… | āœ… | āœ… | +| **Safety Controls** | Basic | Basic | Advanced | +| **Special Features** | Advanced reasoning | JSON mode, plugins | Largest context, advanced safety | ## Best Practices 1. **Always initialize providers** before using them 2. **Handle errors gracefully** with proper error types -3. **Use appropriate models** for your use case (speed vs. capability) +3. **Use appropriate models** for your use case (speed vs. capability vs. context) 4. **Set reasonable timeouts** for your application 5. **Implement retry logic** for production applications 6. **Monitor token usage** to control costs 7. **Use environment variables** for API keys 8. **Consider provider-specific features** when choosing +9. **Configure safety settings** appropriately for Gemini +10. **Leverage large context** capabilities of Gemini for complex tasks ## Advanced Usage @@ -342,70 +415,51 @@ const openai = createOpenAIProvider(process.env.OPENAI_API_KEY!); import { ProviderRegistry } from 'simple-ai-provider'; // List all registered providers -console.log(ProviderRegistry.getRegisteredProviders()); // ['claude', 'openai'] +console.log(ProviderRegistry.getRegisteredProviders()); // ['claude', 'openai', 'gemini'] // Create provider by name -const provider = ProviderRegistry.create('openai', { +const provider = ProviderRegistry.create('gemini', { apiKey: 'your-api-key' }); // Check if provider is registered -if (ProviderRegistry.isRegistered('claude')) { - console.log('Claude is available!'); +if (ProviderRegistry.isRegistered('gemini')) { + console.log('Gemini is available!'); } ``` -### Custom Error Handling +### Gemini-Specific Features ```typescript -function handleAIError(error: unknown, providerName: string) { - if (error instanceof AIProviderError) { - console.error(`${providerName} Error (${error.type}):`, error.message); - - if (error.statusCode) { - console.error('HTTP Status:', error.statusCode); - } - - if (error.originalError) { - console.error('Original Error:', error.originalError.message); +import { createGeminiProvider } from 'simple-ai-provider'; + +const gemini = createGeminiProvider('your-api-key', { + defaultModel: 'gemini-1.5-pro', + safetySettings: [ + { + category: 'HARM_CATEGORY_HARASSMENT', + threshold: 'BLOCK_MEDIUM_AND_ABOVE' } + ], + generationConfig: { + temperature: 0.9, + topP: 0.8, + topK: 40, + maxOutputTokens: 2048, + stopSequences: ['END', 'STOP'] } -} -``` +}); -## Extending the Package +await gemini.initialize(); -To add a new AI provider, extend the `BaseAIProvider` class: - -```typescript -import { BaseAIProvider } from 'simple-ai-provider'; - -class MyCustomProvider extends BaseAIProvider { - protected async doInitialize(): Promise { - // Initialize your provider - } - - protected async doComplete(params: CompletionParams): Promise { - // Implement completion logic - } - - protected async *doStream(params: CompletionParams): AsyncIterable { - // Implement streaming logic - } - - public getInfo(): ProviderInfo { - return { - name: 'MyCustomProvider', - version: '1.0.0', - models: ['my-model'], - maxContextLength: 4096, - supportsStreaming: true - }; - } -} - -// Register your provider -ProviderRegistry.register('mycustom', MyCustomProvider); +// Large context example (up to 1M tokens) +const response = await gemini.complete({ + messages: [ + { role: 'system', content: 'You are analyzing a large document.' }, + { role: 'user', content: 'Your very large text here...' } + ], + maxTokens: 2048 +}); ``` ## Contributing @@ -422,8 +476,11 @@ MIT - Initial release - Claude provider implementation - OpenAI provider implementation -- Streaming support for both providers +- Gemini provider implementation +- Streaming support for all providers - Comprehensive error handling - TypeScript support - Provider registry system - Multi-provider examples +- Large context support (Gemini) +- Advanced safety controls (Gemini) diff --git a/bun.lock b/bun.lock index 650ec78..104638e 100644 --- a/bun.lock +++ b/bun.lock @@ -5,6 +5,7 @@ "name": "simple-ai-provider", "dependencies": { "@anthropic-ai/sdk": "^0.52.0", + "@google/generative-ai": "^0.24.1", "openai": "^4.103.0", }, "devDependencies": { @@ -19,6 +20,8 @@ "packages": { "@anthropic-ai/sdk": ["@anthropic-ai/sdk@0.52.0", "", { "bin": { "anthropic-ai-sdk": "bin/cli" } }, "sha512-d4c+fg+xy9e46c8+YnrrgIQR45CZlAi7PwdzIfDXDM6ACxEZli1/fxhURsq30ZpMZy6LvSkr41jGq5aF5TD7rQ=="], + "@google/generative-ai": ["@google/generative-ai@0.24.1", "", {}, "sha512-MqO+MLfM6kjxcKoy0p1wRzG3b4ZZXtPI+z2IE26UogS2Cm/XHO+7gGRBh6gcJsOiIVoH93UwKvW4HdgiOZCy9Q=="], + "@types/bun": ["@types/bun@1.2.14", "", { "dependencies": { "bun-types": "1.2.14" } }, "sha512-VsFZKs8oKHzI7zwvECiAJ5oSorWndIWEVhfbYqZd4HI/45kzW7PN2Rr5biAzvGvRuNmYLSANY+H59ubHq8xw7Q=="], "@types/node": ["@types/node@20.17.51", "", { "dependencies": { "undici-types": "~6.19.2" } }, "sha512-hccptBl7C8lHiKxTBsY6vYYmqpmw1E/aGR/8fmueE+B390L3pdMOpNSRvFO4ZnXzW5+p2HBXV0yNABd2vdk22Q=="], diff --git a/examples/multi-provider.ts b/examples/multi-provider.ts index a602d65..dd60571 100644 --- a/examples/multi-provider.ts +++ b/examples/multi-provider.ts @@ -1,11 +1,12 @@ /** * Multi-provider example for Simple AI Provider - * Demonstrates how to use both Claude and OpenAI providers + * Demonstrates how to use Claude, OpenAI, and Gemini providers */ import { createClaudeProvider, createOpenAIProvider, + createGeminiProvider, createProvider, ProviderRegistry, AIProviderError, @@ -18,6 +19,7 @@ async function multiProviderExample() { // Get API keys from environment const claudeApiKey = process.env.ANTHROPIC_API_KEY || 'your-claude-api-key'; const openaiApiKey = process.env.OPENAI_API_KEY || 'your-openai-api-key'; + const geminiApiKey = process.env.GOOGLE_AI_API_KEY || 'your-gemini-api-key'; try { // Method 1: Using factory functions @@ -31,6 +33,10 @@ async function multiProviderExample() { defaultModel: 'gpt-3.5-turbo' }); + const gemini = createGeminiProvider(geminiApiKey, { + defaultModel: 'gemini-1.5-flash' + }); + console.log('āœ“ Providers created\n'); // Method 2: Using generic createProvider function @@ -46,6 +52,11 @@ async function multiProviderExample() { defaultModel: 'gpt-3.5-turbo' }); + const gemini2 = createProvider('gemini', { + apiKey: geminiApiKey, + defaultModel: 'gemini-1.5-flash' + }); + console.log('āœ“ Generic providers created\n'); // Method 3: Using Provider Registry @@ -63,7 +74,8 @@ async function multiProviderExample() { console.log('4. Initializing providers...'); await Promise.all([ claude.initialize(), - openai.initialize() + openai.initialize(), + gemini.initialize() ]); console.log('āœ“ All providers initialized\n'); @@ -71,12 +83,13 @@ async function multiProviderExample() { console.log('5. Provider Information:'); console.log('Claude Info:', claude.getInfo()); console.log('OpenAI Info:', openai.getInfo()); + console.log('Gemini Info:', gemini.getInfo()); console.log(); - // Test the same prompt with both providers + // Test the same prompt with all providers const testPrompt = 'Explain the concept of recursion in programming in one sentence.'; - console.log('6. Testing same prompt with both providers...'); + console.log('6. Testing same prompt with all providers...'); console.log(`Prompt: "${testPrompt}"\n`); // Claude response @@ -111,6 +124,22 @@ async function multiProviderExample() { console.log('Model:', openaiResponse.model); console.log(); + // Gemini response + console.log('--- Gemini Response ---'); + const geminiResponse = await gemini.complete({ + messages: [ + { role: 'system', content: 'You are a concise programming tutor.' }, + { role: 'user', content: testPrompt } + ], + maxTokens: 100, + temperature: 0.7 + }); + + console.log('Response:', geminiResponse.content); + console.log('Usage:', geminiResponse.usage); + console.log('Model:', geminiResponse.model); + console.log(); + // Streaming comparison console.log('7. Streaming comparison...'); console.log('Streaming from Claude:'); @@ -139,6 +168,19 @@ async function multiProviderExample() { } } + console.log('Streaming from Gemini:'); + + for await (const chunk of gemini.stream({ + messages: [{ role: 'user', content: 'Count from 1 to 5.' }], + maxTokens: 50 + })) { + if (!chunk.isComplete) { + process.stdout.write(chunk.content); + } else { + console.log('\nāœ“ Gemini streaming complete\n'); + } + } + // Provider-specific features demo console.log('8. Provider-specific features...'); @@ -158,6 +200,18 @@ async function multiProviderExample() { }); console.log('āœ“ Created Claude provider with custom settings'); + // Gemini with safety settings + const geminiCustom = createGeminiProvider(geminiApiKey, { + defaultModel: 'gemini-1.5-pro', + safetySettings: [], + generationConfig: { + temperature: 0.9, + topP: 0.8, + topK: 40 + } + }); + console.log('āœ“ Created Gemini provider with safety and generation settings'); + console.log('\nšŸŽ‰ Multi-provider example completed successfully!'); } catch (error) { @@ -169,7 +223,7 @@ async function multiProviderExample() { switch (error.type) { case AIErrorType.AUTHENTICATION: console.error('šŸ’” Hint: Check your API keys in environment variables'); - console.error(' Set ANTHROPIC_API_KEY and OPENAI_API_KEY'); + console.error(' Set ANTHROPIC_API_KEY, OPENAI_API_KEY, and GOOGLE_AI_API_KEY'); break; case AIErrorType.RATE_LIMIT: console.error('šŸ’” Hint: You are being rate limited. Wait and try again.'); @@ -192,18 +246,44 @@ async function compareProviders() { const providers = [ { name: 'Claude', factory: () => createClaudeProvider('dummy-key') }, - { name: 'OpenAI', factory: () => createOpenAIProvider('dummy-key') } + { name: 'OpenAI', factory: () => createOpenAIProvider('dummy-key') }, + { name: 'Gemini', factory: () => createGeminiProvider('dummy-key') } ]; console.log('\nProvider Capabilities:'); - console.log('| Provider | Models | Context | Streaming | Vision | Functions |'); - console.log('|----------|--------|---------|-----------|--------|-----------|'); + console.log('| Provider | Models | Context | Streaming | Vision | Functions | Multimodal |'); + console.log('|----------|--------|---------|-----------|--------|-----------|------------|'); for (const { name, factory } of providers) { const provider = factory(); const info = provider.getInfo(); - console.log(`| ${name.padEnd(8)} | ${info.models.length.toString().padEnd(6)} | ${info.maxContextLength.toLocaleString().padEnd(7)} | ${info.supportsStreaming ? 'āœ“' : 'āœ—'.padEnd(9)} | ${info.capabilities?.vision ? 'āœ“' : 'āœ—'.padEnd(6)} | ${info.capabilities?.functionCalling ? 'āœ“' : 'āœ—'.padEnd(9)} |`); + const contextStr = info.maxContextLength >= 1000000 + ? `${(info.maxContextLength / 1000000).toFixed(1)}M` + : `${(info.maxContextLength / 1000).toFixed(0)}K`; + + console.log(`| ${name.padEnd(8)} | ${info.models.length.toString().padEnd(6)} | ${contextStr.padEnd(7)} | ${info.supportsStreaming ? 'āœ“' : 'āœ—'.padEnd(9)} | ${info.capabilities?.vision ? 'āœ“' : 'āœ—'.padEnd(6)} | ${info.capabilities?.functionCalling ? 'āœ“' : 'āœ—'.padEnd(9)} | ${info.capabilities?.multimodal ? 'āœ“' : 'āœ—'.padEnd(10)} |`); + } + + console.log(); +} + +// Feature comparison +async function featureComparison() { + console.log('\n=== Feature Comparison ==='); + + const features = [ + ['Provider', 'Context Window', 'Streaming', 'Vision', 'Function Calling', 'System Messages', 'Special Features'], + ['Claude', '200K tokens', 'āœ…', 'āœ…', 'āœ…', 'āœ… (separate)', 'Advanced reasoning'], + ['OpenAI', '128K tokens', 'āœ…', 'āœ…', 'āœ…', 'āœ… (inline)', 'JSON mode, plugins'], + ['Gemini', '1M tokens', 'āœ…', 'āœ…', 'āœ…', 'āœ… (separate)', 'Largest context, multimodal'] + ]; + + for (const row of features) { + console.log('| ' + row.map(cell => cell.padEnd(15)).join(' | ') + ' |'); + if (row[0] === 'Provider') { + console.log('|' + ''.padEnd(row.length * 17 + row.length - 1, '-') + '|'); + } } console.log(); @@ -213,4 +293,5 @@ async function compareProviders() { if (import.meta.main) { await multiProviderExample(); await compareProviders(); + await featureComparison(); } \ No newline at end of file diff --git a/package.json b/package.json index 9fd1c98..b2278ce 100644 --- a/package.json +++ b/package.json @@ -33,6 +33,8 @@ "anthropic", "openai", "gpt", + "gemini", + "google", "provider", "typescript", "nodejs" @@ -45,6 +47,7 @@ }, "dependencies": { "@anthropic-ai/sdk": "^0.52.0", + "@google/generative-ai": "^0.24.1", "openai": "^4.103.0" }, "devDependencies": { diff --git a/src/index.ts b/src/index.ts index e556a55..efe97ef 100644 --- a/src/index.ts +++ b/src/index.ts @@ -29,12 +29,14 @@ export { BaseAIProvider } from './providers/base.js'; // Concrete provider implementations export { ClaudeProvider, type ClaudeConfig } from './providers/claude.js'; export { OpenAIProvider, type OpenAIConfig } from './providers/openai.js'; +export { GeminiProvider, type GeminiConfig } from './providers/gemini.js'; // Utility functions and factory export { createProvider, createClaudeProvider, createOpenAIProvider, + createGeminiProvider, ProviderRegistry, type ProviderType, type ProviderConfigMap @@ -51,4 +53,4 @@ export const VERSION = '1.0.0'; /** * List of supported providers */ -export const SUPPORTED_PROVIDERS = ['claude', 'openai'] as const; \ No newline at end of file +export const SUPPORTED_PROVIDERS = ['claude', 'openai', 'gemini'] as const; \ No newline at end of file diff --git a/src/providers/gemini.ts b/src/providers/gemini.ts new file mode 100644 index 0000000..fe06b1d --- /dev/null +++ b/src/providers/gemini.ts @@ -0,0 +1,403 @@ +/** + * Gemini Provider implementation using Google's Generative AI API + * Provides integration with Gemini models through a standardized interface + */ + +import { GoogleGenerativeAI, GenerativeModel } from '@google/generative-ai'; +import type { Content, Part } from '@google/generative-ai'; +import type { + AIProviderConfig, + CompletionParams, + CompletionResponse, + CompletionChunk, + ProviderInfo, + AIMessage +} from '../types/index.js'; +import { BaseAIProvider } from './base.js'; +import { AIProviderError, AIErrorType } from '../types/index.js'; + +/** + * Configuration specific to Gemini provider + */ +export interface GeminiConfig extends AIProviderConfig { + /** Default model to use if not specified in requests (default: gemini-1.5-flash) */ + defaultModel?: string; + /** Safety settings for content filtering */ + safetySettings?: any[]; + /** Generation configuration */ + generationConfig?: { + temperature?: number; + topP?: number; + topK?: number; + maxOutputTokens?: number; + stopSequences?: string[]; + }; +} + +/** + * Gemini provider implementation + */ +export class GeminiProvider extends BaseAIProvider { + private client: GoogleGenerativeAI | null = null; + private model: GenerativeModel | null = null; + private readonly defaultModel: string; + private readonly safetySettings?: any[]; + private readonly generationConfig?: any; + + constructor(config: GeminiConfig) { + super(config); + this.defaultModel = config.defaultModel || 'gemini-1.5-flash'; + this.safetySettings = config.safetySettings; + this.generationConfig = config.generationConfig; + } + + /** + * Initialize the Gemini provider by setting up the Google Generative AI client + */ + protected async doInitialize(): Promise { + try { + this.client = new GoogleGenerativeAI(this.config.apiKey); + this.model = this.client.getGenerativeModel({ + model: this.defaultModel, + safetySettings: this.safetySettings, + generationConfig: this.generationConfig + }); + + // Test the connection by making a simple request + await this.validateConnection(); + } catch (error) { + throw new AIProviderError( + `Failed to initialize Gemini provider: ${(error as Error).message}`, + AIErrorType.AUTHENTICATION, + undefined, + error as Error + ); + } + } + + /** + * Generate a completion using Gemini + */ + protected async doComplete(params: CompletionParams): Promise { + if (!this.client || !this.model) { + throw new AIProviderError('Client not initialized', AIErrorType.INVALID_REQUEST); + } + + try { + // Get the model for this request (might be different from default) + const model = params.model && params.model !== this.defaultModel + ? this.client.getGenerativeModel({ + model: params.model, + safetySettings: this.safetySettings, + generationConfig: this.buildGenerationConfig(params) + }) + : this.model; + + const { systemInstruction, contents } = this.convertMessages(params.messages); + + // Create chat session or use generateContent + if (contents.length > 1) { + // Multi-turn conversation - use chat session + const chat = model.startChat({ + history: contents.slice(0, -1), + systemInstruction, + generationConfig: this.buildGenerationConfig(params) + }); + + const lastMessage = contents[contents.length - 1]; + if (!lastMessage) { + throw new AIProviderError('No valid messages provided', AIErrorType.INVALID_REQUEST); + } + const result = await chat.sendMessage(lastMessage.parts); + + return this.formatCompletionResponse(result.response, params.model || this.defaultModel); + } else { + // Single message - use generateContent + const result = await model.generateContent({ + contents, + systemInstruction, + generationConfig: this.buildGenerationConfig(params) + }); + + return this.formatCompletionResponse(result.response, params.model || this.defaultModel); + } + } catch (error) { + throw this.handleGeminiError(error as Error); + } + } + + /** + * Generate a streaming completion using Gemini + */ + protected async *doStream(params: CompletionParams): AsyncIterable { + if (!this.client || !this.model) { + throw new AIProviderError('Client not initialized', AIErrorType.INVALID_REQUEST); + } + + try { + // Get the model for this request + const model = params.model && params.model !== this.defaultModel + ? this.client.getGenerativeModel({ + model: params.model, + safetySettings: this.safetySettings, + generationConfig: this.buildGenerationConfig(params) + }) + : this.model; + + const { systemInstruction, contents } = this.convertMessages(params.messages); + + let stream; + if (contents.length > 1) { + // Multi-turn conversation + const chat = model.startChat({ + history: contents.slice(0, -1), + systemInstruction, + generationConfig: this.buildGenerationConfig(params) + }); + + const lastMessage = contents[contents.length - 1]; + if (!lastMessage) { + throw new AIProviderError('No valid messages provided', AIErrorType.INVALID_REQUEST); + } + stream = await chat.sendMessageStream(lastMessage.parts); + } else { + // Single message + stream = await model.generateContentStream({ + contents, + systemInstruction, + generationConfig: this.buildGenerationConfig(params) + }); + } + + let fullText = ''; + const requestId = `gemini-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + + for await (const chunk of stream.stream) { + const chunkText = chunk.text(); + fullText += chunkText; + + yield { + content: chunkText, + isComplete: false, + id: requestId + }; + } + + // Final chunk with usage info + const finalResponse = await stream.response; + const usageMetadata = finalResponse.usageMetadata; + + yield { + content: '', + isComplete: true, + id: requestId, + usage: { + promptTokens: usageMetadata?.promptTokenCount || 0, + completionTokens: usageMetadata?.candidatesTokenCount || 0, + totalTokens: usageMetadata?.totalTokenCount || 0 + } + }; + } catch (error) { + throw this.handleGeminiError(error as Error); + } + } + + /** + * Get information about the Gemini provider + */ + public getInfo(): ProviderInfo { + return { + name: 'Gemini', + version: '1.0.0', + models: [ + 'gemini-1.5-flash', + 'gemini-1.5-flash-8b', + 'gemini-1.5-pro', + 'gemini-1.0-pro', + 'gemini-1.0-pro-vision' + ], + maxContextLength: 1000000, // Gemini 1.5 context length + supportsStreaming: true, + capabilities: { + vision: true, + functionCalling: true, + systemMessages: true, + multimodal: true, + largeContext: true + } + }; + } + + /** + * Validate the connection by making a simple request + */ + private async validateConnection(): Promise { + if (!this.model) { + throw new Error('Model not initialized'); + } + + try { + // Make a minimal request to validate credentials + await this.model.generateContent('Hi'); + } catch (error: any) { + if (error.message?.includes('API key') || error.message?.includes('authentication')) { + throw new AIProviderError( + 'Invalid API key. Please check your Google AI API key.', + AIErrorType.AUTHENTICATION + ); + } + // For other errors during validation, we'll let initialization proceed + // as they might be temporary issues + } + } + + /** + * Convert our generic message format to Gemini's format + * Gemini uses Contents with Parts and supports system instructions separately + */ + private convertMessages(messages: AIMessage[]): { systemInstruction?: string; contents: Content[] } { + let systemInstruction: string | undefined; + const contents: Content[] = []; + + for (const message of messages) { + if (message.role === 'system') { + // Combine multiple system messages + if (systemInstruction) { + systemInstruction += '\n\n' + message.content; + } else { + systemInstruction = message.content; + } + } else { + contents.push({ + role: message.role === 'assistant' ? 'model' : 'user', + parts: [{ text: message.content }] + }); + } + } + + return { systemInstruction, contents }; + } + + /** + * Build generation config from completion parameters + */ + private buildGenerationConfig(params: CompletionParams) { + return { + temperature: params.temperature ?? 0.7, + topP: params.topP, + maxOutputTokens: params.maxTokens || 1000, + stopSequences: params.stopSequences, + ...this.generationConfig + }; + } + + /** + * Format Gemini's response to our standard format + */ + private formatCompletionResponse(response: any, model: string): CompletionResponse { + const candidate = response.candidates?.[0]; + if (!candidate || !candidate.content?.parts?.[0]?.text) { + throw new AIProviderError( + 'No content in Gemini response', + AIErrorType.UNKNOWN + ); + } + + const content = candidate.content.parts + .filter((part: Part) => part.text) + .map((part: Part) => part.text) + .join(''); + + const usageMetadata = response.usageMetadata; + const requestId = `gemini-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + + return { + content, + model, + usage: { + promptTokens: usageMetadata?.promptTokenCount || 0, + completionTokens: usageMetadata?.candidatesTokenCount || 0, + totalTokens: usageMetadata?.totalTokenCount || 0 + }, + id: requestId, + metadata: { + finishReason: candidate.finishReason, + safetyRatings: candidate.safetyRatings, + citationMetadata: candidate.citationMetadata + } + }; + } + + /** + * Handle Gemini-specific errors and convert them to our standard format + */ + private handleGeminiError(error: any): AIProviderError { + if (error instanceof AIProviderError) { + return error; + } + + const message = error.message || 'Unknown Gemini API error'; + + // Handle common Gemini error patterns + if (message.includes('API key')) { + return new AIProviderError( + 'Authentication failed. Please check your Google AI API key.', + AIErrorType.AUTHENTICATION, + undefined, + error + ); + } + + if (message.includes('quota') || message.includes('rate limit')) { + return new AIProviderError( + 'Rate limit exceeded. Please slow down your requests.', + AIErrorType.RATE_LIMIT, + undefined, + error + ); + } + + if (message.includes('model') && message.includes('not found')) { + return new AIProviderError( + 'Model not found. Please check the model name.', + AIErrorType.MODEL_NOT_FOUND, + undefined, + error + ); + } + + if (message.includes('invalid') || message.includes('bad request')) { + return new AIProviderError( + `Invalid request: ${message}`, + AIErrorType.INVALID_REQUEST, + undefined, + error + ); + } + + if (message.includes('network') || message.includes('connection')) { + return new AIProviderError( + 'Network error occurred. Please check your connection.', + AIErrorType.NETWORK, + undefined, + error + ); + } + + if (message.includes('timeout')) { + return new AIProviderError( + 'Request timed out. Please try again.', + AIErrorType.TIMEOUT, + undefined, + error + ); + } + + return new AIProviderError( + `Gemini API error: ${message}`, + AIErrorType.UNKNOWN, + undefined, + error + ); + } +} \ No newline at end of file diff --git a/src/providers/index.ts b/src/providers/index.ts index 568a3e6..6414fd8 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -5,4 +5,5 @@ export { BaseAIProvider } from './base.js'; export { ClaudeProvider, type ClaudeConfig } from './claude.js'; -export { OpenAIProvider, type OpenAIConfig } from './openai.js'; \ No newline at end of file +export { OpenAIProvider, type OpenAIConfig } from './openai.js'; +export { GeminiProvider, type GeminiConfig } from './gemini.js'; \ No newline at end of file diff --git a/src/utils/factory.ts b/src/utils/factory.ts index ed5549a..7572ab3 100644 --- a/src/utils/factory.ts +++ b/src/utils/factory.ts @@ -6,12 +6,13 @@ import type { AIProviderConfig } from '../types/index.js'; import { ClaudeProvider, type ClaudeConfig } from '../providers/claude.js'; import { OpenAIProvider, type OpenAIConfig } from '../providers/openai.js'; +import { GeminiProvider, type GeminiConfig } from '../providers/gemini.js'; import { BaseAIProvider } from '../providers/base.js'; /** * Supported AI provider types */ -export type ProviderType = 'claude' | 'openai'; +export type ProviderType = 'claude' | 'openai' | 'gemini'; /** * Configuration map for different provider types @@ -19,6 +20,7 @@ export type ProviderType = 'claude' | 'openai'; export interface ProviderConfigMap { claude: ClaudeConfig; openai: OpenAIConfig; + gemini: GeminiConfig; } /** @@ -36,6 +38,8 @@ export function createProvider( return new ClaudeProvider(config as ClaudeConfig); case 'openai': return new OpenAIProvider(config as OpenAIConfig); + case 'gemini': + return new GeminiProvider(config as GeminiConfig); default: throw new Error(`Unsupported provider type: ${type}`); } @@ -73,6 +77,22 @@ export function createOpenAIProvider( }); } +/** + * Create a Gemini provider with simplified configuration + * @param apiKey - Google AI API key + * @param options - Optional additional configuration + * @returns Configured Gemini provider instance + */ +export function createGeminiProvider( + apiKey: string, + options: Partial> = {} +): GeminiProvider { + return new GeminiProvider({ + apiKey, + ...options + }); +} + /** * Provider registry for dynamic provider creation */ @@ -122,4 +142,5 @@ export class ProviderRegistry { // Pre-register built-in providers ProviderRegistry.register('claude', ClaudeProvider); -ProviderRegistry.register('openai', OpenAIProvider); \ No newline at end of file +ProviderRegistry.register('openai', OpenAIProvider); +ProviderRegistry.register('gemini', GeminiProvider); \ No newline at end of file diff --git a/tests/gemini.test.ts b/tests/gemini.test.ts new file mode 100644 index 0000000..35eb952 --- /dev/null +++ b/tests/gemini.test.ts @@ -0,0 +1,357 @@ +/** + * Tests for Gemini Provider + */ + +import { describe, it, expect, beforeEach } from 'bun:test'; +import { GeminiProvider, AIProviderError, AIErrorType } from '../src/index.js'; + +describe('GeminiProvider', () => { + let provider: GeminiProvider; + + beforeEach(() => { + provider = new GeminiProvider({ + apiKey: 'test-api-key', + defaultModel: 'gemini-1.5-flash' + }); + }); + + describe('constructor', () => { + it('should create provider with valid config', () => { + expect(provider).toBeInstanceOf(GeminiProvider); + expect(provider.isInitialized()).toBe(false); + }); + + it('should throw error for missing API key', () => { + expect(() => { + new GeminiProvider({ apiKey: '' }); + }).toThrow(AIProviderError); + }); + + it('should set default model', () => { + const customProvider = new GeminiProvider({ + apiKey: 'test-key', + defaultModel: 'gemini-1.5-pro' + }); + expect(customProvider).toBeInstanceOf(GeminiProvider); + }); + + it('should handle safety settings and generation config', () => { + const customProvider = new GeminiProvider({ + apiKey: 'test-key', + safetySettings: [], + generationConfig: { + temperature: 0.8, + topP: 0.9, + topK: 40 + } + }); + expect(customProvider).toBeInstanceOf(GeminiProvider); + }); + }); + + describe('getInfo', () => { + it('should return provider information', () => { + const info = provider.getInfo(); + + expect(info.name).toBe('Gemini'); + expect(info.version).toBe('1.0.0'); + expect(info.supportsStreaming).toBe(true); + expect(info.models).toContain('gemini-1.5-flash'); + expect(info.models).toContain('gemini-1.5-pro'); + expect(info.models).toContain('gemini-1.0-pro'); + expect(info.maxContextLength).toBe(1000000); + expect(info.capabilities).toHaveProperty('vision', true); + expect(info.capabilities).toHaveProperty('functionCalling', true); + expect(info.capabilities).toHaveProperty('systemMessages', true); + expect(info.capabilities).toHaveProperty('multimodal', true); + expect(info.capabilities).toHaveProperty('largeContext', true); + }); + }); + + describe('validation', () => { + it('should validate temperature range', async () => { + // Mock initialization to avoid API call + (provider as any).initialized = true; + (provider as any).client = {}; + (provider as any).model = {}; + + await expect( + provider.complete({ + messages: [{ role: 'user', content: 'test' }], + temperature: 1.5 + }) + ).rejects.toThrow('Temperature must be between 0.0 and 1.0'); + }); + + it('should validate top_p range', async () => { + (provider as any).initialized = true; + (provider as any).client = {}; + (provider as any).model = {}; + + await expect( + provider.complete({ + messages: [{ role: 'user', content: 'test' }], + topP: 1.5 + }) + ).rejects.toThrow('Top-p must be between 0.0 and 1.0'); + }); + + it('should validate message format', async () => { + (provider as any).initialized = true; + (provider as any).client = {}; + (provider as any).model = {}; + + await expect( + provider.complete({ + messages: [{ role: 'invalid' as any, content: 'test' }] + }) + ).rejects.toThrow('Each message must have a valid role'); + }); + + it('should validate empty content', async () => { + (provider as any).initialized = true; + (provider as any).client = {}; + (provider as any).model = {}; + + await expect( + provider.complete({ + messages: [{ role: 'user', content: '' }] + }) + ).rejects.toThrow('Each message must have non-empty string content'); + }); + + it('should require initialization before use', async () => { + await expect( + provider.complete({ + messages: [{ role: 'user', content: 'test' }] + }) + ).rejects.toThrow('Provider must be initialized before use'); + }); + }); + + describe('error handling', () => { + it('should handle authentication errors', () => { + const error = new Error('API key invalid'); + + const providerError = (provider as any).handleGeminiError(error); + + expect(providerError).toBeInstanceOf(AIProviderError); + expect(providerError.type).toBe(AIErrorType.AUTHENTICATION); + expect(providerError.message).toContain('Authentication failed'); + }); + + it('should handle rate limit errors', () => { + const error = new Error('quota exceeded'); + + const providerError = (provider as any).handleGeminiError(error); + + expect(providerError).toBeInstanceOf(AIProviderError); + expect(providerError.type).toBe(AIErrorType.RATE_LIMIT); + expect(providerError.message).toContain('Rate limit exceeded'); + }); + + it('should handle model not found errors', () => { + const error = new Error('model not found'); + + const providerError = (provider as any).handleGeminiError(error); + + expect(providerError).toBeInstanceOf(AIProviderError); + expect(providerError.type).toBe(AIErrorType.MODEL_NOT_FOUND); + expect(providerError.message).toContain('Model not found'); + }); + + it('should handle invalid request errors', () => { + const error = new Error('invalid request parameters'); + + const providerError = (provider as any).handleGeminiError(error); + + expect(providerError).toBeInstanceOf(AIProviderError); + expect(providerError.type).toBe(AIErrorType.INVALID_REQUEST); + }); + + it('should handle network errors', () => { + const error = new Error('network connection failed'); + + const providerError = (provider as any).handleGeminiError(error); + + expect(providerError).toBeInstanceOf(AIProviderError); + expect(providerError.type).toBe(AIErrorType.NETWORK); + }); + + it('should handle timeout errors', () => { + const error = new Error('request timeout'); + + const providerError = (provider as any).handleGeminiError(error); + + expect(providerError).toBeInstanceOf(AIProviderError); + expect(providerError.type).toBe(AIErrorType.TIMEOUT); + }); + + it('should handle unknown errors', () => { + const error = new Error('Unknown error'); + + const providerError = (provider as any).handleGeminiError(error); + + expect(providerError).toBeInstanceOf(AIProviderError); + expect(providerError.type).toBe(AIErrorType.UNKNOWN); + }); + }); + + describe('message conversion', () => { + it('should convert messages to Gemini format', () => { + const messages = [ + { role: 'system' as const, content: 'You are helpful' }, + { role: 'user' as const, content: 'Hello' }, + { role: 'assistant' as const, content: 'Hi there' } + ]; + + const result = (provider as any).convertMessages(messages); + + expect(result.systemInstruction).toBe('You are helpful'); + expect(result.contents).toHaveLength(2); + expect(result.contents[0]).toEqual({ + role: 'user', + parts: [{ text: 'Hello' }] + }); + expect(result.contents[1]).toEqual({ + role: 'model', + parts: [{ text: 'Hi there' }] + }); + }); + + it('should handle multiple system messages', () => { + const messages = [ + { role: 'system' as const, content: 'You are helpful' }, + { role: 'user' as const, content: 'Hello' }, + { role: 'system' as const, content: 'Be concise' } + ]; + + const result = (provider as any).convertMessages(messages); + + expect(result.systemInstruction).toBe('You are helpful\n\nBe concise'); + expect(result.contents).toHaveLength(1); + expect(result.contents[0].role).toBe('user'); + }); + + it('should handle messages without system prompts', () => { + const messages = [ + { role: 'user' as const, content: 'Hello' }, + { role: 'assistant' as const, content: 'Hi there' } + ]; + + const result = (provider as any).convertMessages(messages); + + expect(result.systemInstruction).toBeUndefined(); + expect(result.contents).toHaveLength(2); + }); + + it('should convert assistant role to model role', () => { + const messages = [ + { role: 'assistant' as const, content: 'I am an assistant' } + ]; + + const result = (provider as any).convertMessages(messages); + + expect(result.contents[0].role).toBe('model'); + expect(result.contents[0].parts[0].text).toBe('I am an assistant'); + }); + }); + + describe('generation config', () => { + it('should build generation config from completion params', () => { + const params = { + messages: [{ role: 'user' as const, content: 'test' }], + temperature: 0.8, + topP: 0.9, + maxTokens: 500, + stopSequences: ['STOP', 'END'] + }; + + const result = (provider as any).buildGenerationConfig(params); + + expect(result.temperature).toBe(0.8); + expect(result.topP).toBe(0.9); + expect(result.maxOutputTokens).toBe(500); + expect(result.stopSequences).toEqual(['STOP', 'END']); + }); + + it('should use default temperature when not provided', () => { + const params = { + messages: [{ role: 'user' as const, content: 'test' }] + }; + + const result = (provider as any).buildGenerationConfig(params); + + expect(result.temperature).toBe(0.7); + expect(result.maxOutputTokens).toBe(1000); + }); + }); + + describe('response formatting', () => { + it('should format completion response correctly', () => { + const mockResponse = { + candidates: [{ + content: { + parts: [{ text: 'Hello there!' }] + }, + finishReason: 'STOP', + safetyRatings: [], + citationMetadata: null + }], + usageMetadata: { + promptTokenCount: 10, + candidatesTokenCount: 20, + totalTokenCount: 30 + } + }; + + const result = (provider as any).formatCompletionResponse(mockResponse, 'gemini-1.5-flash'); + + expect(result.content).toBe('Hello there!'); + expect(result.model).toBe('gemini-1.5-flash'); + expect(result.usage.promptTokens).toBe(10); + expect(result.usage.completionTokens).toBe(20); + expect(result.usage.totalTokens).toBe(30); + expect(result.metadata.finishReason).toBe('STOP'); + }); + + it('should handle multiple text parts', () => { + const mockResponse = { + candidates: [{ + content: { + parts: [ + { text: 'Hello ' }, + { text: 'there!' }, + { functionCall: { name: 'test' } } // Non-text part should be filtered + ] + }, + finishReason: 'STOP' + }], + usageMetadata: { + promptTokenCount: 5, + candidatesTokenCount: 10, + totalTokenCount: 15 + } + }; + + const result = (provider as any).formatCompletionResponse(mockResponse, 'gemini-1.5-flash'); + + expect(result.content).toBe('Hello there!'); + }); + + it('should throw error for empty response', () => { + const mockResponse = { + candidates: [], + usageMetadata: { + promptTokenCount: 5, + candidatesTokenCount: 0, + totalTokenCount: 5 + } + }; + + expect(() => { + (provider as any).formatCompletionResponse(mockResponse, 'gemini-1.5-flash'); + }).toThrow('No content in Gemini response'); + }); + }); +}); \ No newline at end of file