diff --git a/internal/settings.go b/internal/settings.go index 89c7bc86fa3..278e2b34cc3 100644 --- a/internal/settings.go +++ b/internal/settings.go @@ -19,32 +19,33 @@ import ( // DialSettings holds information needed to establish a connection with a // Google API service. type DialSettings struct { - Endpoint string - DefaultEndpoint string - DefaultMTLSEndpoint string - Scopes []string - DefaultScopes []string - EnableJwtWithScope bool - TokenSource oauth2.TokenSource - Credentials *google.Credentials - CredentialsFile string // if set, Token Source is ignored. - CredentialsJSON []byte - UserAgent string - APIKey string - Audiences []string - DefaultAudience string - HTTPClient *http.Client - GRPCDialOpts []grpc.DialOption - GRPCConn *grpc.ClientConn - GRPCConnPool ConnPool - GRPCConnPoolSize int - NoAuth bool - TelemetryDisabled bool - ClientCertSource func(*tls.CertificateRequestInfo) (*tls.Certificate, error) - CustomClaims map[string]interface{} - SkipValidation bool - ImpersonationConfig *impersonate.Config - EnableDirectPath bool + Endpoint string + DefaultEndpoint string + DefaultMTLSEndpoint string + Scopes []string + DefaultScopes []string + EnableJwtWithScope bool + TokenSource oauth2.TokenSource + Credentials *google.Credentials + CredentialsFile string // if set, Token Source is ignored. + CredentialsJSON []byte + UserAgent string + APIKey string + Audiences []string + DefaultAudience string + HTTPClient *http.Client + GRPCDialOpts []grpc.DialOption + GRPCConn *grpc.ClientConn + GRPCConnPool ConnPool + GRPCConnPoolSize int + NoAuth bool + TelemetryDisabled bool + ClientCertSource func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + CustomClaims map[string]interface{} + SkipValidation bool + ImpersonationConfig *impersonate.Config + EnableDirectPath bool + AllowNonDefaultServiceAccount bool // Google API system parameters. For more information please read: // https://cloud.google.com/apis/docs/system-parameters diff --git a/option/internaloption/internaloption.go b/option/internaloption/internaloption.go index ed0b7aaf13e..4c33916165a 100644 --- a/option/internaloption/internaloption.go +++ b/option/internaloption/internaloption.go @@ -66,6 +66,21 @@ func (e enableDirectPath) Apply(o *internal.DialSettings) { o.EnableDirectPath = bool(e) } +// AllowNonDefaultServiceAccount returns a ClientOption that overrides the default +// requirement for using the default service account for DirectPath. +// +// It should only be used internally by generated clients. +// This is an EXPERIMENTAL API and may be changed or removed in the future. +func AllowNonDefaultServiceAccount(nd bool) option.ClientOption { + return allowNonDefaultServiceAccount(nd) +} + +type allowNonDefaultServiceAccount bool + +func (a allowNonDefaultServiceAccount) Apply(o *internal.DialSettings) { + o.AllowNonDefaultServiceAccount = bool(a) +} + // WithDefaultAudience returns a ClientOption that specifies a default audience // to be used as the audience field ("aud") for the JWT token authentication. // diff --git a/transport/grpc/dial.go b/transport/grpc/dial.go index 727a5beff1b..8cce06f4611 100644 --- a/transport/grpc/dial.go +++ b/transport/grpc/dial.go @@ -138,7 +138,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C // * The endpoint is a host:port (or dns:///host:port). // * Credentials are obtained via GCE metadata server, using the default // service account. - if o.EnableDirectPath && checkDirectPathEndPoint(endpoint) && isTokenSourceDirectPathCompatible(creds.TokenSource) && metadata.OnGCE() { + if o.EnableDirectPath && checkDirectPathEndPoint(endpoint) && isTokenSourceDirectPathCompatible(creds.TokenSource, o) && metadata.OnGCE() { if !strings.HasPrefix(endpoint, "dns:///") { endpoint = "dns:///" + endpoint } @@ -228,7 +228,7 @@ func (ts grpcTokenSource) GetRequestMetadata(ctx context.Context, uri ...string) return metadata, nil } -func isTokenSourceDirectPathCompatible(ts oauth2.TokenSource) bool { +func isTokenSourceDirectPathCompatible(ts oauth2.TokenSource, o *internal.DialSettings) bool { if ts == nil { return false } @@ -239,6 +239,9 @@ func isTokenSourceDirectPathCompatible(ts oauth2.TokenSource) bool { if tok == nil { return false } + if o.AllowNonDefaultServiceAccount { + return true + } if source, _ := tok.Extra("oauth2.google.tokenSource").(string); source != "compute-metadata" { return false }