diff --git a/idtoken/idtoken.go b/idtoken/idtoken.go index be52f3b2905..be2ba5cd429 100644 --- a/idtoken/idtoken.go +++ b/idtoken/idtoken.go @@ -104,7 +104,7 @@ func tokenSourceFromBytes(ctx context.Context, data []byte, audience string, ds if err := isServiceAccount(data); err != nil { return nil, err } - cfg, err := google.JWTConfigFromJSON(data, ds.Scopes...) + cfg, err := google.JWTConfigFromJSON(data, ds.GetScopes()...) if err != nil { return nil, err } diff --git a/internal/creds.go b/internal/creds.go index dc6d50e96aa..c93daa98c32 100644 --- a/internal/creds.go +++ b/internal/creds.go @@ -34,24 +34,24 @@ func baseCreds(ctx context.Context, ds *DialSettings) (*google.Credentials, erro return ds.Credentials, nil } if ds.CredentialsJSON != nil { - return credentialsFromJSON(ctx, ds.CredentialsJSON, ds.Endpoint, ds.Scopes, ds.Audiences) + return credentialsFromJSON(ctx, ds.CredentialsJSON, ds) } if ds.CredentialsFile != "" { data, err := ioutil.ReadFile(ds.CredentialsFile) if err != nil { return nil, fmt.Errorf("cannot read credentials file: %v", err) } - return credentialsFromJSON(ctx, data, ds.Endpoint, ds.Scopes, ds.Audiences) + return credentialsFromJSON(ctx, data, ds) } if ds.TokenSource != nil { return &google.Credentials{TokenSource: ds.TokenSource}, nil } - cred, err := google.FindDefaultCredentials(ctx, ds.Scopes...) + cred, err := google.FindDefaultCredentials(ctx, ds.GetScopes()...) if err != nil { return nil, err } if len(cred.JSON) > 0 { - return credentialsFromJSON(ctx, cred.JSON, ds.Endpoint, ds.Scopes, ds.Audiences) + return credentialsFromJSON(ctx, cred.JSON, ds) } // For GAE and GCE, the JSON is empty so return the default credentials directly. return cred, nil @@ -66,12 +66,12 @@ const ( // // - If the JSON is a service account and no scopes provided, returns self-signed JWT auth flow // - Otherwise, returns OAuth 2.0 flow. -func credentialsFromJSON(ctx context.Context, data []byte, endpoint string, scopes []string, audiences []string) (*google.Credentials, error) { - cred, err := google.CredentialsFromJSON(ctx, data, scopes...) +func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*google.Credentials, error) { + cred, err := google.CredentialsFromJSON(ctx, data, ds.GetScopes()...) if err != nil { return nil, err } - if len(data) > 0 && len(scopes) == 0 { + if len(data) > 0 && len(ds.Scopes) == 0 && (ds.DefaultAudience != "" || len(ds.Audiences) > 0) { var f struct { Type string `json:"type"` // The rest JSON fields are omitted because they are not used. @@ -80,7 +80,7 @@ func credentialsFromJSON(ctx context.Context, data []byte, endpoint string, scop return nil, err } if f.Type == serviceAccountKey { - ts, err := selfSignedJWTTokenSource(data, endpoint, audiences) + ts, err := selfSignedJWTTokenSource(data, ds.DefaultAudience, ds.Audiences) if err != nil { return nil, err } @@ -90,9 +90,8 @@ func credentialsFromJSON(ctx context.Context, data []byte, endpoint string, scop return cred, err } -func selfSignedJWTTokenSource(data []byte, endpoint string, audiences []string) (oauth2.TokenSource, error) { - // Use the API endpoint as the default audience - audience := endpoint +func selfSignedJWTTokenSource(data []byte, defaultAudience string, audiences []string) (oauth2.TokenSource, error) { + audience := defaultAudience if len(audiences) > 0 { // TODO(shinfan): Update golang oauth to support multiple audiences. if len(audiences) > 1 { @@ -118,7 +117,7 @@ func QuotaProjectFromCreds(cred *google.Credentials) string { func impersonateCredentials(ctx context.Context, creds *google.Credentials, ds *DialSettings) (*google.Credentials, error) { if len(ds.ImpersonationConfig.Scopes) == 0 { - ds.ImpersonationConfig.Scopes = ds.Scopes + ds.ImpersonationConfig.Scopes = ds.GetScopes() } ts, err := impersonate.TokenSource(ctx, creds.TokenSource, ds.ImpersonationConfig) if err != nil { diff --git a/internal/creds_test.go b/internal/creds_test.go index 0be8e2ae59a..045c5bc20a6 100644 --- a/internal/creds_test.go +++ b/internal/creds_test.go @@ -119,7 +119,7 @@ const validServiceAccountJSON = `{ func TestQuotaProjectFromCreds(t *testing.T) { ctx := context.Background() - cred, err := credentialsFromJSON(ctx, []byte(validServiceAccountJSON), "foo.googleapis.com", nil, nil) + cred, err := credentialsFromJSON(ctx, []byte(validServiceAccountJSON), &DialSettings{Endpoint: "foo.googleapis.com"}) if err != nil { t.Fatalf("got %v, wanted no error", err) } @@ -133,7 +133,7 @@ func TestQuotaProjectFromCreds(t *testing.T) { "quota_project_id": "foobar" }`) - cred, err = credentialsFromJSON(ctx, []byte(quotaProjectJSON), "foo.googleapis.com", nil, nil) + cred, err = credentialsFromJSON(ctx, []byte(quotaProjectJSON), &DialSettings{Endpoint: "foo.googleapis.com"}) if err != nil { t.Fatalf("got %v, wanted no error", err) } diff --git a/internal/settings.go b/internal/settings.go index 7c0f9292d50..0ae1cb9778d 100644 --- a/internal/settings.go +++ b/internal/settings.go @@ -23,6 +23,7 @@ type DialSettings struct { DefaultEndpoint string DefaultMTLSEndpoint string Scopes []string + DefaultScopes []string TokenSource oauth2.TokenSource Credentials *google.Credentials CredentialsFile string // if set, Token Source is ignored. @@ -30,6 +31,7 @@ type DialSettings struct { UserAgent string APIKey string Audiences []string + DefaultAudience string HTTPClient *http.Client GRPCDialOpts []grpc.DialOption GRPCConn *grpc.ClientConn @@ -49,6 +51,15 @@ type DialSettings struct { RequestReason string } +// GetScopes returns the user-provided scopes, if set, or else falls back to the +// default scopes. +func (ds *DialSettings) GetScopes() []string { + if len(ds.Scopes) > 0 { + return ds.Scopes + } + return ds.DefaultScopes +} + // Validate reports an error if ds is invalid. func (ds *DialSettings) Validate() error { if ds.SkipValidation { diff --git a/option/internaloption/internaloption.go b/option/internaloption/internaloption.go index 6ee4501c048..1fff22fd5da 100644 --- a/option/internaloption/internaloption.go +++ b/option/internaloption/internaloption.go @@ -65,3 +65,32 @@ type enableDirectPath bool func (e enableDirectPath) Apply(o *internal.DialSettings) { o.EnableDirectPath = bool(e) } + +// WithDefaultAudience returns a ClientOption that specifies a default audience +// to be used as the audience field ("aud") for the JWT token authentication. +// +// It should only be used internally by generated clients. +func WithDefaultAudience(audience string) option.ClientOption { + return withDefaultAudience(audience) +} + +type withDefaultAudience string + +func (w withDefaultAudience) Apply(o *internal.DialSettings) { + o.DefaultAudience = string(w) +} + +// WithDefaultScopes returns a ClientOption that overrides the default OAuth2 +// scopes to be used for a service. +// +// It should only be used internally by generated clients. +func WithDefaultScopes(scope ...string) option.ClientOption { + return withDefaultScopes(scope) +} + +type withDefaultScopes []string + +func (w withDefaultScopes) Apply(o *internal.DialSettings) { + o.DefaultScopes = make([]string, len(w)) + copy(o.DefaultScopes, w) +}