Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚧 wip: test gemini function call #1810

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions package.json
Expand Up @@ -210,6 +210,7 @@
"stylelint": "^15",
"tsx": "^4",
"typescript": "^5",
"undici": "6.5.0",
"unified": "^11",
"unist-util-visit": "^5",
"vite": "^5",
Expand Down
12 changes: 6 additions & 6 deletions src/app/api/chat/google/route.ts
Expand Up @@ -3,15 +3,15 @@ import { POST as UniverseRoute } from '../[provider]/route';
// due to the Chinese region does not support accessing Google
// we need to use proxy to access it
// refs: https://github.com/google/generative-ai-js/issues/29#issuecomment-1866246513
// if (process.env.HTTP_PROXY_URL) {
// const { setGlobalDispatcher, ProxyAgent } = require('undici');
//
// setGlobalDispatcher(new ProxyAgent({ uri: process.env.HTTP_PROXY_URL }));
// }
if (process.env.HTTP_PROXY_URL) {
const { setGlobalDispatcher, ProxyAgent } = require('undici');

setGlobalDispatcher(new ProxyAgent({ uri: process.env.HTTP_PROXY_URL }));
}

// but undici only can be used in NodeJS
// so if you want to use with proxy, you need comment the code below
export const runtime = 'edge';
// export const runtime = 'edge';

// due to gemini-1.5-pro only can be used in us, so we need to set the preferred region only in US
export const preferredRegion = ['cle1', 'iad1', 'pdx1', 'sfo1'];
Expand Down
1 change: 1 addition & 0 deletions src/config/modelProviders/google.ts
Expand Up @@ -21,6 +21,7 @@ const Google: ModelProviderCard = {
{
description: 'The best model for scaling across a wide range of tasks',
displayName: 'Gemini 1.0 Pro',
functionCall: true,
id: 'gemini-pro',
maxOutput: 2048,
tokens: 32_768,
Expand Down
164 changes: 125 additions & 39 deletions src/libs/agent-runtime/google/index.ts
@@ -1,10 +1,21 @@
import { Content, GoogleGenerativeAI, Part } from '@google/generative-ai';
import {
Content,
FunctionDeclaration,
FunctionDeclarationSchemaProperty,
FunctionDeclarationSchemaType,
Tool as GoogleFunctionCallTool,
GoogleGenerativeAI,
Part,
} from '@google/generative-ai';
import { GoogleGenerativeAIStream, StreamingTextResponse } from 'ai';
import { JSONSchema7 } from 'json-schema';
import { transform } from 'lodash-es';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../error';
import {
ChatCompetitionOptions,
ChatCompletionTool,
ChatStreamPayload,
OpenAIChatMessage,
UserMessageContentPart,
Expand Down Expand Up @@ -72,7 +83,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
},
{ apiVersion: 'v1beta' },
)
.generateContentStream({ contents });
.generateContentStream({ contents, tools: this.buildGoogleTools(payload.tools) });

// Convert the response into a friendly text-stream
const stream = GoogleGenerativeAIStream(geminiStream, options?.callback);
Expand All @@ -94,38 +105,19 @@ export class LobeGoogleAI implements LobeRuntimeAI {
}
}

