Skip to content

Commit

Permalink
Add Google Vertex AI as provider to utilize Gemini via GCP
Browse files Browse the repository at this point in the history
  • Loading branch information
mfahlandt committed Feb 25, 2024
1 parent 35f5185 commit 2afe94c
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 4 deletions.
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -19,7 +19,7 @@

It has SRE experience codified into its analyzers and helps to pull out the most relevant information to enrich it with AI.

_Out of the box integration with OpenAI, Azure, Cohere, Amazon Bedrock and local models._
_Out of the box integration with OpenAI, Azure, Cohere, Amazon Bedrock, Google Gemini and local models._

<a href="https://www.producthunt.com/posts/k8sgpt?utm_source=badge-featured&utm_medium=badge&utm_souce=badge-k8sgpt" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=389489&theme=light" alt="K8sGPT - K8sGPT&#0032;gives&#0032;Kubernetes&#0032;Superpowers&#0032;to&#0032;everyone | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>

Expand Down Expand Up @@ -314,6 +314,7 @@ Unused:
> google
> huggingface
> noopai
> googlevertexai
```

For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/).
Expand Down
5 changes: 4 additions & 1 deletion cmd/auth/add.go
Expand Up @@ -119,6 +119,7 @@ var addCmd = &cobra.Command{
Engine: engine,
Temperature: temperature,
ProviderRegion: providerRegion,
ProviderId: providerId,
TopP: topP,
MaxTokens: maxTokens,
}
Expand Down Expand Up @@ -159,5 +160,7 @@ func init() {
// add flag for azure open ai engine/deployment name
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
//add flag for amazonbedrock region name
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock backend)")
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)")
//add flag for vertexAI Project ID
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider Region name (only for googlevertexai backend)")
}
1 change: 1 addition & 0 deletions cmd/auth/auth.go
Expand Up @@ -27,6 +27,7 @@ var (
engine string
temperature float32
providerRegion string
providerId string
topP float32
maxTokens int
)
Expand Down
4 changes: 3 additions & 1 deletion go.mod
Expand Up @@ -51,10 +51,12 @@ require (
atomicgo.dev/schedule v0.1.0 // indirect
cloud.google.com/go v0.112.0 // indirect
cloud.google.com/go/ai v0.3.0 // indirect
cloud.google.com/go/aiplatform v1.59.0 // indirect
cloud.google.com/go/compute v1.24.0 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v1.1.6 // indirect
cloud.google.com/go/longrunning v0.5.4 // indirect
cloud.google.com/go/longrunning v0.5.5 // indirect
cloud.google.com/go/vertexai v0.7.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.1 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Expand Up @@ -99,6 +99,8 @@ cloud.google.com/go/aiplatform v1.52.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv6
cloud.google.com/go/aiplatform v1.54.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU=
cloud.google.com/go/aiplatform v1.57.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU=
cloud.google.com/go/aiplatform v1.58.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU=
cloud.google.com/go/aiplatform v1.59.0 h1:r+P9YStPWrYF52fKyYCQKzTDw4fLiyzLdTEIdxcjmjU=
cloud.google.com/go/aiplatform v1.59.0/go.mod h1:eTlGuHOahHprZw3Hio5VKmtThIOak5/qy6pzdsqcQnM=
cloud.google.com/go/analytics v0.11.0/go.mod h1:DjEWCu41bVbYcKyvlws9Er60YE4a//bK6mnhWvQeFNI=
cloud.google.com/go/analytics v0.12.0/go.mod h1:gkfj9h6XRf9+TS4bmuhPEShsh3hH8PAZzm/41OOhQd4=
cloud.google.com/go/analytics v0.17.0/go.mod h1:WXFa3WSym4IZ+JiKmavYdJwGG/CvpqiqczmL59bTD9M=
Expand Down Expand Up @@ -715,6 +717,8 @@ cloud.google.com/go/longrunning v0.5.2/go.mod h1:nqo6DQbNV2pXhGDbDMoN2bWz68MjZUz
cloud.google.com/go/longrunning v0.5.3/go.mod h1:y/0ga59EYu58J6SHmmQOvekvND2qODbu8ywBBW7EK7Y=
cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg=
cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI=
cloud.google.com/go/longrunning v0.5.5 h1:GOE6pZFdSrTb4KAiKnXsJBtlE6mEyaW44oKyMILWnOg=
cloud.google.com/go/longrunning v0.5.5/go.mod h1:WV2LAxD8/rg5Z1cNW6FJ/ZpX4E4VnDnoTk0yawPBB7s=
cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE=
cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM=
cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA=
Expand Down Expand Up @@ -1125,6 +1129,8 @@ cloud.google.com/go/translate v1.9.1/go.mod h1:TWIgDZknq2+JD4iRcojgeDtqGEp154HN/
cloud.google.com/go/translate v1.9.2/go.mod h1:E3Tc6rUTsQkVrXW6avbUhKJSr7ZE3j7zNmqzXKHqRrY=
cloud.google.com/go/translate v1.9.3/go.mod h1:Kbq9RggWsbqZ9W5YpM94Q1Xv4dshw/gr/SHfsl5yCZ0=
cloud.google.com/go/translate v1.10.0/go.mod h1:Kbq9RggWsbqZ9W5YpM94Q1Xv4dshw/gr/SHfsl5yCZ0=
cloud.google.com/go/vertexai v0.7.1 h1:CSdqsEwjklLIlI1e5SrsnkwG/I+CeJekkBbMTzeYhVg=
cloud.google.com/go/vertexai v0.7.1/go.mod h1:HfnfYR9aPS+qF2436S6Hzuw0Fp+PORjzK3ggqymdzSU=
cloud.google.com/go/video v1.8.0/go.mod h1:sTzKFc0bUSByE8Yoh8X0mn8bMymItVGPfTuUBUyRgxk=
cloud.google.com/go/video v1.9.0/go.mod h1:0RhNKFRF5v92f8dQt0yhaHrEuH95m068JYOvLZYnJSw=
cloud.google.com/go/video v1.12.0/go.mod h1:MLQew95eTuaNDEGriQdcYn0dTwf9oWiA4uYebxM5kdg=
Expand Down
178 changes: 178 additions & 0 deletions pkg/ai/googlevertexai.go
@@ -0,0 +1,178 @@
/*
Copyright 2023 The K8sGPT Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ai

import (
"context"
"errors"
"fmt"

"cloud.google.com/go/vertexai/genai"
"github.com/fatih/color"
)

const googleVertexAIClientName = "googlevertexai"

type GoogleVertexAIClient struct {
client *genai.Client

model string
temperature float32
topP float32
maxTokens int
}

// Vertex AI Gemini supported Regions
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
const VERTEXAI_DEFAULT_REGION = "us-central1" // default use us-east-1 region

const (
US_Central_1 = "us-central1"
US_West_4 = "us-west4"
North_America_Northeast1 = "northamerica-northeast1"
US_East_4 = "us-east4"
US_West_1 = "us-west1"
Asia_Northeast_3 = "asia-northeast3"
Asia_Southeast_1 = "asia-southeast1"
Asia_Northeast_1 = "asia-northeast1"
)

var VERTEXAI_SUPPORTED_REGION = []string{
US_Central_1,
US_West_4,
North_America_Northeast1,
US_East_4,
US_West_1,
Asia_Northeast_3,
Asia_Southeast_1,
Asia_Northeast_1,
}

const (
ModelGeminiProV1 = "gemini-1.0-pro-001"
)

var VERTEXAI_MODELS = []string{
ModelGeminiProV1,
}

// GetModelOrDefault check config model
func GetVertexAIModelOrDefault(model string) string {

// Check if the provided model is in the list
for _, m := range VERTEXAI_MODELS {
if m == model {
return model // Return the provided model
}
}

// Return the default model if the provided model is not in the list
return VERTEXAI_MODELS[0]
}

// GetModelOrDefault check config region
func GetVertexAIRegionOrDefault(region string) string {

// Check if the provided model is in the list
for _, m := range VERTEXAI_SUPPORTED_REGION {
if m == region {
return region // Return the provided model
}
}

// Return the default model if the provided model is not in the list
return VERTEXAI_DEFAULT_REGION
}

func (g *GoogleVertexAIClient) Configure(config IAIConfig) error {
ctx := context.Background()

// Currently you can access VertexAI either by being authenticated via OAuth or Bearer token so we need to consider both
projectId := config.GetProviderId()
region := GetVertexAIRegionOrDefault(config.GetProviderRegion())

client, err := genai.NewClient(ctx, projectId, region)
if err != nil {
return fmt.Errorf("creating genai Google SDK client: %w", err)
}

g.client = client
g.model = GetVertexAIModelOrDefault(config.GetModel())
g.temperature = config.GetTemperature()
g.topP = config.GetTopP()
g.maxTokens = config.GetMaxTokens()

return nil
}

func (g *GoogleVertexAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {

model := g.client.GenerativeModel(g.model)
model.SetTemperature(g.temperature)
model.SetTopP(g.topP)
model.SetMaxOutputTokens(int32(g.maxTokens))

// Google AI SDK is capable of different inputs than just text, for now set explicit text prompt type.
// Similarly, we could stream the response. For now k8sgpt does not support streaming.
resp, err := model.GenerateContent(ctx, genai.Text(prompt))
if err != nil {
return "", err
}

if len(resp.Candidates) == 0 {
if resp.PromptFeedback.BlockReason > 0 {
for _, r := range resp.PromptFeedback.SafetyRatings {
if !r.Blocked {
continue
}
return "", fmt.Errorf("complection blocked due to %v with probability %v", r.Category.String(), r.Probability.String())
}
}
return "", errors.New("no complection returned; unknown reason")
}

// Format output.
// TODO(bwplotka): Provider richer output in certain cases e.g. suddenly finished
// completion based on finish reasons or safety rankings.
got := resp.Candidates[0]
var output string
for _, part := range got.Content.Parts {
switch o := part.(type) {
case genai.Text:
output += string(o)
output += "\n"
default:
color.Yellow("found unsupported AI response part of type %T; ignoring", part)
}
}

if got.CitationMetadata != nil && len(got.CitationMetadata.Citations) > 0 {
output += "Citations:\n"
for _, source := range got.CitationMetadata.Citations {
// TODO(bwplotka): Give details around what exactly words could be attributed to the citation.
output += fmt.Sprintf("* %s, %s\n", source.URI, source.License)
}
}
return output, nil
}

func (g *GoogleVertexAIClient) GetName() string {
return googleVertexAIClientName
}

func (g *GoogleVertexAIClient) Close() {
if err := g.client.Close(); err != nil {
color.Red("googleai client close error: %v", err)
}
}
10 changes: 9 additions & 1 deletion pkg/ai/iai.go
Expand Up @@ -28,6 +28,7 @@ var (
&SageMakerAIClient{},
&GoogleGenAIClient{},
&HuggingfaceClient{},
&GoogleVertexAIClient{},
}
Backends = []string{
openAIClientName,
Expand All @@ -39,6 +40,7 @@ var (
googleAIClientName,
noopAIClientName,
huggingfaceAIClientName,
googleVertexAIClientName,
}
)

Expand Down Expand Up @@ -70,6 +72,7 @@ type IAIConfig interface {
GetProviderRegion() string
GetTopP() float32
GetMaxTokens() int
GetProviderId() string
}

func NewClient(provider string) IAI {
Expand All @@ -96,6 +99,7 @@ type AIProvider struct {
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
}
Expand Down Expand Up @@ -135,7 +139,11 @@ func (p *AIProvider) GetProviderRegion() string {
return p.ProviderRegion
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock"}
func (p *AIProvider) GetProviderId() string {
return p.ProviderId
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai"}

func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
Expand Down

0 comments on commit 2afe94c

Please sign in to comment.