From 105fc0f901f0ce66164668b9cd5ce8e98d122491 Mon Sep 17 00:00:00 2001 From: Mario Fahlandt Date: Sun, 25 Feb 2024 22:58:37 +0100 Subject: [PATCH] feat: add Google Vertex AI as provider to utilize gemini via GCP --- README.md | 3 +- cmd/auth/add.go | 5 +- cmd/auth/auth.go | 1 + go.mod | 4 +- go.sum | 6 ++ pkg/ai/googlevertexai.go | 178 +++++++++++++++++++++++++++++++++++++++ pkg/ai/iai.go | 10 ++- 7 files changed, 203 insertions(+), 4 deletions(-) create mode 100644 pkg/ai/googlevertexai.go diff --git a/README.md b/README.md index 47e715c0c8..5810ffca1c 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ It has SRE experience codified into its analyzers and helps to pull out the most relevant information to enrich it with AI. -_Out of the box integration with OpenAI, Azure, Cohere, Amazon Bedrock and local models._ +_Out of the box integration with OpenAI, Azure, Cohere, Amazon Bedrock, Google Gemini and local models._ K8sGPT - K8sGPT gives Kubernetes Superpowers to everyone | Product Hunt @@ -314,6 +314,7 @@ Unused: > google > huggingface > noopai +> googlevertexai ``` For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/). diff --git a/cmd/auth/add.go b/cmd/auth/add.go index 2e195d02b9..88d660fc27 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -119,6 +119,7 @@ var addCmd = &cobra.Command{ Engine: engine, Temperature: temperature, ProviderRegion: providerRegion, + ProviderId: providerId, TopP: topP, MaxTokens: maxTokens, } @@ -159,5 +160,7 @@ func init() { // add flag for azure open ai engine/deployment name addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)") //add flag for amazonbedrock region name - addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock backend)") + addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)") + //add flag for vertexAI Project ID + addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider Region name (only for googlevertexai backend)") } diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index f252f8a8d3..31f424842b 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -27,6 +27,7 @@ var ( engine string temperature float32 providerRegion string + providerId string topP float32 maxTokens int ) diff --git a/go.mod b/go.mod index 6fb8d7dc0b..3a68f07fa0 100644 --- a/go.mod +++ b/go.mod @@ -51,10 +51,12 @@ require ( atomicgo.dev/schedule v0.1.0 // indirect cloud.google.com/go v0.112.0 // indirect cloud.google.com/go/ai v0.3.0 // indirect + cloud.google.com/go/aiplatform v1.59.0 // indirect cloud.google.com/go/compute v1.24.0 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect cloud.google.com/go/iam v1.1.6 // indirect - cloud.google.com/go/longrunning v0.5.4 // indirect + cloud.google.com/go/longrunning v0.5.5 // indirect + cloud.google.com/go/vertexai v0.7.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.1 // indirect diff --git a/go.sum b/go.sum index d36b93d371..2a79e4389c 100644 --- a/go.sum +++ b/go.sum @@ -99,6 +99,8 @@ cloud.google.com/go/aiplatform v1.52.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv6 cloud.google.com/go/aiplatform v1.54.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU= cloud.google.com/go/aiplatform v1.57.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU= cloud.google.com/go/aiplatform v1.58.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU= +cloud.google.com/go/aiplatform v1.59.0 h1:r+P9YStPWrYF52fKyYCQKzTDw4fLiyzLdTEIdxcjmjU= +cloud.google.com/go/aiplatform v1.59.0/go.mod h1:eTlGuHOahHprZw3Hio5VKmtThIOak5/qy6pzdsqcQnM= cloud.google.com/go/analytics v0.11.0/go.mod h1:DjEWCu41bVbYcKyvlws9Er60YE4a//bK6mnhWvQeFNI= cloud.google.com/go/analytics v0.12.0/go.mod h1:gkfj9h6XRf9+TS4bmuhPEShsh3hH8PAZzm/41OOhQd4= cloud.google.com/go/analytics v0.17.0/go.mod h1:WXFa3WSym4IZ+JiKmavYdJwGG/CvpqiqczmL59bTD9M= @@ -715,6 +717,8 @@ cloud.google.com/go/longrunning v0.5.2/go.mod h1:nqo6DQbNV2pXhGDbDMoN2bWz68MjZUz cloud.google.com/go/longrunning v0.5.3/go.mod h1:y/0ga59EYu58J6SHmmQOvekvND2qODbu8ywBBW7EK7Y= cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg= cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI= +cloud.google.com/go/longrunning v0.5.5 h1:GOE6pZFdSrTb4KAiKnXsJBtlE6mEyaW44oKyMILWnOg= +cloud.google.com/go/longrunning v0.5.5/go.mod h1:WV2LAxD8/rg5Z1cNW6FJ/ZpX4E4VnDnoTk0yawPBB7s= cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE= cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM= cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA= @@ -1125,6 +1129,8 @@ cloud.google.com/go/translate v1.9.1/go.mod h1:TWIgDZknq2+JD4iRcojgeDtqGEp154HN/ cloud.google.com/go/translate v1.9.2/go.mod h1:E3Tc6rUTsQkVrXW6avbUhKJSr7ZE3j7zNmqzXKHqRrY= cloud.google.com/go/translate v1.9.3/go.mod h1:Kbq9RggWsbqZ9W5YpM94Q1Xv4dshw/gr/SHfsl5yCZ0= cloud.google.com/go/translate v1.10.0/go.mod h1:Kbq9RggWsbqZ9W5YpM94Q1Xv4dshw/gr/SHfsl5yCZ0= +cloud.google.com/go/vertexai v0.7.1 h1:CSdqsEwjklLIlI1e5SrsnkwG/I+CeJekkBbMTzeYhVg= +cloud.google.com/go/vertexai v0.7.1/go.mod h1:HfnfYR9aPS+qF2436S6Hzuw0Fp+PORjzK3ggqymdzSU= cloud.google.com/go/video v1.8.0/go.mod h1:sTzKFc0bUSByE8Yoh8X0mn8bMymItVGPfTuUBUyRgxk= cloud.google.com/go/video v1.9.0/go.mod h1:0RhNKFRF5v92f8dQt0yhaHrEuH95m068JYOvLZYnJSw= cloud.google.com/go/video v1.12.0/go.mod h1:MLQew95eTuaNDEGriQdcYn0dTwf9oWiA4uYebxM5kdg= diff --git a/pkg/ai/googlevertexai.go b/pkg/ai/googlevertexai.go new file mode 100644 index 0000000000..b46e90e95a --- /dev/null +++ b/pkg/ai/googlevertexai.go @@ -0,0 +1,178 @@ +/* +Copyright 2023 The K8sGPT Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ai + +import ( + "context" + "errors" + "fmt" + + "cloud.google.com/go/vertexai/genai" + "github.com/fatih/color" +) + +const googleVertexAIClientName = "googlevertexai" + +type GoogleVertexAIClient struct { + client *genai.Client + + model string + temperature float32 + topP float32 + maxTokens int +} + +// Vertex AI Gemini supported Regions +// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini +const VERTEXAI_DEFAULT_REGION = "us-central1" // default use us-east-1 region + +const ( + US_Central_1 = "us-central1" + US_West_4 = "us-west4" + North_America_Northeast1 = "northamerica-northeast1" + US_East_4 = "us-east4" + US_West_1 = "us-west1" + Asia_Northeast_3 = "asia-northeast3" + Asia_Southeast_1 = "asia-southeast1" + Asia_Northeast_1 = "asia-northeast1" +) + +var VERTEXAI_SUPPORTED_REGION = []string{ + US_Central_1, + US_West_4, + North_America_Northeast1, + US_East_4, + US_West_1, + Asia_Northeast_3, + Asia_Southeast_1, + Asia_Northeast_1, +} + +const ( + ModelGeminiProV1 = "gemini-1.0-pro-001" +) + +var VERTEXAI_MODELS = []string{ + ModelGeminiProV1, +} + +// GetModelOrDefault check config model +func GetVertexAIModelOrDefault(model string) string { + + // Check if the provided model is in the list + for _, m := range VERTEXAI_MODELS { + if m == model { + return model // Return the provided model + } + } + + // Return the default model if the provided model is not in the list + return VERTEXAI_MODELS[0] +} + +// GetModelOrDefault check config region +func GetVertexAIRegionOrDefault(region string) string { + + // Check if the provided model is in the list + for _, m := range VERTEXAI_SUPPORTED_REGION { + if m == region { + return region // Return the provided model + } + } + + // Return the default model if the provided model is not in the list + return VERTEXAI_DEFAULT_REGION +} + +func (g *GoogleVertexAIClient) Configure(config IAIConfig) error { + ctx := context.Background() + + // Currently you can access VertexAI either by being authenticated via OAuth or Bearer token so we need to consider both + projectId := config.GetProviderId() + region := GetVertexAIRegionOrDefault(config.GetProviderRegion()) + + client, err := genai.NewClient(ctx, projectId, region) + if err != nil { + return fmt.Errorf("creating genai Google SDK client: %w", err) + } + + g.client = client + g.model = GetVertexAIModelOrDefault(config.GetModel()) + g.temperature = config.GetTemperature() + g.topP = config.GetTopP() + g.maxTokens = config.GetMaxTokens() + + return nil +} + +func (g *GoogleVertexAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + + model := g.client.GenerativeModel(g.model) + model.SetTemperature(g.temperature) + model.SetTopP(g.topP) + model.SetMaxOutputTokens(int32(g.maxTokens)) + + // Google AI SDK is capable of different inputs than just text, for now set explicit text prompt type. + // Similarly, we could stream the response. For now k8sgpt does not support streaming. + resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + if err != nil { + return "", err + } + + if len(resp.Candidates) == 0 { + if resp.PromptFeedback.BlockReason > 0 { + for _, r := range resp.PromptFeedback.SafetyRatings { + if !r.Blocked { + continue + } + return "", fmt.Errorf("complection blocked due to %v with probability %v", r.Category.String(), r.Probability.String()) + } + } + return "", errors.New("no complection returned; unknown reason") + } + + // Format output. + // TODO(bwplotka): Provider richer output in certain cases e.g. suddenly finished + // completion based on finish reasons or safety rankings. + got := resp.Candidates[0] + var output string + for _, part := range got.Content.Parts { + switch o := part.(type) { + case genai.Text: + output += string(o) + output += "\n" + default: + color.Yellow("found unsupported AI response part of type %T; ignoring", part) + } + } + + if got.CitationMetadata != nil && len(got.CitationMetadata.Citations) > 0 { + output += "Citations:\n" + for _, source := range got.CitationMetadata.Citations { + // TODO(bwplotka): Give details around what exactly words could be attributed to the citation. + output += fmt.Sprintf("* %s, %s\n", source.URI, source.License) + } + } + return output, nil +} + +func (g *GoogleVertexAIClient) GetName() string { + return googleVertexAIClientName +} + +func (g *GoogleVertexAIClient) Close() { + if err := g.client.Close(); err != nil { + color.Red("googleai client close error: %v", err) + } +} diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 99de8e3a40..8cca150798 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -28,6 +28,7 @@ var ( &SageMakerAIClient{}, &GoogleGenAIClient{}, &HuggingfaceClient{}, + &GoogleVertexAIClient{}, } Backends = []string{ openAIClientName, @@ -39,6 +40,7 @@ var ( googleAIClientName, noopAIClientName, huggingfaceAIClientName, + googleVertexAIClientName, } ) @@ -70,6 +72,7 @@ type IAIConfig interface { GetProviderRegion() string GetTopP() float32 GetMaxTokens() int + GetProviderId() string } func NewClient(provider string) IAI { @@ -96,6 +99,7 @@ type AIProvider struct { Engine string `mapstructure:"engine" yaml:"engine,omitempty"` Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"` ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"` + ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"` TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` } @@ -135,7 +139,11 @@ func (p *AIProvider) GetProviderRegion() string { return p.ProviderRegion } -var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock"} +func (p *AIProvider) GetProviderId() string { + return p.ProviderId +} + +var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai"} func NeedPassword(backend string) bool { for _, b := range passwordlessProviders {