Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Cache tokens obtained from external commands
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>
  • Loading branch information
Fabio Grätz committed Sep 16, 2023
1 parent 35b7327 commit 3b7389f
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 39 deletions.
36 changes: 21 additions & 15 deletions clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ const ProxyAuthorizationHeader = "proxy-authorization"

// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server.
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg)
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, proxyTokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyTokenCache)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}
Expand Down Expand Up @@ -51,8 +51,8 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T
return nil
}

func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) {
tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand)
func GetProxyTokenSource(ctx context.Context, cfg *Config, proxyTokenCache cache.TokenCache) (oauth2.TokenSource, error) {
tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand, proxyTokenCache)
if err != nil {
return nil, fmt.Errorf("failed to initialized proxy authorization token source provider. Err: %w", err)
}
Expand All @@ -63,8 +63,8 @@ func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource,
return proxyTokenSource, nil
}

func MaterializeProxyAuthCredentials(ctx context.Context, cfg *Config) (context.Context, error) {
proxyTokenSource, err := GetProxyTokenSource(ctx, cfg)
func MaterializeProxyAuthCredentials(ctx context.Context, cfg *Config, proxyTokenCache cache.TokenCache) (context.Context, error) {
proxyTokenSource, err := GetProxyTokenSource(ctx, cfg, proxyTokenCache)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -98,7 +98,7 @@ func (c *proxyAuthTransport) RoundTrip(req *http.Request) (*http.Response, error
}

// Set up http client used in oauth2
func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
func setHTTPClientContext(ctx context.Context, cfg *Config, proxyTokenCache cache.TokenCache) (context.Context, error) {
httpClient := &http.Client{}
transport := &http.Transport{}

Expand All @@ -108,7 +108,10 @@ func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
}

if cfg.ProxyCommand != nil {
proxyTokenSource, _ := GetProxyTokenSource(ctx, cfg)
proxyTokenSource, err := GetProxyTokenSource(ctx, cfg, proxyTokenCache)
if err != nil {
return nil, err
}

httpClient.Transport = &proxyAuthTransport{
transport: transport,
Expand All @@ -118,7 +121,7 @@ func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
httpClient.Transport = transport
}

return context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return context.WithValue(ctx, oauth2.HTTPClient, httpClient), nil
}

// NewAuthInterceptor creates a new grpc.UnaryClientInterceptor that forwards the grpc call and inspects the error.
Expand All @@ -131,19 +134,22 @@ func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
// more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once
// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should
// be able to find and acquire a valid AccessToken to annotate the request with.
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, proxyTokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = setHTTPClientContext(ctx, cfg)
ctx, err := setHTTPClientContext(ctx, cfg, proxyTokenCache)
if err != nil {
return err
}

err := invoker(ctx, method, req, reply, cc, opts...)
err = invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err)

if st, ok := status.FromError(err); ok {
// If the error we receive from executing the request expects
if shouldAttemptToAuthenticate(st.Code()) {
logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture)
newErr := MaterializeCredentials(ctx, cfg, tokenCache, proxyTokenCache, credentialsFuture)
if newErr != nil {
return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr)
}
Expand All @@ -157,9 +163,9 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
}
}

func NewProxyAuthInterceptor(cfg *Config) grpc.UnaryClientInterceptor {
func NewProxyAuthInterceptor(cfg *Config, proxyTokenCache cache.TokenCache) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx, err := MaterializeProxyAuthCredentials(ctx, cfg)
ctx, err := MaterializeProxyAuthCredentials(ctx, cfg, proxyTokenCache)
if err != nil {
return fmt.Errorf("proxy authorization error! Original Error: %v", err)
}
Expand Down
18 changes: 9 additions & 9 deletions clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,17 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr
}

