feat: add Google Gemini provider integration and docs

This commit is contained in:
2025-05-28 12:13:29 +02:00
parent aa2fd98cc1
commit 5da37f388f
9 changed files with 1015 additions and 87 deletions

View File

@@ -29,12 +29,14 @@ export { BaseAIProvider } from './providers/base.js';
// Concrete provider implementations
export { ClaudeProvider, type ClaudeConfig } from './providers/claude.js';
export { OpenAIProvider, type OpenAIConfig } from './providers/openai.js';
export { GeminiProvider, type GeminiConfig } from './providers/gemini.js';
// Utility functions and factory
export {
createProvider,
createClaudeProvider,
createOpenAIProvider,
createGeminiProvider,
ProviderRegistry,
type ProviderType,
type ProviderConfigMap
@@ -51,4 +53,4 @@ export const VERSION = '1.0.0';
/**
* List of supported providers
*/
export const SUPPORTED_PROVIDERS = ['claude', 'openai'] as const;
export const SUPPORTED_PROVIDERS = ['claude', 'openai', 'gemini'] as const;

403
src/providers/gemini.ts Normal file
View File

@@ -0,0 +1,403 @@
/**
* Gemini Provider implementation using Google's Generative AI API
* Provides integration with Gemini models through a standardized interface
*/
import { GoogleGenerativeAI, GenerativeModel } from '@google/generative-ai';
import type { Content, Part } from '@google/generative-ai';
import type {
AIProviderConfig,
CompletionParams,
CompletionResponse,
CompletionChunk,
ProviderInfo,
AIMessage
} from '../types/index.js';
import { BaseAIProvider } from './base.js';
import { AIProviderError, AIErrorType } from '../types/index.js';
/**
* Configuration specific to Gemini provider
*/
export interface GeminiConfig extends AIProviderConfig {
/** Default model to use if not specified in requests (default: gemini-1.5-flash) */
defaultModel?: string;
/** Safety settings for content filtering */
safetySettings?: any[];
/** Generation configuration */
generationConfig?: {
temperature?: number;
topP?: number;
topK?: number;
maxOutputTokens?: number;
stopSequences?: string[];
};
}
/**
* Gemini provider implementation
*/
export class GeminiProvider extends BaseAIProvider {
private client: GoogleGenerativeAI | null = null;
private model: GenerativeModel | null = null;
private readonly defaultModel: string;
private readonly safetySettings?: any[];
private readonly generationConfig?: any;
constructor(config: GeminiConfig) {
super(config);
this.defaultModel = config.defaultModel || 'gemini-1.5-flash';
this.safetySettings = config.safetySettings;
this.generationConfig = config.generationConfig;
}
/**
* Initialize the Gemini provider by setting up the Google Generative AI client
*/
protected async doInitialize(): Promise<void> {
try {
this.client = new GoogleGenerativeAI(this.config.apiKey);
this.model = this.client.getGenerativeModel({
model: this.defaultModel,
safetySettings: this.safetySettings,
generationConfig: this.generationConfig
});
// Test the connection by making a simple request
await this.validateConnection();
} catch (error) {
throw new AIProviderError(
`Failed to initialize Gemini provider: ${(error as Error).message}`,
AIErrorType.AUTHENTICATION,
undefined,
error as Error
);
}
}
/**
* Generate a completion using Gemini
*/
protected async doComplete(params: CompletionParams): Promise<CompletionResponse> {
if (!this.client || !this.model) {
throw new AIProviderError('Client not initialized', AIErrorType.INVALID_REQUEST);
}
try {
// Get the model for this request (might be different from default)
const model = params.model && params.model !== this.defaultModel
? this.client.getGenerativeModel({
model: params.model,
safetySettings: this.safetySettings,
generationConfig: this.buildGenerationConfig(params)
})
: this.model;
const { systemInstruction, contents } = this.convertMessages(params.messages);
// Create chat session or use generateContent
if (contents.length > 1) {
// Multi-turn conversation - use chat session
const chat = model.startChat({
history: contents.slice(0, -1),
systemInstruction,
generationConfig: this.buildGenerationConfig(params)
});
const lastMessage = contents[contents.length - 1];
if (!lastMessage) {
throw new AIProviderError('No valid messages provided', AIErrorType.INVALID_REQUEST);
}
const result = await chat.sendMessage(lastMessage.parts);
return this.formatCompletionResponse(result.response, params.model || this.defaultModel);
} else {
// Single message - use generateContent
const result = await model.generateContent({
contents,
systemInstruction,
generationConfig: this.buildGenerationConfig(params)
});
return this.formatCompletionResponse(result.response, params.model || this.defaultModel);
}
} catch (error) {
throw this.handleGeminiError(error as Error);
}
}
/**
* Generate a streaming completion using Gemini
*/
protected async *doStream(params: CompletionParams): AsyncIterable<CompletionChunk> {
if (!this.client || !this.model) {
throw new AIProviderError('Client not initialized', AIErrorType.INVALID_REQUEST);
}
try {
// Get the model for this request
const model = params.model && params.model !== this.defaultModel
? this.client.getGenerativeModel({
model: params.model,
safetySettings: this.safetySettings,
generationConfig: this.buildGenerationConfig(params)
})
: this.model;
const { systemInstruction, contents } = this.convertMessages(params.messages);
let stream;
if (contents.length > 1) {
// Multi-turn conversation
const chat = model.startChat({
history: contents.slice(0, -1),
systemInstruction,
generationConfig: this.buildGenerationConfig(params)
});
const lastMessage = contents[contents.length - 1];
if (!lastMessage) {
throw new AIProviderError('No valid messages provided', AIErrorType.INVALID_REQUEST);
}
stream = await chat.sendMessageStream(lastMessage.parts);
} else {
// Single message
stream = await model.generateContentStream({
contents,
systemInstruction,
generationConfig: this.buildGenerationConfig(params)
});
}
let fullText = '';
const requestId = `gemini-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
for await (const chunk of stream.stream) {
const chunkText = chunk.text();
fullText += chunkText;
yield {
content: chunkText,
isComplete: false,
id: requestId
};
}
// Final chunk with usage info
const finalResponse = await stream.response;
const usageMetadata = finalResponse.usageMetadata;
yield {
content: '',
isComplete: true,
id: requestId,
usage: {
promptTokens: usageMetadata?.promptTokenCount || 0,
completionTokens: usageMetadata?.candidatesTokenCount || 0,
totalTokens: usageMetadata?.totalTokenCount || 0
}
};
} catch (error) {
throw this.handleGeminiError(error as Error);
}
}
/**
* Get information about the Gemini provider
*/
public getInfo(): ProviderInfo {
return {
name: 'Gemini',
version: '1.0.0',
models: [
'gemini-1.5-flash',
'gemini-1.5-flash-8b',
'gemini-1.5-pro',
'gemini-1.0-pro',
'gemini-1.0-pro-vision'
],
maxContextLength: 1000000, // Gemini 1.5 context length
supportsStreaming: true,
capabilities: {
vision: true,
functionCalling: true,
systemMessages: true,
multimodal: true,
largeContext: true
}
};
}
/**
* Validate the connection by making a simple request
*/
private async validateConnection(): Promise<void> {
if (!this.model) {
throw new Error('Model not initialized');
}
try {
// Make a minimal request to validate credentials
await this.model.generateContent('Hi');
} catch (error: any) {
if (error.message?.includes('API key') || error.message?.includes('authentication')) {
throw new AIProviderError(
'Invalid API key. Please check your Google AI API key.',
AIErrorType.AUTHENTICATION
);
}
// For other errors during validation, we'll let initialization proceed
// as they might be temporary issues
}
}
/**
* Convert our generic message format to Gemini's format
* Gemini uses Contents with Parts and supports system instructions separately
*/
private convertMessages(messages: AIMessage[]): { systemInstruction?: string; contents: Content[] } {
let systemInstruction: string | undefined;
const contents: Content[] = [];
for (const message of messages) {
if (message.role === 'system') {
// Combine multiple system messages
if (systemInstruction) {
systemInstruction += '\n\n' + message.content;
} else {
systemInstruction = message.content;
}
} else {
contents.push({
role: message.role === 'assistant' ? 'model' : 'user',
parts: [{ text: message.content }]
});
}
}
return { systemInstruction, contents };
}
/**
* Build generation config from completion parameters
*/
private buildGenerationConfig(params: CompletionParams) {
return {
temperature: params.temperature ?? 0.7,
topP: params.topP,
maxOutputTokens: params.maxTokens || 1000,
stopSequences: params.stopSequences,
...this.generationConfig
};
}
/**
* Format Gemini's response to our standard format
*/
private formatCompletionResponse(response: any, model: string): CompletionResponse {
const candidate = response.candidates?.[0];
if (!candidate || !candidate.content?.parts?.[0]?.text) {
throw new AIProviderError(
'No content in Gemini response',
AIErrorType.UNKNOWN
);
}
const content = candidate.content.parts
.filter((part: Part) => part.text)
.map((part: Part) => part.text)
.join('');
const usageMetadata = response.usageMetadata;
const requestId = `gemini-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
return {
content,
model,
usage: {
promptTokens: usageMetadata?.promptTokenCount || 0,
completionTokens: usageMetadata?.candidatesTokenCount || 0,
totalTokens: usageMetadata?.totalTokenCount || 0
},
id: requestId,
metadata: {
finishReason: candidate.finishReason,
safetyRatings: candidate.safetyRatings,
citationMetadata: candidate.citationMetadata
}
};
}
/**
* Handle Gemini-specific errors and convert them to our standard format
*/
private handleGeminiError(error: any): AIProviderError {
if (error instanceof AIProviderError) {
return error;
}
const message = error.message || 'Unknown Gemini API error';
// Handle common Gemini error patterns
if (message.includes('API key')) {
return new AIProviderError(
'Authentication failed. Please check your Google AI API key.',
AIErrorType.AUTHENTICATION,
undefined,
error
);
}
if (message.includes('quota') || message.includes('rate limit')) {
return new AIProviderError(
'Rate limit exceeded. Please slow down your requests.',
AIErrorType.RATE_LIMIT,
undefined,
error
);
}
if (message.includes('model') && message.includes('not found')) {
return new AIProviderError(
'Model not found. Please check the model name.',
AIErrorType.MODEL_NOT_FOUND,
undefined,
error
);
}
if (message.includes('invalid') || message.includes('bad request')) {
return new AIProviderError(
`Invalid request: ${message}`,
AIErrorType.INVALID_REQUEST,
undefined,
error
);
}
if (message.includes('network') || message.includes('connection')) {
return new AIProviderError(
'Network error occurred. Please check your connection.',
AIErrorType.NETWORK,
undefined,
error
);
}
if (message.includes('timeout')) {
return new AIProviderError(
'Request timed out. Please try again.',
AIErrorType.TIMEOUT,
undefined,
error
);
}
return new AIProviderError(
`Gemini API error: ${message}`,
AIErrorType.UNKNOWN,
undefined,
error
);
}
}

View File

@@ -5,4 +5,5 @@
export { BaseAIProvider } from './base.js';
export { ClaudeProvider, type ClaudeConfig } from './claude.js';
export { OpenAIProvider, type OpenAIConfig } from './openai.js';
export { OpenAIProvider, type OpenAIConfig } from './openai.js';
export { GeminiProvider, type GeminiConfig } from './gemini.js';

View File

@@ -6,12 +6,13 @@
import type { AIProviderConfig } from '../types/index.js';
import { ClaudeProvider, type ClaudeConfig } from '../providers/claude.js';
import { OpenAIProvider, type OpenAIConfig } from '../providers/openai.js';
import { GeminiProvider, type GeminiConfig } from '../providers/gemini.js';
import { BaseAIProvider } from '../providers/base.js';
/**
* Supported AI provider types
*/
export type ProviderType = 'claude' | 'openai';
export type ProviderType = 'claude' | 'openai' | 'gemini';
/**
* Configuration map for different provider types
@@ -19,6 +20,7 @@ export type ProviderType = 'claude' | 'openai';
export interface ProviderConfigMap {
claude: ClaudeConfig;
openai: OpenAIConfig;
gemini: GeminiConfig;
}
/**
@@ -36,6 +38,8 @@ export function createProvider<T extends ProviderType>(
return new ClaudeProvider(config as ClaudeConfig);
case 'openai':
return new OpenAIProvider(config as OpenAIConfig);
case 'gemini':
return new GeminiProvider(config as GeminiConfig);
default:
throw new Error(`Unsupported provider type: ${type}`);
}
@@ -73,6 +77,22 @@ export function createOpenAIProvider(
});
}
/**
* Create a Gemini provider with simplified configuration
* @param apiKey - Google AI API key
* @param options - Optional additional configuration
* @returns Configured Gemini provider instance
*/
export function createGeminiProvider(
apiKey: string,
options: Partial<Omit<GeminiConfig, 'apiKey'>> = {}
): GeminiProvider {
return new GeminiProvider({
apiKey,
...options
});
}
/**
* Provider registry for dynamic provider creation
*/
@@ -122,4 +142,5 @@ export class ProviderRegistry {
// Pre-register built-in providers
ProviderRegistry.register('claude', ClaudeProvider);
ProviderRegistry.register('openai', OpenAIProvider);
ProviderRegistry.register('openai', OpenAIProvider);
ProviderRegistry.register('gemini', GeminiProvider);