feat: add Google Gemini provider integration and docs
This commit is contained in:
@@ -29,12 +29,14 @@ export { BaseAIProvider } from './providers/base.js';
|
||||
// Concrete provider implementations
|
||||
export { ClaudeProvider, type ClaudeConfig } from './providers/claude.js';
|
||||
export { OpenAIProvider, type OpenAIConfig } from './providers/openai.js';
|
||||
export { GeminiProvider, type GeminiConfig } from './providers/gemini.js';
|
||||
|
||||
// Utility functions and factory
|
||||
export {
|
||||
createProvider,
|
||||
createClaudeProvider,
|
||||
createOpenAIProvider,
|
||||
createGeminiProvider,
|
||||
ProviderRegistry,
|
||||
type ProviderType,
|
||||
type ProviderConfigMap
|
||||
@@ -51,4 +53,4 @@ export const VERSION = '1.0.0';
|
||||
/**
|
||||
* List of supported providers
|
||||
*/
|
||||
export const SUPPORTED_PROVIDERS = ['claude', 'openai'] as const;
|
||||
export const SUPPORTED_PROVIDERS = ['claude', 'openai', 'gemini'] as const;
|
||||
403
src/providers/gemini.ts
Normal file
403
src/providers/gemini.ts
Normal file
@@ -0,0 +1,403 @@
|
||||
/**
|
||||
* Gemini Provider implementation using Google's Generative AI API
|
||||
* Provides integration with Gemini models through a standardized interface
|
||||
*/
|
||||
|
||||
import { GoogleGenerativeAI, GenerativeModel } from '@google/generative-ai';
|
||||
import type { Content, Part } from '@google/generative-ai';
|
||||
import type {
|
||||
AIProviderConfig,
|
||||
CompletionParams,
|
||||
CompletionResponse,
|
||||
CompletionChunk,
|
||||
ProviderInfo,
|
||||
AIMessage
|
||||
} from '../types/index.js';
|
||||
import { BaseAIProvider } from './base.js';
|
||||
import { AIProviderError, AIErrorType } from '../types/index.js';
|
||||
|
||||
/**
|
||||
* Configuration specific to Gemini provider
|
||||
*/
|
||||
export interface GeminiConfig extends AIProviderConfig {
|
||||
/** Default model to use if not specified in requests (default: gemini-1.5-flash) */
|
||||
defaultModel?: string;
|
||||
/** Safety settings for content filtering */
|
||||
safetySettings?: any[];
|
||||
/** Generation configuration */
|
||||
generationConfig?: {
|
||||
temperature?: number;
|
||||
topP?: number;
|
||||
topK?: number;
|
||||
maxOutputTokens?: number;
|
||||
stopSequences?: string[];
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gemini provider implementation
|
||||
*/
|
||||
export class GeminiProvider extends BaseAIProvider {
|
||||
private client: GoogleGenerativeAI | null = null;
|
||||
private model: GenerativeModel | null = null;
|
||||
private readonly defaultModel: string;
|
||||
private readonly safetySettings?: any[];
|
||||
private readonly generationConfig?: any;
|
||||
|
||||
constructor(config: GeminiConfig) {
|
||||
super(config);
|
||||
this.defaultModel = config.defaultModel || 'gemini-1.5-flash';
|
||||
this.safetySettings = config.safetySettings;
|
||||
this.generationConfig = config.generationConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the Gemini provider by setting up the Google Generative AI client
|
||||
*/
|
||||
protected async doInitialize(): Promise<void> {
|
||||
try {
|
||||
this.client = new GoogleGenerativeAI(this.config.apiKey);
|
||||
this.model = this.client.getGenerativeModel({
|
||||
model: this.defaultModel,
|
||||
safetySettings: this.safetySettings,
|
||||
generationConfig: this.generationConfig
|
||||
});
|
||||
|
||||
// Test the connection by making a simple request
|
||||
await this.validateConnection();
|
||||
} catch (error) {
|
||||
throw new AIProviderError(
|
||||
`Failed to initialize Gemini provider: ${(error as Error).message}`,
|
||||
AIErrorType.AUTHENTICATION,
|
||||
undefined,
|
||||
error as Error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a completion using Gemini
|
||||
*/
|
||||
protected async doComplete(params: CompletionParams): Promise<CompletionResponse> {
|
||||
if (!this.client || !this.model) {
|
||||
throw new AIProviderError('Client not initialized', AIErrorType.INVALID_REQUEST);
|
||||
}
|
||||
|
||||
try {
|
||||
// Get the model for this request (might be different from default)
|
||||
const model = params.model && params.model !== this.defaultModel
|
||||
? this.client.getGenerativeModel({
|
||||
model: params.model,
|
||||
safetySettings: this.safetySettings,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
})
|
||||
: this.model;
|
||||
|
||||
const { systemInstruction, contents } = this.convertMessages(params.messages);
|
||||
|
||||
// Create chat session or use generateContent
|
||||
if (contents.length > 1) {
|
||||
// Multi-turn conversation - use chat session
|
||||
const chat = model.startChat({
|
||||
history: contents.slice(0, -1),
|
||||
systemInstruction,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
});
|
||||
|
||||
const lastMessage = contents[contents.length - 1];
|
||||
if (!lastMessage) {
|
||||
throw new AIProviderError('No valid messages provided', AIErrorType.INVALID_REQUEST);
|
||||
}
|
||||
const result = await chat.sendMessage(lastMessage.parts);
|
||||
|
||||
return this.formatCompletionResponse(result.response, params.model || this.defaultModel);
|
||||
} else {
|
||||
// Single message - use generateContent
|
||||
const result = await model.generateContent({
|
||||
contents,
|
||||
systemInstruction,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
});
|
||||
|
||||
return this.formatCompletionResponse(result.response, params.model || this.defaultModel);
|
||||
}
|
||||
} catch (error) {
|
||||
throw this.handleGeminiError(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a streaming completion using Gemini
|
||||
*/
|
||||
protected async *doStream(params: CompletionParams): AsyncIterable<CompletionChunk> {
|
||||
if (!this.client || !this.model) {
|
||||
throw new AIProviderError('Client not initialized', AIErrorType.INVALID_REQUEST);
|
||||
}
|
||||
|
||||
try {
|
||||
// Get the model for this request
|
||||
const model = params.model && params.model !== this.defaultModel
|
||||
? this.client.getGenerativeModel({
|
||||
model: params.model,
|
||||
safetySettings: this.safetySettings,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
})
|
||||
: this.model;
|
||||
|
||||
const { systemInstruction, contents } = this.convertMessages(params.messages);
|
||||
|
||||
let stream;
|
||||
if (contents.length > 1) {
|
||||
// Multi-turn conversation
|
||||
const chat = model.startChat({
|
||||
history: contents.slice(0, -1),
|
||||
systemInstruction,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
});
|
||||
|
||||
const lastMessage = contents[contents.length - 1];
|
||||
if (!lastMessage) {
|
||||
throw new AIProviderError('No valid messages provided', AIErrorType.INVALID_REQUEST);
|
||||
}
|
||||
stream = await chat.sendMessageStream(lastMessage.parts);
|
||||
} else {
|
||||
// Single message
|
||||
stream = await model.generateContentStream({
|
||||
contents,
|
||||
systemInstruction,
|
||||
generationConfig: this.buildGenerationConfig(params)
|
||||
});
|
||||
}
|
||||
|
||||
let fullText = '';
|
||||
const requestId = `gemini-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
|
||||
for await (const chunk of stream.stream) {
|
||||
const chunkText = chunk.text();
|
||||
fullText += chunkText;
|
||||
|
||||
yield {
|
||||
content: chunkText,
|
||||
isComplete: false,
|
||||
id: requestId
|
||||
};
|
||||
}
|
||||
|
||||
// Final chunk with usage info
|
||||
const finalResponse = await stream.response;
|
||||
const usageMetadata = finalResponse.usageMetadata;
|
||||
|
||||
yield {
|
||||
content: '',
|
||||
isComplete: true,
|
||||
id: requestId,
|
||||
usage: {
|
||||
promptTokens: usageMetadata?.promptTokenCount || 0,
|
||||
completionTokens: usageMetadata?.candidatesTokenCount || 0,
|
||||
totalTokens: usageMetadata?.totalTokenCount || 0
|
||||
}
|
||||
};
|
||||
} catch (error) {
|
||||
throw this.handleGeminiError(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get information about the Gemini provider
|
||||
*/
|
||||
public getInfo(): ProviderInfo {
|
||||
return {
|
||||
name: 'Gemini',
|
||||
version: '1.0.0',
|
||||
models: [
|
||||
'gemini-1.5-flash',
|
||||
'gemini-1.5-flash-8b',
|
||||
'gemini-1.5-pro',
|
||||
'gemini-1.0-pro',
|
||||
'gemini-1.0-pro-vision'
|
||||
],
|
||||
maxContextLength: 1000000, // Gemini 1.5 context length
|
||||
supportsStreaming: true,
|
||||
capabilities: {
|
||||
vision: true,
|
||||
functionCalling: true,
|
||||
systemMessages: true,
|
||||
multimodal: true,
|
||||
largeContext: true
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate the connection by making a simple request
|
||||
*/
|
||||
private async validateConnection(): Promise<void> {
|
||||
if (!this.model) {
|
||||
throw new Error('Model not initialized');
|
||||
}
|
||||
|
||||
try {
|
||||
// Make a minimal request to validate credentials
|
||||
await this.model.generateContent('Hi');
|
||||
} catch (error: any) {
|
||||
if (error.message?.includes('API key') || error.message?.includes('authentication')) {
|
||||
throw new AIProviderError(
|
||||
'Invalid API key. Please check your Google AI API key.',
|
||||
AIErrorType.AUTHENTICATION
|
||||
);
|
||||
}
|
||||
// For other errors during validation, we'll let initialization proceed
|
||||
// as they might be temporary issues
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert our generic message format to Gemini's format
|
||||
* Gemini uses Contents with Parts and supports system instructions separately
|
||||
*/
|
||||
private convertMessages(messages: AIMessage[]): { systemInstruction?: string; contents: Content[] } {
|
||||
let systemInstruction: string | undefined;
|
||||
const contents: Content[] = [];
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.role === 'system') {
|
||||
// Combine multiple system messages
|
||||
if (systemInstruction) {
|
||||
systemInstruction += '\n\n' + message.content;
|
||||
} else {
|
||||
systemInstruction = message.content;
|
||||
}
|
||||
} else {
|
||||
contents.push({
|
||||
role: message.role === 'assistant' ? 'model' : 'user',
|
||||
parts: [{ text: message.content }]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return { systemInstruction, contents };
|
||||
}
|
||||
|
||||
/**
|
||||
* Build generation config from completion parameters
|
||||
*/
|
||||
private buildGenerationConfig(params: CompletionParams) {
|
||||
return {
|
||||
temperature: params.temperature ?? 0.7,
|
||||
topP: params.topP,
|
||||
maxOutputTokens: params.maxTokens || 1000,
|
||||
stopSequences: params.stopSequences,
|
||||
...this.generationConfig
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Format Gemini's response to our standard format
|
||||
*/
|
||||
private formatCompletionResponse(response: any, model: string): CompletionResponse {
|
||||
const candidate = response.candidates?.[0];
|
||||
if (!candidate || !candidate.content?.parts?.[0]?.text) {
|
||||
throw new AIProviderError(
|
||||
'No content in Gemini response',
|
||||
AIErrorType.UNKNOWN
|
||||
);
|
||||
}
|
||||
|
||||
const content = candidate.content.parts
|
||||
.filter((part: Part) => part.text)
|
||||
.map((part: Part) => part.text)
|
||||
.join('');
|
||||
|
||||
const usageMetadata = response.usageMetadata;
|
||||
const requestId = `gemini-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
|
||||
return {
|
||||
content,
|
||||
model,
|
||||
usage: {
|
||||
promptTokens: usageMetadata?.promptTokenCount || 0,
|
||||
completionTokens: usageMetadata?.candidatesTokenCount || 0,
|
||||
totalTokens: usageMetadata?.totalTokenCount || 0
|
||||
},
|
||||
id: requestId,
|
||||
metadata: {
|
||||
finishReason: candidate.finishReason,
|
||||
safetyRatings: candidate.safetyRatings,
|
||||
citationMetadata: candidate.citationMetadata
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle Gemini-specific errors and convert them to our standard format
|
||||
*/
|
||||
private handleGeminiError(error: any): AIProviderError {
|
||||
if (error instanceof AIProviderError) {
|
||||
return error;
|
||||
}
|
||||
|
||||
const message = error.message || 'Unknown Gemini API error';
|
||||
|
||||
// Handle common Gemini error patterns
|
||||
if (message.includes('API key')) {
|
||||
return new AIProviderError(
|
||||
'Authentication failed. Please check your Google AI API key.',
|
||||
AIErrorType.AUTHENTICATION,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('quota') || message.includes('rate limit')) {
|
||||
return new AIProviderError(
|
||||
'Rate limit exceeded. Please slow down your requests.',
|
||||
AIErrorType.RATE_LIMIT,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('model') && message.includes('not found')) {
|
||||
return new AIProviderError(
|
||||
'Model not found. Please check the model name.',
|
||||
AIErrorType.MODEL_NOT_FOUND,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('invalid') || message.includes('bad request')) {
|
||||
return new AIProviderError(
|
||||
`Invalid request: ${message}`,
|
||||
AIErrorType.INVALID_REQUEST,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('network') || message.includes('connection')) {
|
||||
return new AIProviderError(
|
||||
'Network error occurred. Please check your connection.',
|
||||
AIErrorType.NETWORK,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('timeout')) {
|
||||
return new AIProviderError(
|
||||
'Request timed out. Please try again.',
|
||||
AIErrorType.TIMEOUT,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
return new AIProviderError(
|
||||
`Gemini API error: ${message}`,
|
||||
AIErrorType.UNKNOWN,
|
||||
undefined,
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -5,4 +5,5 @@
|
||||
|
||||
export { BaseAIProvider } from './base.js';
|
||||
export { ClaudeProvider, type ClaudeConfig } from './claude.js';
|
||||
export { OpenAIProvider, type OpenAIConfig } from './openai.js';
|
||||
export { OpenAIProvider, type OpenAIConfig } from './openai.js';
|
||||
export { GeminiProvider, type GeminiConfig } from './gemini.js';
|
||||
@@ -6,12 +6,13 @@
|
||||
import type { AIProviderConfig } from '../types/index.js';
|
||||
import { ClaudeProvider, type ClaudeConfig } from '../providers/claude.js';
|
||||
import { OpenAIProvider, type OpenAIConfig } from '../providers/openai.js';
|
||||
import { GeminiProvider, type GeminiConfig } from '../providers/gemini.js';
|
||||
import { BaseAIProvider } from '../providers/base.js';
|
||||
|
||||
/**
|
||||
* Supported AI provider types
|
||||
*/
|
||||
export type ProviderType = 'claude' | 'openai';
|
||||
export type ProviderType = 'claude' | 'openai' | 'gemini';
|
||||
|
||||
/**
|
||||
* Configuration map for different provider types
|
||||
@@ -19,6 +20,7 @@ export type ProviderType = 'claude' | 'openai';
|
||||
export interface ProviderConfigMap {
|
||||
claude: ClaudeConfig;
|
||||
openai: OpenAIConfig;
|
||||
gemini: GeminiConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,6 +38,8 @@ export function createProvider<T extends ProviderType>(
|
||||
return new ClaudeProvider(config as ClaudeConfig);
|
||||
case 'openai':
|
||||
return new OpenAIProvider(config as OpenAIConfig);
|
||||
case 'gemini':
|
||||
return new GeminiProvider(config as GeminiConfig);
|
||||
default:
|
||||
throw new Error(`Unsupported provider type: ${type}`);
|
||||
}
|
||||
@@ -73,6 +77,22 @@ export function createOpenAIProvider(
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a Gemini provider with simplified configuration
|
||||
* @param apiKey - Google AI API key
|
||||
* @param options - Optional additional configuration
|
||||
* @returns Configured Gemini provider instance
|
||||
*/
|
||||
export function createGeminiProvider(
|
||||
apiKey: string,
|
||||
options: Partial<Omit<GeminiConfig, 'apiKey'>> = {}
|
||||
): GeminiProvider {
|
||||
return new GeminiProvider({
|
||||
apiKey,
|
||||
...options
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider registry for dynamic provider creation
|
||||
*/
|
||||
@@ -122,4 +142,5 @@ export class ProviderRegistry {
|
||||
|
||||
// Pre-register built-in providers
|
||||
ProviderRegistry.register('claude', ClaudeProvider);
|
||||
ProviderRegistry.register('openai', OpenAIProvider);
|
||||
ProviderRegistry.register('openai', OpenAIProvider);
|
||||
ProviderRegistry.register('gemini', GeminiProvider);
|
||||
Reference in New Issue
Block a user