// InitializeAuthMetadataClient creates a new anonymously Auth Metadata Service client.
func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client service.AuthMetadataServiceClient, err error) {
func InitializeAuthMetadataClient(ctx context.Context, cfg *Config, proxyTokenCache cache.TokenCache) (client service.AuthMetadataServiceClient, err error) {
// Create an unauthenticated connection to fetch AuthMetadata
authMetadataConnection, err := NewAdminConnection(ctx, cfg)
authMetadataConnection, err := NewAdminConnection(ctx, cfg, proxyTokenCache)
if err != nil {
return nil, fmt.Errorf("failed to initialized admin connection. Error: %w", err)
}

return service.NewAuthMetadataServiceClient(authMetadataConnection), nil
}

func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
func NewAdminConnection(ctx context.Context, cfg *Config, proxyTokenCache cache.TokenCache, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
if opts == nil {
// Initialize opts list to the potential number of options we will add. Initialization optimizes memory
// allocation.
Expand Down Expand Up @@ -155,7 +155,7 @@ func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOptio

// Ensure proxy auth interceptor is invoked prior to auth interceptor
if cfg.ProxyCommand != nil {
opts = append([]grpc.DialOption{grpc.WithChainUnaryInterceptor(NewProxyAuthInterceptor(cfg))}, opts...)
opts = append([]grpc.DialOption{grpc.WithChainUnaryInterceptor(NewProxyAuthInterceptor(cfg, proxyTokenCache))}, opts...)
}

return grpc.Dial(cfg.Endpoint.String(), opts...)
Expand All @@ -164,7 +164,7 @@ func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOptio
// InitializeAdminClient creates an AdminClient with a shared Admin connection for the process
// Deprecated: Please use initializeClients instead.
func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOption) service.AdminServiceClient {
set, err := initializeClients(ctx, cfg, nil, opts...)
set, err := initializeClients(ctx, cfg, nil, nil, opts...)
if err != nil {
logger.Panicf(ctx, "Failed to initialized client. Error: %v", err)
return nil
Expand All @@ -175,18 +175,18 @@ func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOp

// initializeClients creates an AdminClient, AuthServiceClient and IdentityServiceClient with a shared Admin connection
// for the process. Note that if called with different cfg/dialoptions, it will not refresh the connection.
func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) {
func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, proxyTokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) {
credentialsFuture := NewPerRPCCredentialsFuture()

opts = append(opts,
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture)),
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, proxyTokenCache, credentialsFuture)),
grpc.WithPerRPCCredentials(credentialsFuture))

if cfg.DefaultServiceConfig != "" {
opts = append(opts, grpc.WithDefaultServiceConfig(cfg.DefaultServiceConfig))
}

adminConnection, err := NewAdminConnection(ctx, cfg, opts...)
adminConnection, err := NewAdminConnection(ctx, cfg, proxyTokenCache, opts...)
if err != nil {
logger.Panicf(ctx, "failed to initialized Admin connection. Err: %s", err.Error())
}
Expand All @@ -204,7 +204,7 @@ func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenC

// Deprecated: Please use NewClientsetBuilder() instead.
func InitializeAdminClientFromConfig(ctx context.Context, tokenCache cache.TokenCache, opts ...grpc.DialOption) (service.AdminServiceClient, error) {
clientSet, err := initializeClients(ctx, GetConfig(ctx), tokenCache, opts...)
clientSet, err := initializeClients(ctx, GetConfig(ctx), tokenCache, nil, opts...)
if err != nil {
return nil, err
}
Expand Down
19 changes: 15 additions & 4 deletions clients/go/admin/client_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

// ClientsetBuilder is used to build the clientset. This allows custom token cache implementations to be plugged in.
type ClientsetBuilder struct {
config *Config
tokenCache cache.TokenCache
opts []grpc.DialOption
config *Config
tokenCache cache.TokenCache
proxyTokenCache cache.TokenCache
opts []grpc.DialOption
}

// ClientSetBuilder is constructor function to be used by the clients in interacting with the builder
Expand All @@ -32,6 +33,13 @@ func (cb *ClientsetBuilder) WithTokenCache(tokenCache cache.TokenCache) *Clients
return cb
}

// TokenCache is designed to cache a single token. When clients choose to send `"proxy-authorization`"
// headers, we, thus, employ a separate token cache.
func (cb *ClientsetBuilder) WithProxyTokenCache(tokenCache cache.TokenCache) *ClientsetBuilder {
cb.proxyTokenCache = tokenCache
return cb
}

