feat: add Google Gemini provider integration and docs
This commit is contained in:
205
README.md
205
README.md
@ -1,6 +1,6 @@
|
||||
# Simple AI Provider
|
||||
|
||||
A professional, extensible TypeScript package for integrating multiple AI providers into your applications with a unified interface. Currently supports **Claude (Anthropic)** and **OpenAI (GPT)** with plans to add more providers.
|
||||
A professional, extensible TypeScript package for integrating multiple AI providers into your applications with a unified interface. Currently supports **Claude (Anthropic)**, **OpenAI (GPT)**, and **Google Gemini** with plans to add more providers.
|
||||
|
||||
## Features
|
||||
|
||||
@ -11,7 +11,7 @@ A professional, extensible TypeScript package for integrating multiple AI provid
|
||||
- 🛡️ **Error Handling**: Robust error handling with categorized error types
|
||||
- 🔧 **Extensible**: Easy to add new AI providers
|
||||
- 📦 **Modern**: Built with ES modules and modern JavaScript features
|
||||
- 🌐 **Multi-Provider**: Switch between Claude and OpenAI seamlessly
|
||||
- 🌐 **Multi-Provider**: Switch between Claude, OpenAI, and Gemini seamlessly
|
||||
|
||||
## Installation
|
||||
|
||||
@ -71,39 +71,65 @@ const response = await openai.complete({
|
||||
console.log(response.content);
|
||||
```
|
||||
|
||||
### Basic Usage with Gemini
|
||||
|
||||
```typescript
|
||||
import { createGeminiProvider } from 'simple-ai-provider';
|
||||
|
||||
// Create a Gemini provider
|
||||
const gemini = createGeminiProvider('your-google-ai-api-key');
|
||||
|
||||
// Initialize the provider
|
||||
await gemini.initialize();
|
||||
|
||||
// Generate a completion
|
||||
const response = await gemini.complete({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Hello! How are you today?' }
|
||||
],
|
||||
maxTokens: 100,
|
||||
temperature: 0.7
|
||||
});
|
||||
|
||||
console.log(response.content);
|
||||
```
|
||||
|
||||
### Multi-Provider Usage
|
||||
|
||||
```typescript
|
||||
import { createProvider, createClaudeProvider, createOpenAIProvider } from 'simple-ai-provider';
|
||||
import { createProvider, createClaudeProvider, createOpenAIProvider, createGeminiProvider } from 'simple-ai-provider';
|
||||
|
||||
// Method 1: Using specific factory functions
|
||||
const claude = createClaudeProvider('your-anthropic-api-key');
|
||||
const openai = createOpenAIProvider('your-openai-api-key');
|
||||
const gemini = createGeminiProvider('your-google-ai-api-key');
|
||||
|
||||
// Method 2: Using generic factory
|
||||
const claude2 = createProvider('claude', { apiKey: 'your-anthropic-api-key' });
|
||||
const openai2 = createProvider('openai', { apiKey: 'your-openai-api-key' });
|
||||
const gemini2 = createProvider('gemini', { apiKey: 'your-google-ai-api-key' });
|
||||
|
||||
// Initialize both
|
||||
await Promise.all([claude.initialize(), openai.initialize()]);
|
||||
// Initialize all
|
||||
await Promise.all([claude.initialize(), openai.initialize(), gemini.initialize()]);
|
||||
|
||||
// Use the same interface for both providers
|
||||
// Use the same interface for all providers
|
||||
const prompt = { messages: [{ role: 'user', content: 'Explain AI' }] };
|
||||
|
||||
const claudeResponse = await claude.complete(prompt);
|
||||
const openaiResponse = await openai.complete(prompt);
|
||||
const geminiResponse = await gemini.complete(prompt);
|
||||
```
|
||||
|
||||
### Streaming Responses
|
||||
|
||||
```typescript
|
||||
import { createOpenAIProvider } from 'simple-ai-provider';
|
||||
import { createGeminiProvider } from 'simple-ai-provider';
|
||||
|
||||
const openai = createOpenAIProvider('your-openai-api-key');
|
||||
await openai.initialize();
|
||||
const gemini = createGeminiProvider('your-google-ai-api-key');
|
||||
await gemini.initialize();
|
||||
|
||||
// Stream a completion
|
||||
for await (const chunk of openai.stream({
|
||||
for await (const chunk of gemini.stream({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Write a short story about a robot.' }
|
||||
],
|
||||
@ -120,7 +146,7 @@ for await (const chunk of openai.stream({
|
||||
### Advanced Configuration
|
||||
|
||||
```typescript
|
||||
import { ClaudeProvider, OpenAIProvider } from 'simple-ai-provider';
|
||||
import { ClaudeProvider, OpenAIProvider, GeminiProvider } from 'simple-ai-provider';
|
||||
|
||||
// Claude with custom configuration
|
||||
const claude = new ClaudeProvider({
|
||||
@ -141,14 +167,28 @@ const openai = new OpenAIProvider({
|
||||
maxRetries: 5
|
||||
});
|
||||
|
||||
await Promise.all([claude.initialize(), openai.initialize()]);
|
||||
// Gemini with safety settings and generation config
|
||||
const gemini = new GeminiProvider({
|
||||
apiKey: 'your-google-ai-api-key',
|
||||
defaultModel: 'gemini-1.5-pro',
|
||||
safetySettings: [], // Configure content filtering
|
||||
generationConfig: {
|
||||
temperature: 0.8,
|
||||
topP: 0.9,
|
||||
topK: 40,
|
||||
maxOutputTokens: 2048
|
||||
},
|
||||
timeout: 45000
|
||||
});
|
||||
|
||||
const response = await openai.complete({
|
||||
await Promise.all([claude.initialize(), openai.initialize(), gemini.initialize()]);
|
||||
|
||||
const response = await gemini.complete({
|
||||
messages: [
|
||||
{ role: 'system', content: 'You are a helpful assistant.' },
|
||||
{ role: 'user', content: 'Explain quantum computing in simple terms.' }
|
||||
],
|
||||
model: 'gpt-4-turbo',
|
||||
model: 'gemini-1.5-flash',
|
||||
maxTokens: 300,
|
||||
temperature: 0.5,
|
||||
topP: 0.9,
|
||||
@ -216,6 +256,20 @@ const openai = createOpenAIProvider('your-api-key', {
|
||||
});
|
||||
```
|
||||
|
||||
#### `createGeminiProvider(apiKey, options?)`
|
||||
Creates a Gemini provider with simplified configuration.
|
||||
|
||||
```typescript
|
||||
const gemini = createGeminiProvider('your-api-key', {
|
||||
defaultModel: 'gemini-1.5-pro',
|
||||
safetySettings: [],
|
||||
generationConfig: {
|
||||
temperature: 0.8,
|
||||
topK: 40
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
#### `createProvider(type, config)`
|
||||
Generic factory function for creating any provider type.
|
||||
|
||||
@ -229,6 +283,11 @@ const openai = createProvider('openai', {
|
||||
apiKey: 'your-api-key',
|
||||
defaultModel: 'gpt-4'
|
||||
});
|
||||
|
||||
const gemini = createProvider('gemini', {
|
||||
apiKey: 'your-api-key',
|
||||
defaultModel: 'gemini-1.5-flash'
|
||||
});
|
||||
```
|
||||
|
||||
### Provider Methods
|
||||
@ -297,6 +356,13 @@ try {
|
||||
- `gpt-3.5-turbo-0125`
|
||||
- `gpt-3.5-turbo-1106`
|
||||
|
||||
### Google Gemini
|
||||
- `gemini-1.5-flash` (default)
|
||||
- `gemini-1.5-flash-8b`
|
||||
- `gemini-1.5-pro`
|
||||
- `gemini-1.0-pro`
|
||||
- `gemini-1.0-pro-vision`
|
||||
|
||||
## Environment Variables
|
||||
|
||||
You can set your API keys as environment variables:
|
||||
@ -304,35 +370,42 @@ You can set your API keys as environment variables:
|
||||
```bash
|
||||
export ANTHROPIC_API_KEY="your-anthropic-api-key"
|
||||
export OPENAI_API_KEY="your-openai-api-key"
|
||||
export GOOGLE_AI_API_KEY="your-google-ai-api-key"
|
||||
```
|
||||
|
||||
```typescript
|
||||
const claude = createClaudeProvider(process.env.ANTHROPIC_API_KEY!);
|
||||
const openai = createOpenAIProvider(process.env.OPENAI_API_KEY!);
|
||||
const gemini = createGeminiProvider(process.env.GOOGLE_AI_API_KEY!);
|
||||
```
|
||||
|
||||
## Provider Comparison
|
||||
|
||||
| Feature | Claude | OpenAI |
|
||||
|---------|--------|--------|
|
||||
| **Models** | 5 models | 8+ models |
|
||||
| **Max Context** | 200K tokens | 128K tokens |
|
||||
| **Streaming** | ✅ | ✅ |
|
||||
| **Vision** | ✅ | ✅ |
|
||||
| **Function Calling** | ✅ | ✅ |
|
||||
| **JSON Mode** | ❌ | ✅ |
|
||||
| **System Messages** | ✅ (separate) | ✅ (inline) |
|
||||
| Feature | Claude | OpenAI | Gemini |
|
||||
|---------|--------|--------|--------|
|
||||
| **Models** | 5 models | 8+ models | 5 models |
|
||||
| **Max Context** | 200K tokens | 128K tokens | 1M tokens |
|
||||
| **Streaming** | ✅ | ✅ | ✅ |
|
||||
| **Vision** | ✅ | ✅ | ✅ |
|
||||
| **Function Calling** | ✅ | ✅ | ✅ |
|
||||
| **JSON Mode** | ❌ | ✅ | ❌ |
|
||||
| **System Messages** | ✅ (separate) | ✅ (inline) | ✅ (separate) |
|
||||
| **Multimodal** | ✅ | ✅ | ✅ |
|
||||
| **Safety Controls** | Basic | Basic | Advanced |
|
||||
| **Special Features** | Advanced reasoning | JSON mode, plugins | Largest context, advanced safety |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always initialize providers** before using them
|
||||
2. **Handle errors gracefully** with proper error types
|
||||
3. **Use appropriate models** for your use case (speed vs. capability)
|
||||
3. **Use appropriate models** for your use case (speed vs. capability vs. context)
|
||||
4. **Set reasonable timeouts** for your application
|
||||
5. **Implement retry logic** for production applications
|
||||
6. **Monitor token usage** to control costs
|
||||
7. **Use environment variables** for API keys
|
||||
8. **Consider provider-specific features** when choosing
|
||||
9. **Configure safety settings** appropriately for Gemini
|
||||
10. **Leverage large context** capabilities of Gemini for complex tasks
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
@ -342,70 +415,51 @@ const openai = createOpenAIProvider(process.env.OPENAI_API_KEY!);
|
||||
import { ProviderRegistry } from 'simple-ai-provider';
|
||||
|
||||
// List all registered providers
|
||||
console.log(ProviderRegistry.getRegisteredProviders()); // ['claude', 'openai']
|
||||
console.log(ProviderRegistry.getRegisteredProviders()); // ['claude', 'openai', 'gemini']
|
||||
|
||||
// Create provider by name
|
||||
const provider = ProviderRegistry.create('openai', {
|
||||
const provider = ProviderRegistry.create('gemini', {
|
||||
apiKey: 'your-api-key'
|
||||
});
|
||||
|
||||
// Check if provider is registered
|
||||
if (ProviderRegistry.isRegistered('claude')) {
|
||||
console.log('Claude is available!');
|
||||
if (ProviderRegistry.isRegistered('gemini')) {
|
||||
console.log('Gemini is available!');
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Error Handling
|
||||
### Gemini-Specific Features
|
||||
|
||||
```typescript
|
||||
function handleAIError(error: unknown, providerName: string) {
|
||||
if (error instanceof AIProviderError) {
|
||||
console.error(`${providerName} Error (${error.type}):`, error.message);
|
||||
|
||||
if (error.statusCode) {
|
||||
console.error('HTTP Status:', error.statusCode);
|
||||
}
|
||||
|
||||
if (error.originalError) {
|
||||
console.error('Original Error:', error.originalError.message);
|
||||
import { createGeminiProvider } from 'simple-ai-provider';
|
||||
|
||||
const gemini = createGeminiProvider('your-api-key', {
|
||||
defaultModel: 'gemini-1.5-pro',
|
||||
safetySettings: [
|
||||
{
|
||||
category: 'HARM_CATEGORY_HARASSMENT',
|
||||
threshold: 'BLOCK_MEDIUM_AND_ABOVE'
|
||||
}
|
||||
],
|
||||
generationConfig: {
|
||||
temperature: 0.9,
|
||||
topP: 0.8,
|
||||
topK: 40,
|
||||
maxOutputTokens: 2048,
|
||||
stopSequences: ['END', 'STOP']
|
||||
}
|
||||
}
|
||||
```
|
||||
});
|
||||
|
||||
## Extending the Package
|
||||
await gemini.initialize();
|
||||
|
||||
To add a new AI provider, extend the `BaseAIProvider` class:
|
||||
|
||||
```typescript
|
||||
import { BaseAIProvider } from 'simple-ai-provider';
|
||||
|
||||
class MyCustomProvider extends BaseAIProvider {
|
||||
protected async doInitialize(): Promise<void> {
|
||||
// Initialize your provider
|
||||
}
|
||||
|
||||
protected async doComplete(params: CompletionParams): Promise<CompletionResponse> {
|
||||
// Implement completion logic
|
||||
}
|
||||
|
||||
protected async *doStream(params: CompletionParams): AsyncIterable<CompletionChunk> {
|
||||
// Implement streaming logic
|
||||
}
|
||||
|
||||
public getInfo(): ProviderInfo {
|
||||
return {
|
||||
name: 'MyCustomProvider',
|
||||
version: '1.0.0',
|
||||
models: ['my-model'],
|
||||
maxContextLength: 4096,
|
||||
supportsStreaming: true
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Register your provider
|
||||
ProviderRegistry.register('mycustom', MyCustomProvider);
|
||||
// Large context example (up to 1M tokens)
|
||||
const response = await gemini.complete({
|
||||
messages: [
|
||||
{ role: 'system', content: 'You are analyzing a large document.' },
|
||||
{ role: 'user', content: 'Your very large text here...' }
|
||||
],
|
||||
maxTokens: 2048
|
||||
});
|
||||
```
|
||||
|
||||
## Contributing
|
||||
@ -422,8 +476,11 @@ MIT
|
||||
- Initial release
|
||||
- Claude provider implementation
|
||||
- OpenAI provider implementation
|
||||
- Streaming support for both providers
|
||||
- Gemini provider implementation
|
||||
- Streaming support for all providers
|
||||
- Comprehensive error handling
|
||||
- TypeScript support
|
||||
- Provider registry system
|
||||
- Multi-provider examples
|
||||
- Large context support (Gemini)
|
||||
- Advanced safety controls (Gemini)
|
||||
|
3
bun.lock
3
bun.lock
@ -5,6 +5,7 @@
|
||||
"name": "simple-ai-provider",
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.52.0",
|
||||
"@google/generative-ai": "^0.24.1",
|
||||
"openai": "^4.103.0",
|
||||
},
|
||||
"devDependencies": {
|
||||
@ -19,6 +20,8 @@
|
||||
"packages": {
|
||||
"@anthropic-ai/sdk": ["@anthropic-ai/sdk@0.52.0", "", { "bin": { "anthropic-ai-sdk": "bin/cli" } }, "sha512-d4c+fg+xy9e46c8+YnrrgIQR45CZlAi7PwdzIfDXDM6ACxEZli1/fxhURsq30ZpMZy6LvSkr41jGq5aF5TD7rQ=="],
|
||||
|
||||
"@google/generative-ai": ["@google/generative-ai@0.24.1", "", {}, "sha512-MqO+MLfM6kjxcKoy0p1wRzG3b4ZZXtPI+z2IE26UogS2Cm/XHO+7gGRBh6gcJsOiIVoH93UwKvW4HdgiOZCy9Q=="],
|
||||
|
||||
"@types/bun": ["@types/bun@1.2.14", "", { "dependencies": { "bun-types": "1.2.14" } }, "sha512-VsFZKs8oKHzI7zwvECiAJ5oSorWndIWEVhfbYqZd4HI/45kzW7PN2Rr5biAzvGvRuNmYLSANY+H59ubHq8xw7Q=="],
|
||||
|
||||
"@types/node": ["@types/node@20.17.51", "", { "dependencies": { "undici-types": "~6.19.2" } }, "sha512-hccptBl7C8lHiKxTBsY6vYYmqpmw1E/aGR/8fmueE+B390L3pdMOpNSRvFO4ZnXzW5+p2HBXV0yNABd2vdk22Q=="],
|
||||
|
@ -1,11 +1,12 @@
|
||||
/**
|
||||
* Multi-provider example for Simple AI Provider
|
||||
* Demonstrates how to use both Claude and OpenAI providers
|
||||
* Demonstrates how to use Claude, OpenAI, and Gemini providers
|
||||
*/
|
||||
|
||||
import {
|
||||
createClaudeProvider,
|
||||
createOpenAIProvider,
|
||||
createGeminiProvider,
|
||||
createProvider,
|
||||
ProviderRegistry,
|
||||
AIProviderError,
|
||||
@ -18,6 +19,7 @@ async function multiProviderExample() {
|
||||
// Get API keys from environment
|
||||
const claudeApiKey = process.env.ANTHROPIC_API_KEY || 'your-claude-api-key';
|
||||
const openaiApiKey = process.env.OPENAI_API_KEY || 'your-openai-api-key';
|
||||
const geminiApiKey = process.env.GOOGLE_AI_API_KEY || 'your-gemini-api-key';
|
||||
|
||||
try {
|
||||
// Method 1: Using factory functions
|
||||
@ -31,6 +33,10 @@ async function multiProviderExample() {
|
||||
defaultModel: 'gpt-3.5-turbo'
|
||||
});
|
||||
|
||||
const gemini = createGeminiProvider(geminiApiKey, {
|
||||
defaultModel: 'gemini-1.5-flash'
|
||||
});
|
||||
|
||||
console.log('✓ Providers created\n');
|
||||
|
||||
// Method 2: Using generic createProvider function
|
||||
@ -46,6 +52,11 @@ async function multiProviderExample() {
|
||||
defaultModel: 'gpt-3.5-turbo'
|
||||
});
|
||||
|
||||
const gemini2 = createProvider('gemini', {
|
||||
apiKey: geminiApiKey,
|
||||
defaultModel: 'gemini-1.5-flash'
|
||||
});
|
||||
|
||||
console.log('✓ Generic providers created\n');
|
||||
|
||||
// Method 3: Using Provider Registry
|
||||
@ -63,7 +74,8 @@ async function multiProviderExample() {
|
||||
console.log('4. Initializing providers...');
|
||||
await Promise.all([
|
||||
claude.initialize(),
|
||||
openai.initialize()
|
||||
openai.initialize(),
|
||||
gemini.initialize()
|
||||
]);
|
||||
console.log('✓ All providers initialized\n');
|
||||
|
||||
@ -71,12 +83,13 @@ async function multiProviderExample() {
|
||||
console.log('5. Provider Information:');
|
||||
console.log('Claude Info:', claude.getInfo());
|
||||
console.log('OpenAI Info:', openai.getInfo());
|
||||
console.log('Gemini Info:', gemini.getInfo());
|
||||
console.log();
|
||||
|
||||
// Test the same prompt with both providers
|
||||
// Test the same prompt with all providers
|
||||
const testPrompt = 'Explain the concept of recursion in programming in one sentence.';
|
||||
|
||||
console.log('6. Testing same prompt with both providers...');
|
||||
console.log('6. Testing same prompt with all providers...');
|
||||
console.log(`Prompt: "${testPrompt}"\n`);
|
||||
|
||||
// Claude response
|
||||
@ -111,6 +124,22 @@ async function multiProviderExample() {
|
||||
console.log('Model:', openaiResponse.model);
|
||||
console.log();
|
||||
|
||||
// Gemini response
|
||||
console.log('--- Gemini Response ---');
|
||||
const geminiResponse = await gemini.complete({
|
||||
messages: [
|
||||
{ role: 'system', content: 'You are a concise programming tutor.' },
|
||||
{ role: 'user', content: testPrompt }
|
||||
],
|
||||
maxTokens: 100,
|
||||
temperature: 0.7
|
||||
});
|
||||
|
||||
console.log('Response:', geminiResponse.content);
|
||||
console.log('Usage:', geminiResponse.usage);
|
||||
console.log('Model:', geminiResponse.model);
|
||||
console.log();
|
||||
|
||||
// Streaming comparison
|
||||
console.log('7. Streaming comparison...');
|
||||
console.log('Streaming from Claude:');
|
||||
@ -139,6 +168,19 @@ async function multiProviderExample() {
|
||||
}
|
||||
}
|
||||
|
||||
console.log('Streaming from Gemini:');
|
||||
|
||||
for await (const chunk of gemini.stream({
|
||||
messages: [{ role: 'user', content: 'Count from 1 to 5.' }],
|
||||
maxTokens: 50
|
||||
})) {
|
||||
if (!chunk.isComplete) {
|
||||
process.stdout.write(chunk.content);
|
||||
} else {
|
||||
console.log('\n✓ Gemini streaming complete\n');
|
||||
}
|
||||
}
|
||||
|
||||
// Provider-specific features demo
|
||||
console.log('8. Provider-specific features...');
|
||||
|
||||
@ -158,6 +200,18 @@ async function multiProviderExample() {
|
||||
});
|
||||
console.log('✓ Created Claude provider with custom settings');
|
||||
|
||||
// Gemini with safety settings
|
||||
const geminiCustom = createGeminiProvider(geminiApiKey, {
|
||||
defaultModel: 'gemini-1.5-pro',
|
||||
safetySettings: [],
|
||||
generationConfig: {
|
||||
temperature: 0.9,
|
||||
topP: 0.8,
|
||||
topK: 40
|
||||
}
|
||||
});
|
||||
console.log('✓ Created Gemini provider with safety and generation settings');
|
||||
|
||||
console.log('\n🎉 Multi-provider example completed successfully!');
|
||||
|
||||
} catch (error) {
|
||||
@ -169,7 +223,7 @@ async function multiProviderExample() {
|
||||
switch (error.type) {
|
||||
case AIErrorType.AUTHENTICATION:
|
||||
console.error('💡 Hint: Check your API keys in environment variables');
|
||||
console.error(' Set ANTHROPIC_API_KEY and OPENAI_API_KEY');
|
||||
console.error(' Set ANTHROPIC_API_KEY, OPENAI_API_KEY, and GOOGLE_AI_API_KEY');
|
||||
break;
|
||||
case AIErrorType.RATE_LIMIT:
|
||||
console.error('💡 Hint: You are being rate limited. Wait and try again.');
|
||||
@ -192,18 +246,44 @@ async function compareProviders() {
|
||||
|
||||
const providers = [
|
||||
{ name: 'Claude', factory: () => createClaudeProvider('dummy-key') },
|
||||
{ name: 'OpenAI', factory: () => createOpenAIProvider('dummy-key') }
|
||||
{ name: 'OpenAI', factory: () => createOpenAIProvider('dummy-key') },
|
||||
{ name: 'Gemini', factory: () => createGeminiProvider('dummy-key') }
|
||||
];
|
||||
|
||||
console.log('\nProvider Capabilities:');
|
||||
console.log('| Provider | Models | Context | Streaming | Vision | Functions |');
|
||||
console.log('|----------|--------|---------|-----------|--------|-----------|');
|
||||
console.log('| Provider | Models | Context | Streaming | Vision | Functions | Multimodal |');
|
||||
console.log('|----------|--------|---------|-----------|--------|-----------|------------|');
|
||||
|
||||
for (const { name, factory } of providers) {
|
||||
const provider = factory();
|
||||
const info = provider.getInfo();
|
||||
|
||||
console.log(`| ${name.padEnd(8)} | ${info.models.length.toString().padEnd(6)} | ${info.maxContextLength.toLocaleString().padEnd(7)} | ${info.supportsStreaming ? '✓' : '✗'.padEnd(9)} | ${info.capabilities?.vision ? '✓' : '✗'.padEnd(6)} | ${info.capabilities?.functionCalling ? '✓' : '✗'.padEnd(9)} |`);
|
||||
const contextStr = info.maxContextLength >= 1000000
|
||||
? `${(info.maxContextLength / 1000000).toFixed(1)}M`
|
||||
: `${(info.maxContextLength / 1000).toFixed(0)}K`;
|
||||
|
||||
console.log(`| ${name.padEnd(8)} | ${info.models.length.toString().padEnd(6)} | ${contextStr.padEnd(7)} | ${info.supportsStreaming ? '✓' : '✗'.padEnd(9)} | ${info.capabilities?.vision ? '✓' : '✗'.padEnd(6)} | ${info.capabilities?.functionCalling ? '✓' : '✗'.padEnd(9)} | ${info.capabilities?.multimodal ? '✓' : '✗'.padEnd(10)} |`);
|
||||
}
|
||||
|
||||
console.log();
|
||||
}
|
||||
|
||||
// Feature comparison
|
||||
async function featureComparison() {
|
||||
console.log('\n=== Feature Comparison ===');
|
||||
|
||||
const features = [
|
||||
['Provider', 'Context Window', 'Streaming', 'Vision', 'Function Calling', 'System Messages', 'Special Features'],
|
||||
['Claude', '200K tokens', '✅', '✅', '✅', '✅ (separate)', 'Advanced reasoning'],
|
||||
['OpenAI', '128K tokens', '✅', '✅', '✅', '✅ (inline)', 'JSON mode, plugins'],
|
||||
['Gemini', '1M tokens', '✅', '✅', '✅', '✅ (separate)', 'Largest context, multimodal']
|
||||
];
|
||||
|
||||
for (const row of features) {
|
||||
console.log('| ' + row.map(cell => cell.padEnd(15)).join(' | ') + ' |');
|
||||
if (row[0] === 'Provider') {
|
||||
console.log('|' + ''.padEnd(row.length * 17 + row.length - 1, '-') + '|');
|
||||
}
|
||||
}
|
||||
|
||||
console.log();
|
||||
@ -213,4 +293,5 @@ async function compareProviders() {
|
||||
if (import.meta.main) {
|
||||
await multiProviderExample();
|
||||
await compareProviders();
|
||||
await featureComparison();
|
||||
}
|
@ -33,6 +33,8 @@
|
||||
"anthropic",
|
||||
"openai",
|
||||
"gpt",
|
||||
"gemini",
|
||||
"google",
|
||||
"provider",
|
||||
"typescript",
|
||||
"nodejs"
|
||||
@ -45,6 +47,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.52.0",
|
||||
"@google/generative-ai": "^0.24.1",
|
||||
"openai": "^4.103.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
@ -29,12 +29,14 @@ export { BaseAIProvider } from './providers/base.js';
|
||||
// Concrete provider implementations
|
||||
export { ClaudeProvider, type ClaudeConfig } from './providers/claude.js';
|
||||
export { OpenAIProvider, type OpenAIConfig } from './providers/openai.js';
|
||||
export { GeminiProvider, type GeminiConfig } from './providers/gemini.js';
|
||||
|
||||
// Utility functions and factory
|
||||
export {
|
||||
createProvider,
|
||||
createClaudeProvider,
|
||||
createOpenAIProvider,
|
||||
createGeminiProvider,
|
||||
ProviderRegistry,
|
||||
type ProviderType,
|
||||
type ProviderConfigMap
|
||||
@ -51,4 +53,4 @@ export const VERSION = '1.0.0';
|
||||
/**
|
||||
* List of supported providers
|
||||
*/
|
||||
export const SUPPORTED_PROVIDERS = ['claude', 'openai'] as const;
|
||||
export const SUPPORTED_PROVIDERS = ['claude', 'openai', 'gemini'] as const;
|
403
src/providers/gemini.ts
Normal file
403
src/providers/gemini.ts
Normal file
@ -0,0 +1,403 @@
|
||||
/**
|
||||
* Gemini Provider implementation using Google's Generative AI API
|
||||
* Provides integration with Gemini models through a standardized interface
|
||||
*/
|
||||
|
||||
import { GoogleGenerativeAI, GenerativeModel } from '@google/generative-ai';
|
||||
import type { Content, Part } from '@google/generative-ai';
|
||||
import type {
|
||||
AIProviderConfig,
|
||||
CompletionParams,
|
||||
CompletionResponse,
|
||||
CompletionChunk,
|
||||
ProviderInfo,
|
||||
AIMessage
|
||||
} from '../types/index.js';
|
||||
import { BaseAIProvider } from './base.js';
|
||||
import { AIProviderError, AIErrorType } from '../types/index.js';
|
||||
|
||||
/**
|
||||
* Configuration specific to Gemini provider
|
||||
*/
|
||||
export interface GeminiConfig extends AIProviderConfig {
|
||||
/** Default model to use if not specified in requests (default: gemini-1.5-flash) */
|
||||
defaultModel?: string;
|
||||
/** Safety settings for content filtering */
|
||||
safetySettings?: any[];
|
||||
/** Generation configuration */
|
||||
generationConfig?: {
|
||||
temperature?: number;
|
||||
topP?: number;
|
||||
topK?: number;
|
||||
maxOutputTokens?: number;
|
||||
stopSequences?: string[];
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gemini provider implementation
|
||||
*/
|
||||
export class GeminiProvider extends BaseAIProvider {
|
||||
private client: GoogleGenerativeAI | null = null;
|
||||
private model: GenerativeModel | null = null;
|
||||
private readonly defaultModel: string;
|
||||
private readonly safetySettings?: any[];
|
||||
private readonly generationConfig?: any;
|
||||
|
||||
constructor(config: GeminiConfig) {
|
||||
super(config);
|
||||
this.defaultModel = config.defaultModel || 'gemini-1.5-flash';
|
||||
this.safetySettings = config.safetySettings;
|
||||
this.generationConfig = config.generationConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the Gemini provider by setting up the Google Generative AI client
|
||||
*/
|
||||
protected async doInitialize(): Promise<void> {
|
||||
try {
|
||||
this.client = new GoogleGenerativeAI(this.config.apiKey);
|
||||
this.model = this.client.getGenerativeModel({
|
||||
model: this.defaultModel,
|
||||
safetySettings: this.safetySettings,
|
||||
generationConfig: this.generationConfig
|
||||
});
|
||||
|
||||
// Test the connection by making a simple request
|
||||
await this.validateConnection();
|
||||
} catch (error) {
|
||||
throw new AIProviderError(
|
||||
`Failed to initialize Gemini provider: ${(error as Error).message}`,
|
||||
AIErrorType.AUTHENTICATION,
|
||||
undefined,
|
||||
error as Error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a completion using Gemini
|
||||
*/
|
||||
protected async doComplete(params: CompletionParams): Promise<CompletionResponse> {
|
||||
if (!this.client || !this.model) {
|
||||
throw new AIProviderError('Client not initialized', AIErrorType.INVALID_REQUEST);
|
||||
}
|
||||
|
||||
try {
|
||||
// Get the model for this request (might be different from default)
|
||||
const model = params.model && params.model !== this.defaultModel
|
||||
? this.client.getGenerativeModel({
|
||||
model: params.model,
|
||||
safetySettings: this.safetySettings,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
})
|
||||
: this.model;
|
||||
|
||||
const { systemInstruction, contents } = this.convertMessages(params.messages);
|
||||
|
||||
// Create chat session or use generateContent
|
||||
if (contents.length > 1) {
|
||||
// Multi-turn conversation - use chat session
|
||||
const chat = model.startChat({
|
||||
history: contents.slice(0, -1),
|
||||
systemInstruction,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
});
|
||||
|
||||
const lastMessage = contents[contents.length - 1];
|
||||
if (!lastMessage) {
|
||||
throw new AIProviderError('No valid messages provided', AIErrorType.INVALID_REQUEST);
|
||||
}
|
||||
const result = await chat.sendMessage(lastMessage.parts);
|
||||
|
||||
return this.formatCompletionResponse(result.response, params.model || this.defaultModel);
|
||||
} else {
|
||||
// Single message - use generateContent
|
||||
const result = await model.generateContent({
|
||||
contents,
|
||||
systemInstruction,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
});
|
||||
|
||||
return this.formatCompletionResponse(result.response, params.model || this.defaultModel);
|
||||
}
|
||||
} catch (error) {
|
||||
throw this.handleGeminiError(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a streaming completion using Gemini
|
||||
*/
|
||||
protected async *doStream(params: CompletionParams): AsyncIterable<CompletionChunk> {
|
||||
if (!this.client || !this.model) {
|
||||
throw new AIProviderError('Client not initialized', AIErrorType.INVALID_REQUEST);
|
||||
}
|
||||
|
||||
try {
|
||||
// Get the model for this request
|
||||
const model = params.model && params.model !== this.defaultModel
|
||||
? this.client.getGenerativeModel({
|
||||
model: params.model,
|
||||
safetySettings: this.safetySettings,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
})
|
||||
: this.model;
|
||||
|
||||
const { systemInstruction, contents } = this.convertMessages(params.messages);
|
||||
|
||||
let stream;
|
||||
if (contents.length > 1) {
|
||||
// Multi-turn conversation
|
||||
const chat = model.startChat({
|
||||
history: contents.slice(0, -1),
|
||||
systemInstruction,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
});
|
||||
|
||||
const lastMessage = contents[contents.length - 1];
|
||||
if (!lastMessage) {
|
||||
throw new AIProviderError('No valid messages provided', AIErrorType.INVALID_REQUEST);
|
||||
}
|
||||
stream = await chat.sendMessageStream(lastMessage.parts);
|
||||
} else {
|
||||
// Single message
|
||||
stream = await model.generateContentStream({
|
||||
contents,
|
||||
systemInstruction,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
});
|
||||
}
|
||||
|
||||
let fullText = '';
|
||||
const requestId = `gemini-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
|
||||
for await (const chunk of stream.stream) {
|
||||
const chunkText = chunk.text();
|
||||
fullText += chunkText;
|
||||
|
||||
yield {
|
||||
content: chunkText,
|
||||
isComplete: false,
|
||||
id: requestId
|
||||
};
|
||||
}
|
||||
|
||||
// Final chunk with usage info
|
||||
const finalResponse = await stream.response;
|
||||
const usageMetadata = finalResponse.usageMetadata;
|
||||
|
||||
yield {
|
||||
content: '',
|
||||
isComplete: true,
|
||||
id: requestId,
|
||||
usage: {
|
||||
promptTokens: usageMetadata?.promptTokenCount || 0,
|
||||
completionTokens: usageMetadata?.candidatesTokenCount || 0,
|
||||
totalTokens: usageMetadata?.totalTokenCount || 0
|
||||
}
|
||||
};
|
||||
} catch (error) {
|
||||
throw this.handleGeminiError(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get information about the Gemini provider
|
||||
*/
|
||||
public getInfo(): ProviderInfo {
|
||||
return {
|
||||
name: 'Gemini',
|
||||
version: '1.0.0',
|
||||
models: [
|
||||
'gemini-1.5-flash',
|
||||
'gemini-1.5-flash-8b',
|
||||
'gemini-1.5-pro',
|
||||
'gemini-1.0-pro',
|
||||
'gemini-1.0-pro-vision'
|
||||
],
|
||||
maxContextLength: 1000000, // Gemini 1.5 context length
|
||||
supportsStreaming: true,
|
||||
capabilities: {
|
||||
vision: true,
|
||||
functionCalling: true,
|
||||
systemMessages: true,
|
||||
multimodal: true,
|
||||
largeContext: true
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate the connection by making a simple request
|
||||
*/
|
||||
private async validateConnection(): Promise<void> {
|
||||
if (!this.model) {
|
||||
throw new Error('Model not initialized');
|
||||
}
|
||||
|
||||
try {
|
||||
// Make a minimal request to validate credentials
|
||||
await this.model.generateContent('Hi');
|
||||
} catch (error: any) {
|
||||
if (error.message?.includes('API key') || error.message?.includes('authentication')) {
|
||||
throw new AIProviderError(
|
||||
'Invalid API key. Please check your Google AI API key.',
|
||||
AIErrorType.AUTHENTICATION
|
||||
);
|
||||
}
|
||||
// For other errors during validation, we'll let initialization proceed
|
||||
// as they might be temporary issues
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert our generic message format to Gemini's format
|
||||
* Gemini uses Contents with Parts and supports system instructions separately
|
||||
*/
|
||||
private convertMessages(messages: AIMessage[]): { systemInstruction?: string; contents: Content[] } {
|
||||
let systemInstruction: string | undefined;
|
||||
const contents: Content[] = [];
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.role === 'system') {
|
||||
// Combine multiple system messages
|
||||
if (systemInstruction) {
|
||||
systemInstruction += '\n\n' + message.content;
|
||||
} else {
|
||||
systemInstruction = message.content;
|
||||
}
|
||||
} else {
|
||||
contents.push({
|
||||
role: message.role === 'assistant' ? 'model' : 'user',
|
||||
parts: [{ text: message.content }]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return { systemInstruction, contents };
|
||||
}
|
||||
|
||||
/**
|
||||
* Build generation config from completion parameters
|
||||
*/
|
||||
private buildGenerationConfig(params: CompletionParams) {
|
||||
return {
|
||||
temperature: params.temperature ?? 0.7,
|
||||
topP: params.topP,
|
||||
maxOutputTokens: params.maxTokens || 1000,
|
||||
stopSequences: params.stopSequences,
|
||||
...this.generationConfig
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Format Gemini's response to our standard format
|
||||
*/
|
||||
private formatCompletionResponse(response: any, model: string): CompletionResponse {
|
||||
const candidate = response.candidates?.[0];
|
||||
if (!candidate || !candidate.content?.parts?.[0]?.text) {
|
||||
throw new AIProviderError(
|
||||
'No content in Gemini response',
|
||||
AIErrorType.UNKNOWN
|
||||
);
|
||||
}
|
||||
|
||||
const content = candidate.content.parts
|
||||
.filter((part: Part) => part.text)
|
||||
.map((part: Part) => part.text)
|
||||
.join('');
|
||||
|
||||
const usageMetadata = response.usageMetadata;
|
||||
const requestId = `gemini-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
|
||||
return {
|
||||
content,
|
||||
model,
|
||||
usage: {
|
||||
promptTokens: usageMetadata?.promptTokenCount || 0,
|
||||
completionTokens: usageMetadata?.candidatesTokenCount || 0,
|
||||
totalTokens: usageMetadata?.totalTokenCount || 0
|
||||
},
|
||||
id: requestId,
|
||||
metadata: {
|
||||
finishReason: candidate.finishReason,
|
||||
safetyRatings: candidate.safetyRatings,
|
||||
citationMetadata: candidate.citationMetadata
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle Gemini-specific errors and convert them to our standard format
|
||||
*/
|
||||
private handleGeminiError(error: any): AIProviderError {
|
||||
if (error instanceof AIProviderError) {
|
||||
return error;
|
||||
}
|
||||
|
||||
const message = error.message || 'Unknown Gemini API error';
|
||||
|
||||
// Handle common Gemini error patterns
|
||||
if (message.includes('API key')) {
|
||||
return new AIProviderError(
|
||||
'Authentication failed. Please check your Google AI API key.',
|
||||
AIErrorType.AUTHENTICATION,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('quota') || message.includes('rate limit')) {
|
||||
return new AIProviderError(
|
||||
'Rate limit exceeded. Please slow down your requests.',
|
||||
AIErrorType.RATE_LIMIT,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('model') && message.includes('not found')) {
|
||||
return new AIProviderError(
|
||||
'Model not found. Please check the model name.',
|
||||
AIErrorType.MODEL_NOT_FOUND,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('invalid') || message.includes('bad request')) {
|
||||
return new AIProviderError(
|
||||
`Invalid request: ${message}`,
|
||||
AIErrorType.INVALID_REQUEST,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('network') || message.includes('connection')) {
|
||||
return new AIProviderError(
|
||||
'Network error occurred. Please check your connection.',
|
||||
AIErrorType.NETWORK,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('timeout')) {
|
||||
return new AIProviderError(
|
||||
'Request timed out. Please try again.',
|
||||
AIErrorType.TIMEOUT,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
return new AIProviderError(
|
||||
`Gemini API error: ${message}`,
|
||||
AIErrorType.UNKNOWN,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
@ -5,4 +5,5 @@
|
||||
|
||||
export { BaseAIProvider } from './base.js';
|
||||
export { ClaudeProvider, type ClaudeConfig } from './claude.js';
|
||||
export { OpenAIProvider, type OpenAIConfig } from './openai.js';
|
||||
export { OpenAIProvider, type OpenAIConfig } from './openai.js';
|
||||
export { GeminiProvider, type GeminiConfig } from './gemini.js';
|
@ -6,12 +6,13 @@
|
||||
import type { AIProviderConfig } from '../types/index.js';
|
||||
import { ClaudeProvider, type ClaudeConfig } from '../providers/claude.js';
|
||||
import { OpenAIProvider, type OpenAIConfig } from '../providers/openai.js';
|
||||
import { GeminiProvider, type GeminiConfig } from '../providers/gemini.js';
|
||||
import { BaseAIProvider } from '../providers/base.js';
|
||||
|
||||
/**
|
||||
* Supported AI provider types
|
||||
*/
|
||||
export type ProviderType = 'claude' | 'openai';
|
||||
export type ProviderType = 'claude' | 'openai' | 'gemini';
|
||||
|
||||
/**
|
||||
* Configuration map for different provider types
|
||||
@ -19,6 +20,7 @@ export type ProviderType = 'claude' | 'openai';
|
||||
export interface ProviderConfigMap {
|
||||
claude: ClaudeConfig;
|
||||
openai: OpenAIConfig;
|
||||
gemini: GeminiConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -36,6 +38,8 @@ export function createProvider<T extends ProviderType>(
|
||||
return new ClaudeProvider(config as ClaudeConfig);
|
||||
case 'openai':
|
||||
return new OpenAIProvider(config as OpenAIConfig);
|
||||
case 'gemini':
|
||||
return new GeminiProvider(config as GeminiConfig);
|
||||
default:
|
||||
throw new Error(`Unsupported provider type: ${type}`);
|
||||
}
|
||||
@ -73,6 +77,22 @@ export function createOpenAIProvider(
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a Gemini provider with simplified configuration
|
||||
* @param apiKey - Google AI API key
|
||||
* @param options - Optional additional configuration
|
||||
* @returns Configured Gemini provider instance
|
||||
*/
|
||||
export function createGeminiProvider(
|
||||
apiKey: string,
|
||||
options: Partial<Omit<GeminiConfig, 'apiKey'>> = {}
|
||||
): GeminiProvider {
|
||||
return new GeminiProvider({
|
||||
apiKey,
|
||||
...options
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider registry for dynamic provider creation
|
||||
*/
|
||||
@ -122,4 +142,5 @@ export class ProviderRegistry {
|
||||
|
||||
// Pre-register built-in providers
|
||||
ProviderRegistry.register('claude', ClaudeProvider);
|
||||
ProviderRegistry.register('openai', OpenAIProvider);
|
||||
ProviderRegistry.register('openai', OpenAIProvider);
|
||||
ProviderRegistry.register('gemini', GeminiProvider);
|
357
tests/gemini.test.ts
Normal file
357
tests/gemini.test.ts
Normal file
@ -0,0 +1,357 @@
|
||||
/**
|
||||
* Tests for Gemini Provider
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach } from 'bun:test';
|
||||
import { GeminiProvider, AIProviderError, AIErrorType } from '../src/index.js';
|
||||
|
||||
describe('GeminiProvider', () => {
|
||||
let provider: GeminiProvider;
|
||||
|
||||
beforeEach(() => {
|
||||
provider = new GeminiProvider({
|
||||
apiKey: 'test-api-key',
|
||||
defaultModel: 'gemini-1.5-flash'
|
||||
});
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should create provider with valid config', () => {
|
||||
expect(provider).toBeInstanceOf(GeminiProvider);
|
||||
expect(provider.isInitialized()).toBe(false);
|
||||
});
|
||||
|
||||
it('should throw error for missing API key', () => {
|
||||
expect(() => {
|
||||
new GeminiProvider({ apiKey: '' });
|
||||
}).toThrow(AIProviderError);
|
||||
});
|
||||
|
||||
it('should set default model', () => {
|
||||
const customProvider = new GeminiProvider({
|
||||
apiKey: 'test-key',
|
||||
defaultModel: 'gemini-1.5-pro'
|
||||
});
|
||||
expect(customProvider).toBeInstanceOf(GeminiProvider);
|
||||
});
|
||||
|
||||
it('should handle safety settings and generation config', () => {
|
||||
const customProvider = new GeminiProvider({
|
||||
apiKey: 'test-key',
|
||||
safetySettings: [],
|
||||
generationConfig: {
|
||||
temperature: 0.8,
|
||||
topP: 0.9,
|
||||
topK: 40
|
||||
}
|
||||
});
|
||||
expect(customProvider).toBeInstanceOf(GeminiProvider);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getInfo', () => {
|
||||
it('should return provider information', () => {
|
||||
const info = provider.getInfo();
|
||||
|
||||
expect(info.name).toBe('Gemini');
|
||||
expect(info.version).toBe('1.0.0');
|
||||
expect(info.supportsStreaming).toBe(true);
|
||||
expect(info.models).toContain('gemini-1.5-flash');
|
||||
expect(info.models).toContain('gemini-1.5-pro');
|
||||
expect(info.models).toContain('gemini-1.0-pro');
|
||||
expect(info.maxContextLength).toBe(1000000);
|
||||
expect(info.capabilities).toHaveProperty('vision', true);
|
||||
expect(info.capabilities).toHaveProperty('functionCalling', true);
|
||||
expect(info.capabilities).toHaveProperty('systemMessages', true);
|
||||
expect(info.capabilities).toHaveProperty('multimodal', true);
|
||||
expect(info.capabilities).toHaveProperty('largeContext', true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validation', () => {
|
||||
it('should validate temperature range', async () => {
|
||||
// Mock initialization to avoid API call
|
||||
(provider as any).initialized = true;
|
||||
(provider as any).client = {};
|
||||
(provider as any).model = {};
|
||||
|
||||
await expect(
|
||||
provider.complete({
|
||||
messages: [{ role: 'user', content: 'test' }],
|
||||
temperature: 1.5
|
||||
})
|
||||
).rejects.toThrow('Temperature must be between 0.0 and 1.0');
|
||||
});
|
||||
|
||||
it('should validate top_p range', async () => {
|
||||
(provider as any).initialized = true;
|
||||
(provider as any).client = {};
|
||||
(provider as any).model = {};
|
||||
|
||||
await expect(
|
||||
provider.complete({
|
||||
messages: [{ role: 'user', content: 'test' }],
|
||||
topP: 1.5
|
||||
})
|
||||
).rejects.toThrow('Top-p must be between 0.0 and 1.0');
|
||||
});
|
||||
|
||||
it('should validate message format', async () => {
|
||||
(provider as any).initialized = true;
|
||||
(provider as any).client = {};
|
||||
(provider as any).model = {};
|
||||
|
||||
await expect(
|
||||
provider.complete({
|
||||
messages: [{ role: 'invalid' as any, content: 'test' }]
|
||||
})
|
||||
).rejects.toThrow('Each message must have a valid role');
|
||||
});
|
||||
|
||||
it('should validate empty content', async () => {
|
||||
(provider as any).initialized = true;
|
||||
(provider as any).client = {};
|
||||
(provider as any).model = {};
|
||||
|
||||
await expect(
|
||||
provider.complete({
|
||||
messages: [{ role: 'user', content: '' }]
|
||||
})
|
||||
).rejects.toThrow('Each message must have non-empty string content');
|
||||
});
|
||||
|
||||
it('should require initialization before use', async () => {
|
||||
await expect(
|
||||
provider.complete({
|
||||
messages: [{ role: 'user', content: 'test' }]
|
||||
})
|
||||
).rejects.toThrow('Provider must be initialized before use');
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should handle authentication errors', () => {
|
||||
const error = new Error('API key invalid');
|
||||
|
||||
const providerError = (provider as any).handleGeminiError(error);
|
||||
|
||||
expect(providerError).toBeInstanceOf(AIProviderError);
|
||||
expect(providerError.type).toBe(AIErrorType.AUTHENTICATION);
|
||||
expect(providerError.message).toContain('Authentication failed');
|
||||
});
|
||||
|
||||
it('should handle rate limit errors', () => {
|
||||
const error = new Error('quota exceeded');
|
||||
|
||||
const providerError = (provider as any).handleGeminiError(error);
|
||||
|
||||
expect(providerError).toBeInstanceOf(AIProviderError);
|
||||
expect(providerError.type).toBe(AIErrorType.RATE_LIMIT);
|
||||
expect(providerError.message).toContain('Rate limit exceeded');
|
||||
});
|
||||
|
||||
it('should handle model not found errors', () => {
|
||||
const error = new Error('model not found');
|
||||
|
||||
const providerError = (provider as any).handleGeminiError(error);
|
||||
|
||||
expect(providerError).toBeInstanceOf(AIProviderError);
|
||||
expect(providerError.type).toBe(AIErrorType.MODEL_NOT_FOUND);
|
||||
expect(providerError.message).toContain('Model not found');
|
||||
});
|
||||
|
||||
it('should handle invalid request errors', () => {
|
||||
const error = new Error('invalid request parameters');
|
||||
|
||||
const providerError = (provider as any).handleGeminiError(error);
|
||||
|
||||
expect(providerError).toBeInstanceOf(AIProviderError);
|
||||
expect(providerError.type).toBe(AIErrorType.INVALID_REQUEST);
|
||||
});
|
||||
|
||||
it('should handle network errors', () => {
|
||||
const error = new Error('network connection failed');
|
||||
|
||||
const providerError = (provider as any).handleGeminiError(error);
|
||||
|
||||
expect(providerError).toBeInstanceOf(AIProviderError);
|
||||
expect(providerError.type).toBe(AIErrorType.NETWORK);
|
||||
});
|
||||
|
||||
it('should handle timeout errors', () => {
|
||||
const error = new Error('request timeout');
|
||||
|
||||
const providerError = (provider as any).handleGeminiError(error);
|
||||
|
||||
expect(providerError).toBeInstanceOf(AIProviderError);
|
||||
expect(providerError.type).toBe(AIErrorType.TIMEOUT);
|
||||
});
|
||||
|
||||
it('should handle unknown errors', () => {
|
||||
const error = new Error('Unknown error');
|
||||
|
||||
const providerError = (provider as any).handleGeminiError(error);
|
||||
|
||||
expect(providerError).toBeInstanceOf(AIProviderError);
|
||||
expect(providerError.type).toBe(AIErrorType.UNKNOWN);
|
||||
});
|
||||
});
|
||||
|
||||
describe('message conversion', () => {
|
||||
it('should convert messages to Gemini format', () => {
|
||||
const messages = [
|
||||
{ role: 'system' as const, content: 'You are helpful' },
|
||||
{ role: 'user' as const, content: 'Hello' },
|
||||
{ role: 'assistant' as const, content: 'Hi there' }
|
||||
];
|
||||
|
||||
const result = (provider as any).convertMessages(messages);
|
||||
|
||||
expect(result.systemInstruction).toBe('You are helpful');
|
||||
expect(result.contents).toHaveLength(2);
|
||||
expect(result.contents[0]).toEqual({
|
||||
role: 'user',
|
||||
parts: [{ text: 'Hello' }]
|
||||
});
|
||||
expect(result.contents[1]).toEqual({
|
||||
role: 'model',
|
||||
parts: [{ text: 'Hi there' }]
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple system messages', () => {
|
||||
const messages = [
|
||||
{ role: 'system' as const, content: 'You are helpful' },
|
||||
{ role: 'user' as const, content: 'Hello' },
|
||||
{ role: 'system' as const, content: 'Be concise' }
|
||||
];
|
||||
|
||||
const result = (provider as any).convertMessages(messages);
|
||||
|
||||
expect(result.systemInstruction).toBe('You are helpful\n\nBe concise');
|
||||
expect(result.contents).toHaveLength(1);
|
||||
expect(result.contents[0].role).toBe('user');
|
||||
});
|
||||
|
||||
it('should handle messages without system prompts', () => {
|
||||
const messages = [
|
||||
{ role: 'user' as const, content: 'Hello' },
|
||||
{ role: 'assistant' as const, content: 'Hi there' }
|
||||
];
|
||||
|
||||
const result = (provider as any).convertMessages(messages);
|
||||
|
||||
expect(result.systemInstruction).toBeUndefined();
|
||||
expect(result.contents).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should convert assistant role to model role', () => {
|
||||
const messages = [
|
||||
{ role: 'assistant' as const, content: 'I am an assistant' }
|
||||
];
|
||||
|
||||
const result = (provider as any).convertMessages(messages);
|
||||
|
||||
expect(result.contents[0].role).toBe('model');
|
||||
expect(result.contents[0].parts[0].text).toBe('I am an assistant');
|
||||
});
|
||||
});
|
||||
|
||||
describe('generation config', () => {
|
||||
it('should build generation config from completion params', () => {
|
||||
const params = {
|
||||
messages: [{ role: 'user' as const, content: 'test' }],
|
||||
temperature: 0.8,
|
||||
topP: 0.9,
|
||||
maxTokens: 500,
|
||||
stopSequences: ['STOP', 'END']
|
||||
};
|
||||
|
||||
const result = (provider as any).buildGenerationConfig(params);
|
||||
|
||||
expect(result.temperature).toBe(0.8);
|
||||
expect(result.topP).toBe(0.9);
|
||||
expect(result.maxOutputTokens).toBe(500);
|
||||
expect(result.stopSequences).toEqual(['STOP', 'END']);
|
||||
});
|
||||
|
||||
it('should use default temperature when not provided', () => {
|
||||
const params = {
|
||||
messages: [{ role: 'user' as const, content: 'test' }]
|
||||
};
|
||||
|
||||
const result = (provider as any).buildGenerationConfig(params);
|
||||
|
||||
expect(result.temperature).toBe(0.7);
|
||||
expect(result.maxOutputTokens).toBe(1000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('response formatting', () => {
|
||||
it('should format completion response correctly', () => {
|
||||
const mockResponse = {
|
||||
candidates: [{
|
||||
content: {
|
||||
parts: [{ text: 'Hello there!' }]
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
safetyRatings: [],
|
||||
citationMetadata: null
|
||||
}],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 20,
|
||||
totalTokenCount: 30
|
||||
}
|
||||
};
|
||||
|
||||
const result = (provider as any).formatCompletionResponse(mockResponse, 'gemini-1.5-flash');
|
||||
|
||||
expect(result.content).toBe('Hello there!');
|
||||
expect(result.model).toBe('gemini-1.5-flash');
|
||||
expect(result.usage.promptTokens).toBe(10);
|
||||
expect(result.usage.completionTokens).toBe(20);
|
||||
expect(result.usage.totalTokens).toBe(30);
|
||||
expect(result.metadata.finishReason).toBe('STOP');
|
||||
});
|
||||
|
||||
it('should handle multiple text parts', () => {
|
||||
const mockResponse = {
|
||||
candidates: [{
|
||||
content: {
|
||||
parts: [
|
||||
{ text: 'Hello ' },
|
||||
{ text: 'there!' },
|
||||
{ functionCall: { name: 'test' } } // Non-text part should be filtered
|
||||
]
|
||||
},
|
||||
finishReason: 'STOP'
|
||||
}],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 5,
|
||||
candidatesTokenCount: 10,
|
||||
totalTokenCount: 15
|
||||
}
|
||||
};
|
||||
|
||||
const result = (provider as any).formatCompletionResponse(mockResponse, 'gemini-1.5-flash');
|
||||
|
||||
expect(result.content).toBe('Hello there!');
|
||||
});
|
||||
|
||||
it('should throw error for empty response', () => {
|
||||
const mockResponse = {
|
||||
candidates: [],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 5,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 5
|
||||
}
|
||||
};
|
||||
|
||||
expect(() => {
|
||||
(provider as any).formatCompletionResponse(mockResponse, 'gemini-1.5-flash');
|
||||
}).toThrow('No content in Gemini response');
|
||||
});
|
||||
});
|
||||
});
|
Reference in New Issue
Block a user