diff --git a/clients/go/admin/api_key.go b/clients/go/admin/api_key.go new file mode 100644 index 000000000..1beb28ee4 --- /dev/null +++ b/clients/go/admin/api_key.go @@ -0,0 +1,38 @@ +package admin + +import ( + "encoding/base64" + "encoding/json" +) + +// APIKey is a struct that represents the API key used to authenticate with the Admin service. +type APIKey struct { + ClientID string `json:"id"` + Secret string `json:"secret"` +} + +// Encode encodes the API key into a base64 encoded string. +func (k APIKey) Encode() (string, error) { + raw, err := json.Marshal(k) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(raw), nil +} + +// DecodeAPIKey decodes a base64 encoded API key into a struct. +func DecodeAPIKey(apiKey string) (APIKey, error) { + var k APIKey + raw, err := base64.StdEncoding.DecodeString(apiKey) + if err != nil { + return k, err + } + + err = json.Unmarshal(raw, &k) + if err != nil { + return k, err + } + + return k, nil +} diff --git a/clients/go/admin/api_key_test.go b/clients/go/admin/api_key_test.go new file mode 100644 index 000000000..7ac356897 --- /dev/null +++ b/clients/go/admin/api_key_test.go @@ -0,0 +1,59 @@ +package admin + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestDecodeAPIKey(t *testing.T) { + tests := []struct { + name string + apiKey string + want APIKey + wantErr assert.ErrorAssertionFunc + }{ + {"happy path", "eyJpZCI6IiIsInNlY3JldCI6IiJ9", APIKey{}, assert.NoError}, + {"happy path", "eyJpZCI6ImFiYyIsInNlY3JldCI6IiJ9", APIKey{ClientID: "abc"}, assert.NoError}, + {"happy path", "eyJpZCI6ImFiYyIsInNlY3JldCI6ImRlZiJ9", APIKey{ClientID: "abc", Secret: "def"}, assert.NoError}, + {"happy path", "eyJpZCI6IiMkQCFAYWJjIiwic2VjcmV0IjoiIChcdTAwMjZAIGRlZiAifQ==", APIKey{ClientID: "#$@!@abc", Secret: " (&@ def "}, assert.NoError}, + {"invalid base64", "blah bloh==", APIKey{}, assert.Error}, + {"invalid json", "aGVsbG8gd29ybGQK", APIKey{}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DecodeAPIKey(tt.apiKey) + if !tt.wantErr(t, err, fmt.Sprintf("DecodeAPIKey(%v)", tt.apiKey)) { + return + } + assert.Equalf(t, tt.want, got, "DecodeAPIKey(%v)", tt.apiKey) + }) + } +} + +func TestAPIKey_Encode(t *testing.T) { + tests := []struct { + name string + apiKey APIKey + want string + wantErr assert.ErrorAssertionFunc + }{ + {"happy path", APIKey{}, "eyJpZCI6IiIsInNlY3JldCI6IiJ9", assert.NoError}, + {"happy path", APIKey{ClientID: "abc"}, "eyJpZCI6ImFiYyIsInNlY3JldCI6IiJ9", assert.NoError}, + {"happy path", APIKey{ClientID: "abc", Secret: "def"}, "eyJpZCI6ImFiYyIsInNlY3JldCI6ImRlZiJ9", assert.NoError}, + {"happy path", APIKey{ClientID: "#$@!@abc", Secret: " (&@ def "}, "eyJpZCI6IiMkQCFAYWJjIiwic2VjcmV0IjoiIChcdTAwMjZAIGRlZiAifQ==", assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := APIKey{ + ClientID: tt.apiKey.ClientID, + Secret: tt.apiKey.Secret, + } + got, err := k.Encode() + if !tt.wantErr(t, err, "Encode()") { + return + } + assert.Equalf(t, tt.want, got, "Encode()") + }) + } +} diff --git a/clients/go/admin/config.go b/clients/go/admin/config.go index dd0652606..29eb85b3f 100644 --- a/clients/go/admin/config.go +++ b/clients/go/admin/config.go @@ -35,6 +35,8 @@ const ( AuthTypeExternalCommand // AuthTypeDeviceFlow Uses device flow to authenticate in a constrained environment with no access to browser AuthTypeDeviceFlow + // AuthTypeAPIKey Uses API Key to authenticate + AuthTypeAPIKey ) type Config struct { @@ -52,6 +54,8 @@ type Config struct { ClientID string `json:"clientId" pflag:",Client ID"` ClientSecretLocation string `json:"clientSecretLocation" pflag:",File containing the client secret"` ClientSecretEnvVar string `json:"clientSecretEnvVar" pflag:",Environment variable containing the client secret"` + APIKeyEnvVar string `json:"apiKeyEnvVar" pflag:",Environment variable containing the API Key"` + APIKeyLocation string `json:"apiKeyLocation" pflag:",File containing the API Key"` Scopes []string `json:"scopes" pflag:",List of scopes to request"` UseAudienceFromAdmin bool `json:"useAudienceFromAdmin" pflag:",Use Audience configured from admins public endpoint config."` Audience string `json:"audience" pflag:",Audience to use when initiating OAuth2 authorization requests."` @@ -89,6 +93,7 @@ var ( ClientID: DefaultClientID, AuthType: AuthTypeClientSecret, ClientSecretLocation: DefaultClientSecretLocation, + APIKeyEnvVar: "FLYTE_API_KEY", PkceConfig: pkce.Config{ TokenRefreshGracePeriod: config.Duration{Duration: 5 * time.Minute}, BrowserSessionTimeout: config.Duration{Duration: 2 * time.Minute}, diff --git a/clients/go/admin/token_source_provider.go b/clients/go/admin/token_source_provider.go index 41db678f6..becfb9b18 100644 --- a/clients/go/admin/token_source_provider.go +++ b/clients/go/admin/token_source_provider.go @@ -3,7 +3,6 @@ package admin import ( "context" "fmt" - "io/ioutil" "net/url" "os" "strings" @@ -43,34 +42,31 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T var tokenProvider TokenSourceProvider var err error + + apiKey := os.Getenv(cfg.APIKeyEnvVar) + if len(apiKey) > 0 { + logger.Debugf(ctx, "Using API key from environment variable [%s]", cfg.APIKeyEnvVar) + cfg.AuthType = AuthTypeAPIKey + } else if len(cfg.APIKeyLocation) > 0 { + logger.Debugf(ctx, "Using API Key from file [%s]", cfg.APIKeyLocation) + cfg.AuthType = AuthTypeAPIKey + } + switch cfg.AuthType { - case AuthTypeClientSecret: - tokenURL := cfg.TokenURL - if len(tokenURL) == 0 { - metadata, err := authClient.GetOAuth2Metadata(ctx, &service.OAuth2MetadataRequest{}) - if err != nil { - return nil, fmt.Errorf("failed to fetch auth metadata. Error: %v", err) - } - - tokenURL = metadata.TokenEndpoint + case AuthTypeAPIKey: + tokenURL, scopes, audienceValue, err := resolveAuthMetadata(ctx, cfg, authClient) + if err != nil { + return nil, err } - scopes := cfg.Scopes - audienceValue := cfg.Audience - - if len(scopes) == 0 || cfg.UseAudienceFromAdmin { - publicClientConfig, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) - if err != nil { - return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err) - } - // Update scopes from publicClientConfig - if len(scopes) == 0 { - scopes = publicClientConfig.Scopes - } - // Update audience from publicClientConfig - if cfg.UseAudienceFromAdmin { - audienceValue = publicClientConfig.Audience - } + tokenProvider, err = NewClientCredentialsTokenSourceProviderFromAPIKey(cfg, scopes, tokenURL, tokenCache, audienceValue) + if err != nil { + return nil, err + } + case AuthTypeClientSecret: + tokenURL, scopes, audienceValue, err := resolveAuthMetadata(ctx, cfg, authClient) + if err != nil { + return nil, err } tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, tokenCache, audienceValue) @@ -109,6 +105,37 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T return tokenProvider, nil } +func resolveAuthMetadata(ctx context.Context, cfg *Config, authClient service.AuthMetadataServiceClient) (string, []string, string, error) { + tokenURL := cfg.TokenURL + if len(tokenURL) == 0 { + metadata, err := authClient.GetOAuth2Metadata(ctx, &service.OAuth2MetadataRequest{}) + if err != nil { + return "", nil, "", fmt.Errorf("failed to fetch auth metadata. Error: %v", err) + } + + tokenURL = metadata.TokenEndpoint + } + + scopes := cfg.Scopes + audienceValue := cfg.Audience + + if len(scopes) == 0 || cfg.UseAudienceFromAdmin { + publicClientConfig, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) + if err != nil { + return "", nil, "", fmt.Errorf("failed to fetch client metadata. Error: %v", err) + } + // Update scopes from publicClientConfig + if len(scopes) == 0 { + scopes = publicClientConfig.Scopes + } + // Update audience from publicClientConfig + if cfg.UseAudienceFromAdmin { + audienceValue = publicClientConfig.Audience + } + } + return tokenURL, scopes, audienceValue, nil +} + type ExternalTokenSourceProvider struct { command []string } @@ -172,19 +199,49 @@ type ClientCredentialsTokenSourceProvider struct { tokenCache cache.TokenCache } +func NewClientCredentialsTokenSourceProviderFromAPIKey(cfg *Config, scopes []string, tokenURL string, + tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) { + apiKey := os.Getenv(cfg.APIKeyEnvVar) + if len(apiKey) == 0 && len(cfg.APIKeyLocation) == 0 { + return nil, fmt.Errorf("API key is empty at Env Var [%v]", cfg.APIKeyEnvVar) + } + + if len(apiKey) == 0 { + raw, err := os.ReadFile(cfg.APIKeyLocation) + if err != nil { + return nil, fmt.Errorf("failed to read API Key from [%v]. Error: %w", cfg.APIKeyLocation, err) + } + + apiKey = string(raw) + } + + decoded, err := DecodeAPIKey(apiKey) + if err != nil { + return nil, fmt.Errorf("failed to decode API key. Error: %w", err) + } + + return newClientCredentialsTokenSourceProvider(cfg, decoded.ClientID, decoded.Secret, scopes, tokenURL, tokenCache, audience) +} + func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string, tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) { var secret string if len(cfg.ClientSecretEnvVar) > 0 { secret = os.Getenv(cfg.ClientSecretEnvVar) } else if len(cfg.ClientSecretLocation) > 0 { - secretBytes, err := ioutil.ReadFile(cfg.ClientSecretLocation) + secretBytes, err := os.ReadFile(cfg.ClientSecretLocation) if err != nil { logger.Errorf(ctx, "Error reading secret from location %s", cfg.ClientSecretLocation) return nil, err } secret = string(secretBytes) } + + return newClientCredentialsTokenSourceProvider(cfg, cfg.ClientID, secret, scopes, tokenURL, tokenCache, audience) +} + +func newClientCredentialsTokenSourceProvider(cfg *Config, clientID, secret string, scopes []string, tokenURL string, + tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) { endpointParams := url.Values{} if len(audience) > 0 { endpointParams = url.Values{audienceKey: {audience}} @@ -195,14 +252,15 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s } return ClientCredentialsTokenSourceProvider{ ccConfig: clientcredentials.Config{ - ClientID: cfg.ClientID, + ClientID: clientID, ClientSecret: secret, TokenURL: tokenURL, Scopes: scopes, EndpointParams: endpointParams, }, tokenRefreshWindow: cfg.TokenRefreshWindow.Duration, - tokenCache: tokenCache}, nil + tokenCache: tokenCache, + }, nil } func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {