Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Google Vertex AI as provider to utilize gemini via GCP #984

Merged
merged 8 commits into from Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -19,7 +19,7 @@

It has SRE experience codified into its analyzers and helps to pull out the most relevant information to enrich it with AI.

_Out of the box integration with OpenAI, Azure, Cohere, Amazon Bedrock and local models._
_Out of the box integration with OpenAI, Azure, Cohere, Amazon Bedrock, Google Gemini and local models._

<a href="https://www.producthunt.com/posts/k8sgpt?utm_source=badge-featured&utm_medium=badge&utm_souce=badge-k8sgpt" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=389489&theme=light" alt="K8sGPT - K8sGPT&#0032;gives&#0032;Kubernetes&#0032;Superpowers&#0032;to&#0032;everyone | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>

Expand Down Expand Up @@ -314,6 +314,7 @@ Unused:
> google
> huggingface
> noopai
> googlevertexai
```

For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/).
Expand Down
5 changes: 4 additions & 1 deletion cmd/auth/add.go
Expand Up @@ -119,6 +119,7 @@ var addCmd = &cobra.Command{
Engine: engine,
Temperature: temperature,
ProviderRegion: providerRegion,
ProviderId: providerId,
TopP: topP,
MaxTokens: maxTokens,
}
Expand Down Expand Up @@ -159,5 +160,7 @@ func init() {
// add flag for azure open ai engine/deployment name
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
//add flag for amazonbedrock region name
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock backend)")
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)")
//add flag for vertexAI Project ID
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider Region name (only for googlevertexai backend)")
JuHyung-Son marked this conversation as resolved.
Show resolved Hide resolved
}
1 change: 1 addition & 0 deletions cmd/auth/auth.go
Expand Up @@ -27,6 +27,7 @@ var (
engine string
temperature float32
providerRegion string
providerId string
topP float32
maxTokens int
)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Expand Up @@ -50,10 +50,12 @@ require (
atomicgo.dev/schedule v0.1.0 // indirect
cloud.google.com/go v0.112.0 // indirect
cloud.google.com/go/ai v0.3.0 // indirect
cloud.google.com/go/aiplatform v1.59.0 // indirect
cloud.google.com/go/compute v1.24.0 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v1.1.6 // indirect
cloud.google.com/go/longrunning v0.5.5 // indirect
cloud.google.com/go/vertexai v0.7.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.1 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Expand Up @@ -99,6 +99,8 @@ cloud.google.com/go/aiplatform v1.52.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv6
cloud.google.com/go/aiplatform v1.54.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU=
cloud.google.com/go/aiplatform v1.57.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU=
cloud.google.com/go/aiplatform v1.58.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU=
cloud.google.com/go/aiplatform v1.59.0 h1:r+P9YStPWrYF52fKyYCQKzTDw4fLiyzLdTEIdxcjmjU=
cloud.google.com/go/aiplatform v1.59.0/go.mod h1:eTlGuHOahHprZw3Hio5VKmtThIOak5/qy6pzdsqcQnM=
cloud.google.com/go/analytics v0.11.0/go.mod h1:DjEWCu41bVbYcKyvlws9Er60YE4a//bK6mnhWvQeFNI=
cloud.google.com/go/analytics v0.12.0/go.mod h1:gkfj9h6XRf9+TS4bmuhPEShsh3hH8PAZzm/41OOhQd4=
cloud.google.com/go/analytics v0.17.0/go.mod h1:WXFa3WSym4IZ+JiKmavYdJwGG/CvpqiqczmL59bTD9M=
Expand Down Expand Up @@ -1126,6 +1128,8 @@ cloud.google.com/go/translate v1.9.1/go.mod h1:TWIgDZknq2+JD4iRcojgeDtqGEp154HN/
cloud.google.com/go/translate v1.9.2/go.mod h1:E3Tc6rUTsQkVrXW6avbUhKJSr7ZE3j7zNmqzXKHqRrY=
cloud.google.com/go/translate v1.9.3/go.mod h1:Kbq9RggWsbqZ9W5YpM94Q1Xv4dshw/gr/SHfsl5yCZ0=
cloud.google.com/go/translate v1.10.0/go.mod h1:Kbq9RggWsbqZ9W5YpM94Q1Xv4dshw/gr/SHfsl5yCZ0=
cloud.google.com/go/vertexai v0.7.1 h1:CSdqsEwjklLIlI1e5SrsnkwG/I+CeJekkBbMTzeYhVg=
cloud.google.com/go/vertexai v0.7.1/go.mod h1:HfnfYR9aPS+qF2436S6Hzuw0Fp+PORjzK3ggqymdzSU=
cloud.google.com/go/video v1.8.0/go.mod h1:sTzKFc0bUSByE8Yoh8X0mn8bMymItVGPfTuUBUyRgxk=
cloud.google.com/go/video v1.9.0/go.mod h1:0RhNKFRF5v92f8dQt0yhaHrEuH95m068JYOvLZYnJSw=
cloud.google.com/go/video v1.12.0/go.mod h1:MLQew95eTuaNDEGriQdcYn0dTwf9oWiA4uYebxM5kdg=
Expand Down
178 changes: 178 additions & 0 deletions pkg/ai/googlevertexai.go
@@ -0,0 +1,178 @@
/*
Copyright 2023 The K8sGPT Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ai

import (
"context"
"errors"
"fmt"

"cloud.google.com/go/vertexai/genai"
"github.com/fatih/color"
)

const googleVertexAIClientName = "googlevertexai"

type GoogleVertexAIClient struct {
client *genai.Client

JuHyung-Son marked this conversation as resolved.
Show resolved Hide resolved
model string
temperature float32
topP float32
maxTokens int
}

// Vertex AI Gemini supported Regions
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
const VERTEXAI_DEFAULT_REGION = "us-central1" // default use us-east-1 region

const (
US_Central_1 = "us-central1"
US_West_4 = "us-west4"
North_America_Northeast1 = "northamerica-northeast1"
US_East_4 = "us-east4"
US_West_1 = "us-west1"
Asia_Northeast_3 = "asia-northeast3"
Asia_Southeast_1 = "asia-southeast1"
Asia_Northeast_1 = "asia-northeast1"
)

var VERTEXAI_SUPPORTED_REGION = []string{
US_Central_1,
US_West_4,
North_America_Northeast1,
US_East_4,
US_West_1,
Asia_Northeast_3,
Asia_Southeast_1,
Asia_Northeast_1,
}

const (
ModelGeminiProV1 = "gemini-1.0-pro-001"
)

var VERTEXAI_MODELS = []string{
ModelGeminiProV1,
}

// GetModelOrDefault check config model
func GetVertexAIModelOrDefault(model string) string {

// Check if the provided model is in the list
for _, m := range VERTEXAI_MODELS {
if m == model {
return model // Return the provided model
}
}

// Return the default model if the provided model is not in the list
return VERTEXAI_MODELS[0]
}

// GetModelOrDefault check config region
func GetVertexAIRegionOrDefault(region string) string {

// Check if the provided model is in the list
for _, m := range VERTEXAI_SUPPORTED_REGION {
if m == region {
return region // Return the provided model
}
}

// Return the default model if the provided model is not in the list
return VERTEXAI_DEFAULT_REGION
}

func (g *GoogleVertexAIClient) Configure(config IAIConfig) error {
ctx := context.Background()

// Currently you can access VertexAI either by being authenticated via OAuth or Bearer token so we need to consider both
projectId := config.GetProviderId()
region := GetVertexAIRegionOrDefault(config.GetProviderRegion())

client, err := genai.NewClient(ctx, projectId, region)
if err != nil {
return fmt.Errorf("creating genai Google SDK client: %w", err)
}

g.client = client
g.model = GetVertexAIModelOrDefault(config.GetModel())
g.temperature = config.GetTemperature()
g.topP = config.GetTopP()
g.maxTokens = config.GetMaxTokens()

return nil
}

func (g *GoogleVertexAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {

model := g.client.GenerativeModel(g.model)
model.SetTemperature(g.temperature)
model.SetTopP(g.topP)
model.SetMaxOutputTokens(int32(g.maxTokens))

// Google AI SDK is capable of different inputs than just text, for now set explicit text prompt type.
// Similarly, we could stream the response. For now k8sgpt does not support streaming.
resp, err := model.GenerateContent(ctx, genai.Text(prompt))
if err != nil {
return "", err
}

if len(resp.Candidates) == 0 {
if resp.PromptFeedback.BlockReason > 0 {
for _, r := range resp.PromptFeedback.SafetyRatings {
if !r.Blocked {
continue
}
return "", fmt.Errorf("complection blocked due to %v with probability %v", r.Category.String(), r.Probability.String())
}
}
return "", errors.New("no complection returned; unknown reason")
}

// Format output.
// TODO(bwplotka): Provider richer output in certain cases e.g. suddenly finished
// completion based on finish reasons or safety rankings.
got := resp.Candidates[0]
var output string
for _, part := range got.Content.Parts {
switch o := part.(type) {
case genai.Text:
output += string(o)
output += "\n"
default:
color.Yellow("found unsupported AI response part of type %T; ignoring", part)
}
}

if got.CitationMetadata != nil && len(got.CitationMetadata.Citations) > 0 {
output += "Citations:\n"
for _, source := range got.CitationMetadata.Citations {
// TODO(bwplotka): Give details around what exactly words could be attributed to the citation.
output += fmt.Sprintf("* %s, %s\n", source.URI, source.License)
}
}
return output, nil
}

func (g *GoogleVertexAIClient) GetName() string {
return googleVertexAIClientName
}

func (g *GoogleVertexAIClient) Close() {
if err := g.client.Close(); err != nil {
color.Red("googleai client close error: %v", err)
}
}
10 changes: 9 additions & 1 deletion pkg/ai/iai.go
Expand Up @@ -28,6 +28,7 @@ var (
&SageMakerAIClient{},
&GoogleGenAIClient{},
&HuggingfaceClient{},
&GoogleVertexAIClient{},
}
Backends = []string{
openAIClientName,
Expand All @@ -39,6 +40,7 @@ var (
googleAIClientName,
noopAIClientName,
huggingfaceAIClientName,
googleVertexAIClientName,
}
)

Expand Down Expand Up @@ -71,6 +73,7 @@ type IAIConfig interface {
GetProviderRegion() string
GetTopP() float32
GetMaxTokens() int
GetProviderId() string
}

func NewClient(provider string) IAI {
Expand Down Expand Up @@ -99,6 +102,7 @@ type AIProvider struct {
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
}
Expand Down Expand Up @@ -142,7 +146,11 @@ func (p *AIProvider) GetProviderRegion() string {
return p.ProviderRegion
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock"}
func (p *AIProvider) GetProviderId() string {
return p.ProviderId
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai"}

func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
Expand Down