private convertContentToGooglePart = (content: UserMessageContentPart): Part => {
switch (content.type) {
case 'text': {
return { text: content.text };
}
case 'image_url': {
const { mimeType, base64 } = parseDataUri(content.image_url.url);
private convertModel = (model: string, messages: OpenAIChatMessage[]) => {
let finalModel: string = model;

if (!base64) {
throw new TypeError("Image URL doesn't contain base64 data");
}
if (model.includes('pro-vision')) {
// if message are all text message, use vision will return an error:
// "[400 Bad Request] Add an image to use models/gemini-pro-vision, or switch your model to a text model."
const noNeedVision = messages.every((m) => typeof m.content === 'string');

return {
inlineData: {
data: base64,
mimeType: mimeType || 'image/png',
},
};
}
// so we need to downgrade to gemini-pro
if (noNeedVision) finalModel = 'gemini-pro';
}
};

private convertOAIMessagesToGoogleMessage = (message: OpenAIChatMessage): Content => {
const content = message.content as string | UserMessageContentPart[];

return {
parts:
typeof content === 'string'
? [{ text: content }]
: content.map((c) => this.convertContentToGooglePart(c)),
role: message.role === 'assistant' ? 'model' : 'user',
};
return finalModel;
};

// convert messages from the Vercel AI SDK Format to the format
Expand Down Expand Up @@ -169,19 +161,113 @@ export class LobeGoogleAI implements LobeRuntimeAI {
return contents;
};

private convertModel = (model: string, messages: OpenAIChatMessage[]) => {
let finalModel: string = model;
private buildGoogleTools(
tools: ChatCompletionTool[] | undefined,
): GoogleFunctionCallTool[] | undefined {
if (!tools || tools.length === 0) return;

return [
{
functionDeclarations: tools.map((tool) => {
const t = this.convertToolToGoogleTool(tool);
console.log('output Schema', t);
return t;
}),
},
];
}

if (model.includes('pro-vision')) {
// if message are all text message, use vision will return an error:
// "[400 Bad Request] Add an image to use models/gemini-pro-vision, or switch your model to a text model."
const noNeedVision = messages.every((m) => typeof m.content === 'string');
private convertToolToGoogleTool = (tool: ChatCompletionTool): FunctionDeclaration => {
const functionDeclaration = tool.function;
const parameters = functionDeclaration.parameters;

// so we need to downgrade to gemini-pro
if (noNeedVision) finalModel = 'gemini-pro';
console.log('input Schema', JSON.stringify(parameters, null, 2));

return {
description: functionDeclaration.description,
name: functionDeclaration.name,
parameters: {
description: parameters?.description,
properties: transform(parameters?.properties, (result, value, key: string) => {
result[key] = this.convertSchemaObject(value as JSONSchema7);
}),
required: parameters?.required,
type: FunctionDeclarationSchemaType.OBJECT,
},
};
};

private convertSchemaObject(schema: JSONSchema7): FunctionDeclarationSchemaProperty {
console.log('input:', schema);

switch (schema.type) {
case 'object': {
return {
...schema,
properties: Object.fromEntries(
Object.entries(schema.properties || {}).map(([key, value]) => [
key,
this.convertSchemaObject(value as JSONSchema7),
]),
),
type: FunctionDeclarationSchemaType.OBJECT,
};
}

case 'array': {
return {
...schema,
items: this.convertSchemaObject(schema.items as JSONSchema7),
type: FunctionDeclarationSchemaType.ARRAY,
};
}

case 'string': {
return { ...schema, type: FunctionDeclarationSchemaType.STRING };
}

case 'number': {
return { ...schema, type: FunctionDeclarationSchemaType.NUMBER };
}

case 'boolean': {
return { ...schema, type: FunctionDeclarationSchemaType.BOOLEAN };
}
}
}

return finalModel;
private convertContentToGooglePart = (content: UserMessageContentPart): Part => {
switch (content.type) {
case 'text': {
return { text: content.text };
}
case 'image_url': {
const { mimeType, base64 } = parseDataUri(content.image_url.url);

if (!base64) {
throw new TypeError("Image URL doesn't contain base64 data");
}

return {
inlineData: {
data: base64,
mimeType: mimeType || 'image/png',
},
};
}
}
};

private convertOAIMessagesToGoogleMessage = (message: OpenAIChatMessage): Content => {
const content = message.content as string | UserMessageContentPart[];

return {
parts:
typeof content === 'string'
? [{ text: content }]
: content.map((c) => this.convertContentToGooglePart(c)),
role: message.role === 'assistant' ? 'model' : 'user',
};
};

private parseErrorMessage(message: string): {
Expand Down