func (cb *ClientsetBuilder) WithDialOptions(opts ...grpc.DialOption) *ClientsetBuilder {
cb.opts = opts
return cb
Expand All @@ -42,12 +50,15 @@ func (cb *ClientsetBuilder) Build(ctx context.Context) (*Clientset, error) {
if cb.tokenCache == nil {
cb.tokenCache = &cache.TokenCacheInMemoryProvider{}
}
if cb.proxyTokenCache == nil {
cb.proxyTokenCache = &cache.TokenCacheInMemoryProvider{}
}

if cb.config == nil {
cb.config = GetConfig(ctx)
}

return initializeClients(ctx, cb.config, cb.tokenCache, cb.opts...)
return initializeClients(ctx, cb.config, cb.tokenCache, cb.proxyTokenCache, cb.opts...)
}

func NewClientsetBuilder() *ClientsetBuilder {
Expand Down
37 changes: 37 additions & 0 deletions clients/go/admin/externalprocess/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package externalprocess

import (
"encoding/base64"
"encoding/json"
"errors"
"strings"
"time"
)

type jwtClaims struct {
Exp int64 `json:"exp"`
}

// When we receive a token from an external process, we don't have a key to validate it
// For caching purposes, we still want to know when the token *allegedly* expires.
func GetUnvalidatedTokenExpiration(tokenString string) (time.Time, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return time.Time{}, errors.New("invalid token")
}

payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return time.Time{}, err
}

var claims jwtClaims
err = json.Unmarshal(payload, &claims)
if err != nil {
return time.Time{}, err
}

expiry := time.Unix(claims.Exp, 0)

return expiry, nil
}
24 changes: 24 additions & 0 deletions clients/go/admin/externalprocess/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package externalprocess

import (
"encoding/base64"
"encoding/json"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestGetUnvalidatedTokenExpiration(t *testing.T) {

exp := time.Now().Unix()
claims := jwtClaims{Exp: exp}
payload, _ := json.Marshal(claims)
token := "header." + base64.RawURLEncoding.EncodeToString(payload) + ".signature"

expiry, err := GetUnvalidatedTokenExpiration(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
assert.Equal(t, exp, expiry.Unix())
}
48 changes: 37 additions & 11 deletions clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"strings"
"sync"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
Expand Down Expand Up @@ -86,7 +87,7 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T
return nil, err
}
case AuthTypeExternalCommand:
tokenProvider, err = NewExternalTokenSourceProvider(cfg.Command)
tokenProvider, err = NewExternalTokenSourceProvider(cfg.Command, tokenCache)
if err != nil {
return nil, err
}
Expand All @@ -107,24 +108,51 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T
return tokenProvider, nil
}

type ExternalTokenSourceProvider struct {
type ExternalTokenSource struct {
command []string
}

func NewExternalTokenSourceProvider(command []string) (TokenSourceProvider, error) {
return &ExternalTokenSourceProvider{command: command}, nil
func NewExternalTokenSource(command []string) oauth2.TokenSource {
return &ExternalTokenSource{command: command}
}

func (e ExternalTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
func (e *ExternalTokenSource) Token() (*oauth2.Token, error) {
output, err := externalprocess.Execute(e.command)
if err != nil {
return nil, err
}

return oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: strings.Trim(string(output), "\t \n"),
token := strings.Trim(string(output), "\t \n")
exp, err := externalprocess.GetUnvalidatedTokenExpiration(token)
if err != nil {
// If we cannot extract an expiration, as a precaution, we do not
// want to cache the token as otherwise the external command would
// never be called again. Note that `exp = time.Time{}` would cause
// the token to be considered valid forever.
exp = time.Unix(0, 0)
}
return &oauth2.Token{
AccessToken: token,
TokenType: "bearer",
}), nil
Expiry: exp,
}, nil
}

type ExternalTokenSourceProvider struct {
command []string
tokenCache cache.TokenCache
}

func NewExternalTokenSourceProvider(command []string, tokenCache cache.TokenCache) (TokenSourceProvider, error) {
return &ExternalTokenSourceProvider{command: command, tokenCache: tokenCache}, nil
}

func (e *ExternalTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return &customTokenSource{
ctx: ctx,
new: NewExternalTokenSource(e.command),
mu: sync.Mutex{},
tokenCache: e.tokenCache,
}, nil
}

type PKCETokenSourceProvider struct {
Expand All @@ -151,8 +179,6 @@ func GetPKCEAuthTokenSource(ctx context.Context, pkceTokenOrchestrator pkce.Toke
logger.Warnf(ctx, "Failed fetching from cache. Will restart the flow. Error: %v", err)
}

authToken = nil

if authToken == nil {
// Fetch using auth flow
if authToken, err = pkceTokenOrchestrator.FetchTokenFromAuthFlow(ctx); err != nil {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/antihax/optional v1.0.0
github.com/flyteorg/flytestdlib v1.0.0
github.com/go-test/deep v1.0.7
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b
github.com/golang/protobuf v1.4.3
github.com/grpc-ecosystem/go-grpc-middleware v1.1.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down

0 comments on commit 3b7389f

Please sign in to comment.