diff --git a/internal/settings.go b/internal/settings.go index 26259b82abb..7c0f9292d50 100644 --- a/internal/settings.go +++ b/internal/settings.go @@ -41,6 +41,7 @@ type DialSettings struct { CustomClaims map[string]interface{} SkipValidation bool ImpersonationConfig *impersonate.Config + EnableDirectPath bool // Google API system parameters. For more information please read: // https://cloud.google.com/apis/docs/system-parameters diff --git a/option/internaloption/internaloption.go b/option/internaloption/internaloption.go index b4d78a830ae..6ee4501c048 100644 --- a/option/internaloption/internaloption.go +++ b/option/internaloption/internaloption.go @@ -50,3 +50,18 @@ type skipDialSettingsValidation struct{} func (s skipDialSettingsValidation) Apply(settings *internal.DialSettings) { settings.SkipValidation = true } + +// EnableDirectPath returns a ClientOption that overrides the default +// attempt to use DirectPath. +// +// It should only be used internally by generated clients. +// This is an EXPERIMENTAL API and may be changed or removed in the future. +func EnableDirectPath(dp bool) option.ClientOption { + return enableDirectPath(dp) +} + +type enableDirectPath bool + +func (e enableDirectPath) Apply(o *internal.DialSettings) { + o.EnableDirectPath = bool(e) +} diff --git a/transport/grpc/dial.go b/transport/grpc/dial.go index 55c04a5af06..f8a6ca29981 100644 --- a/transport/grpc/dial.go +++ b/transport/grpc/dial.go @@ -12,7 +12,6 @@ import ( "crypto/tls" "errors" "log" - "os" "strings" "go.opencensus.io/plugin/ocgrpc" @@ -138,9 +137,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C // * The endpoint is a host:port (or dns:///host:port). // * Credentials are obtained via GCE metadata server, using the default // service account. - // * Opted in via GOOGLE_CLOUD_ENABLE_DIRECT_PATH environment variable. - // For example, GOOGLE_CLOUD_ENABLE_DIRECT_PATH=spanner,pubsub - if isDirectPathEnabled(endpoint) && isTokenSourceDirectPathCompatible(creds.TokenSource) { + if o.EnableDirectPath && checkDirectPathEndPoint(endpoint) && isTokenSourceDirectPathCompatible(creds.TokenSource) { if !strings.HasPrefix(endpoint, "dns:///") { endpoint = "dns:///" + endpoint } @@ -189,7 +186,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C // point when isDirectPathEnabled will default to true, we guard it by // the Directpath env var for now once we can introspect user defined // dialer (https://github.com/grpc/grpc-go/issues/2795). - if timeoutDialerOption != nil && isDirectPathEnabled(endpoint) { + if timeoutDialerOption != nil && o.EnableDirectPath && checkDirectPathEndPoint(endpoint) { grpcOpts = append(grpcOpts, timeoutDialerOption) } @@ -250,8 +247,8 @@ func isTokenSourceDirectPathCompatible(ts oauth2.TokenSource) bool { return true } -func isDirectPathEnabled(endpoint string) bool { - // Only host:port is supported, not other schemes (e.g., "tcp://" or "unix://"). +func checkDirectPathEndPoint(endpoint string) bool { + // Only [dns:///]host[:port] is supported, not other schemes (e.g., "tcp://" or "unix://"). // Also don't try direct path if the user has chosen an alternate name resolver // (i.e., via ":///" prefix). // @@ -261,15 +258,11 @@ func isDirectPathEnabled(endpoint string) bool { return false } - // Only try direct path if the user has opted in via the environment variable. - directPathAPIs := strings.Split(os.Getenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH"), ",") - for _, api := range directPathAPIs { - // Ignore empty string since an empty env variable splits into [""] - if api != "" && strings.Contains(endpoint, api) { - return true - } + if endpoint == "" { + return false } - return false + + return true } func processAndValidateOpts(opts []option.ClientOption) (*internal.DialSettings, error) { diff --git a/transport/grpc/dial_socketopt_test.go b/transport/grpc/dial_socketopt_test.go index e36965d5df7..d6cc7f726f5 100644 --- a/transport/grpc/dial_socketopt_test.go +++ b/transport/grpc/dial_socketopt_test.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "net" - "os" "syscall" "testing" "time" @@ -19,6 +18,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/sys/unix" "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" "google.golang.org/grpc" ) @@ -90,9 +90,6 @@ func getTCPUserTimeout(conn net.Conn) (int, error) { // Check that tcp timeout dialer overwrites user defined dialer. func TestDialWithDirectPathEnabled(t *testing.T) { - os.Setenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH", "example,other") - defer os.Clearenv() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) userDialer := grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { @@ -104,7 +101,8 @@ func TestDialWithDirectPathEnabled(t *testing.T) { conn, err := Dial(ctx, option.WithTokenSource(oauth2.StaticTokenSource(nil)), // No creds. option.WithGRPCDialOption(userDialer), - option.WithEndpoint("example.google.com:443")) + option.WithEndpoint("example.google.com:443"), + internaloption.EnableDirectPath(true)) if err != nil { t.Errorf("DialGRPC: error %v, want nil", err) } diff --git a/transport/grpc/dial_test.go b/transport/grpc/dial_test.go index c850ce615fe..6ec20369306 100644 --- a/transport/grpc/dial_test.go +++ b/transport/grpc/dial_test.go @@ -8,7 +8,6 @@ import ( "context" "errors" "net" - "os" "testing" "time" @@ -55,62 +54,40 @@ func TestGRPCHook(t *testing.T) { } } -func TestIsDirectPathEnabled(t *testing.T) { +func TestCheckDirectPathEndPoint(t *testing.T) { for _, testcase := range []struct { name string endpoint string - envVar string want bool }{ { - name: "matches", - endpoint: "some-api", - envVar: "some-api", - want: true, - }, - { - name: "does not match", - endpoint: "some-api", - envVar: "some-other-api", + name: "empty endpoint are disallowed", + endpoint: "", want: false, }, { - name: "matches in list", - endpoint: "some-api-2", - envVar: "some-api-1,some-api-2,some-api-3", + name: "dns schemes are allowed", + endpoint: "dns:///foo", want: true, }, { - name: "empty env var", - endpoint: "", - envVar: "", - want: false, - }, - { - name: "trailing comma", - endpoint: "", - envVar: "foo,bar,", - want: false, + name: "host without no prefix are allowed", + endpoint: "foo", + want: true, }, { - name: "dns schemes are allowed", - endpoint: "dns:///foo", - envVar: "dns:///foo", + name: "host with port are allowed", + endpoint: "foo:1234", want: true, }, { name: "non-dns schemes are disallowed", endpoint: "https://foo", - envVar: "https://foo", want: false, }, } { t.Run(testcase.name, func(t *testing.T) { - if err := os.Setenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH", testcase.envVar); err != nil { - t.Fatal(err) - } - - if got := isDirectPathEnabled(testcase.endpoint); got != testcase.want { + if got := checkDirectPathEndPoint(testcase.endpoint); got != testcase.want { t.Fatalf("got %v, want %v", got, testcase.want) } })