Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Support API Key for authentication #384

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions clients/go/admin/api_key.go
@@ -0,0 +1,38 @@
package admin

import (
"encoding/base64"
"encoding/json"
)

// APIKey is a struct that represents the API key used to authenticate with the Admin service.
type APIKey struct {
ClientID string `json:"id"`
Secret string `json:"secret"`
}

// Encode encodes the API key into a base64 encoded string.
func (k APIKey) Encode() (string, error) {
raw, err := json.Marshal(k)
if err != nil {
return "", err
}

return base64.StdEncoding.EncodeToString(raw), nil
}

// DecodeAPIKey decodes a base64 encoded API key into a struct.
func DecodeAPIKey(apiKey string) (APIKey, error) {
var k APIKey
raw, err := base64.StdEncoding.DecodeString(apiKey)
if err != nil {
return k, err
}

err = json.Unmarshal(raw, &k)
if err != nil {
return k, err
}

return k, nil
}
59 changes: 59 additions & 0 deletions clients/go/admin/api_key_test.go
@@ -0,0 +1,59 @@
package admin

import (
"fmt"
"github.com/stretchr/testify/assert"
"testing"
)

func TestDecodeAPIKey(t *testing.T) {
tests := []struct {
name string
apiKey string
want APIKey
wantErr assert.ErrorAssertionFunc
}{
{"happy path", "eyJpZCI6IiIsInNlY3JldCI6IiJ9", APIKey{}, assert.NoError},
{"happy path", "eyJpZCI6ImFiYyIsInNlY3JldCI6IiJ9", APIKey{ClientID: "abc"}, assert.NoError},
{"happy path", "eyJpZCI6ImFiYyIsInNlY3JldCI6ImRlZiJ9", APIKey{ClientID: "abc", Secret: "def"}, assert.NoError},
{"happy path", "eyJpZCI6IiMkQCFAYWJjIiwic2VjcmV0IjoiIChcdTAwMjZAIGRlZiAifQ==", APIKey{ClientID: "#$@!@abc", Secret: " (&@ def "}, assert.NoError},
{"invalid base64", "blah bloh==", APIKey{}, assert.Error},
{"invalid json", "aGVsbG8gd29ybGQK", APIKey{}, assert.Error},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := DecodeAPIKey(tt.apiKey)
if !tt.wantErr(t, err, fmt.Sprintf("DecodeAPIKey(%v)", tt.apiKey)) {
return
}
assert.Equalf(t, tt.want, got, "DecodeAPIKey(%v)", tt.apiKey)
})
}
}

func TestAPIKey_Encode(t *testing.T) {
tests := []struct {
name string
apiKey APIKey
want string
wantErr assert.ErrorAssertionFunc
}{
{"happy path", APIKey{}, "eyJpZCI6IiIsInNlY3JldCI6IiJ9", assert.NoError},
{"happy path", APIKey{ClientID: "abc"}, "eyJpZCI6ImFiYyIsInNlY3JldCI6IiJ9", assert.NoError},
{"happy path", APIKey{ClientID: "abc", Secret: "def"}, "eyJpZCI6ImFiYyIsInNlY3JldCI6ImRlZiJ9", assert.NoError},
{"happy path", APIKey{ClientID: "#$@!@abc", Secret: " (&@ def "}, "eyJpZCI6IiMkQCFAYWJjIiwic2VjcmV0IjoiIChcdTAwMjZAIGRlZiAifQ==", assert.NoError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := APIKey{
ClientID: tt.apiKey.ClientID,
Secret: tt.apiKey.Secret,
}
got, err := k.Encode()
if !tt.wantErr(t, err, "Encode()") {
return
}
assert.Equalf(t, tt.want, got, "Encode()")
})
}
}
5 changes: 5 additions & 0 deletions clients/go/admin/config.go
Expand Up @@ -35,6 +35,8 @@ const (
AuthTypeExternalCommand
// AuthTypeDeviceFlow Uses device flow to authenticate in a constrained environment with no access to browser
AuthTypeDeviceFlow
// AuthTypeAPIKey Uses API Key to authenticate
AuthTypeAPIKey
)

type Config struct {
Expand All @@ -52,6 +54,8 @@ type Config struct {
ClientID string `json:"clientId" pflag:",Client ID"`
ClientSecretLocation string `json:"clientSecretLocation" pflag:",File containing the client secret"`
ClientSecretEnvVar string `json:"clientSecretEnvVar" pflag:",Environment variable containing the client secret"`
APIKeyEnvVar string `json:"apiKeyEnvVar" pflag:",Environment variable containing the API Key"`
APIKeyLocation string `json:"apiKeyLocation" pflag:",File containing the API Key"`
Scopes []string `json:"scopes" pflag:",List of scopes to request"`
UseAudienceFromAdmin bool `json:"useAudienceFromAdmin" pflag:",Use Audience configured from admins public endpoint config."`
Audience string `json:"audience" pflag:",Audience to use when initiating OAuth2 authorization requests."`
Expand Down Expand Up @@ -89,6 +93,7 @@ var (
ClientID: DefaultClientID,
AuthType: AuthTypeClientSecret,
ClientSecretLocation: DefaultClientSecretLocation,
APIKeyEnvVar: "FLYTE_API_KEY",
PkceConfig: pkce.Config{
TokenRefreshGracePeriod: config.Duration{Duration: 5 * time.Minute},
BrowserSessionTimeout: config.Duration{Duration: 2 * time.Minute},
Expand Down
116 changes: 87 additions & 29 deletions clients/go/admin/token_source_provider.go
Expand Up @@ -3,7 +3,6 @@ package admin
import (
"context"
"fmt"
"io/ioutil"
"net/url"
"os"
"strings"
Expand Down Expand Up @@ -43,34 +42,31 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T

var tokenProvider TokenSourceProvider
var err error

apiKey := os.Getenv(cfg.APIKeyEnvVar)
if len(apiKey) > 0 {
logger.Debugf(ctx, "Using API key from environment variable [%s]", cfg.APIKeyEnvVar)
cfg.AuthType = AuthTypeAPIKey
} else if len(cfg.APIKeyLocation) > 0 {
logger.Debugf(ctx, "Using API Key from file [%s]", cfg.APIKeyLocation)
cfg.AuthType = AuthTypeAPIKey
}

switch cfg.AuthType {
case AuthTypeClientSecret:
tokenURL := cfg.TokenURL
if len(tokenURL) == 0 {
metadata, err := authClient.GetOAuth2Metadata(ctx, &service.OAuth2MetadataRequest{})
if err != nil {
return nil, fmt.Errorf("failed to fetch auth metadata. Error: %v", err)
}

tokenURL = metadata.TokenEndpoint
case AuthTypeAPIKey:
tokenURL, scopes, audienceValue, err := resolveAuthMetadata(ctx, cfg, authClient)
if err != nil {
return nil, err
}

scopes := cfg.Scopes
audienceValue := cfg.Audience

if len(scopes) == 0 || cfg.UseAudienceFromAdmin {
publicClientConfig, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
// Update scopes from publicClientConfig
if len(scopes) == 0 {
scopes = publicClientConfig.Scopes
}
// Update audience from publicClientConfig
if cfg.UseAudienceFromAdmin {
audienceValue = publicClientConfig.Audience
}
tokenProvider, err = NewClientCredentialsTokenSourceProviderFromAPIKey(cfg, scopes, tokenURL, tokenCache, audienceValue)
if err != nil {
return nil, err
}
case AuthTypeClientSecret:
tokenURL, scopes, audienceValue, err := resolveAuthMetadata(ctx, cfg, authClient)
if err != nil {
return nil, err
}

tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, tokenCache, audienceValue)
Expand Down Expand Up @@ -109,6 +105,37 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T
return tokenProvider, nil
}

func resolveAuthMetadata(ctx context.Context, cfg *Config, authClient service.AuthMetadataServiceClient) (string, []string, string, error) {
tokenURL := cfg.TokenURL
if len(tokenURL) == 0 {
metadata, err := authClient.GetOAuth2Metadata(ctx, &service.OAuth2MetadataRequest{})
if err != nil {
return "", nil, "", fmt.Errorf("failed to fetch auth metadata. Error: %v", err)
}

tokenURL = metadata.TokenEndpoint
}

scopes := cfg.Scopes
audienceValue := cfg.Audience

if len(scopes) == 0 || cfg.UseAudienceFromAdmin {
publicClientConfig, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return "", nil, "", fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
// Update scopes from publicClientConfig
if len(scopes) == 0 {
scopes = publicClientConfig.Scopes
}
// Update audience from publicClientConfig
if cfg.UseAudienceFromAdmin {
audienceValue = publicClientConfig.Audience
}
}
return tokenURL, scopes, audienceValue, nil
}

type ExternalTokenSourceProvider struct {
command []string
}
Expand Down Expand Up @@ -172,19 +199,49 @@ type ClientCredentialsTokenSourceProvider struct {
tokenCache cache.TokenCache
}

func NewClientCredentialsTokenSourceProviderFromAPIKey(cfg *Config, scopes []string, tokenURL string,
tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) {
apiKey := os.Getenv(cfg.APIKeyEnvVar)
if len(apiKey) == 0 && len(cfg.APIKeyLocation) == 0 {
return nil, fmt.Errorf("API key is empty at Env Var [%v]", cfg.APIKeyEnvVar)
}

if len(apiKey) == 0 {
raw, err := os.ReadFile(cfg.APIKeyLocation)
if err != nil {
return nil, fmt.Errorf("failed to read API Key from [%v]. Error: %w", cfg.APIKeyLocation, err)
}

apiKey = string(raw)
}

decoded, err := DecodeAPIKey(apiKey)
if err != nil {
return nil, fmt.Errorf("failed to decode API key. Error: %w", err)
}

return newClientCredentialsTokenSourceProvider(cfg, decoded.ClientID, decoded.Secret, scopes, tokenURL, tokenCache, audience)
}

func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string,
tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) {
var secret string
if len(cfg.ClientSecretEnvVar) > 0 {
secret = os.Getenv(cfg.ClientSecretEnvVar)
} else if len(cfg.ClientSecretLocation) > 0 {
secretBytes, err := ioutil.ReadFile(cfg.ClientSecretLocation)
secretBytes, err := os.ReadFile(cfg.ClientSecretLocation)
if err != nil {
logger.Errorf(ctx, "Error reading secret from location %s", cfg.ClientSecretLocation)
return nil, err
}
secret = string(secretBytes)
}

return newClientCredentialsTokenSourceProvider(cfg, cfg.ClientID, secret, scopes, tokenURL, tokenCache, audience)
}

func newClientCredentialsTokenSourceProvider(cfg *Config, clientID, secret string, scopes []string, tokenURL string,
tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) {
endpointParams := url.Values{}
if len(audience) > 0 {
endpointParams = url.Values{audienceKey: {audience}}
Expand All @@ -195,14 +252,15 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
}
return ClientCredentialsTokenSourceProvider{
ccConfig: clientcredentials.Config{
ClientID: cfg.ClientID,
ClientID: clientID,
ClientSecret: secret,
TokenURL: tokenURL,
Scopes: scopes,
EndpointParams: endpointParams,
},
tokenRefreshWindow: cfg.TokenRefreshWindow.Duration,
tokenCache: tokenCache}, nil
tokenCache: tokenCache,
}, nil
}

func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
Expand Down