Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(internaloption): add EnableDirectPath internaloption #732

Merged
merged 1 commit into from Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/settings.go
Expand Up @@ -41,6 +41,7 @@ type DialSettings struct {
CustomClaims map[string]interface{}
SkipValidation bool
ImpersonationConfig *impersonate.Config
EnableDirectPath bool

// Google API system parameters. For more information please read:
// https://cloud.google.com/apis/docs/system-parameters
Expand Down
15 changes: 15 additions & 0 deletions option/internaloption/internaloption.go
Expand Up @@ -50,3 +50,18 @@ type skipDialSettingsValidation struct{}
func (s skipDialSettingsValidation) Apply(settings *internal.DialSettings) {
settings.SkipValidation = true
}

// EnableDirectPath returns a ClientOption that overrides the default
// attempt to use DirectPath.
//
// It should only be used internally by generated clients.
// This is an EXPERIMENTAL API and may be changed or removed in the future.
func EnableDirectPath(dp bool) option.ClientOption {
return enableDirectPath(dp)
}

type enableDirectPath bool

func (e enableDirectPath) Apply(o *internal.DialSettings) {
o.EnableDirectPath = bool(e)
}
23 changes: 8 additions & 15 deletions transport/grpc/dial.go
Expand Up @@ -12,7 +12,6 @@ import (
"crypto/tls"
"errors"
"log"
"os"
"strings"

"go.opencensus.io/plugin/ocgrpc"
Expand Down Expand Up @@ -138,9 +137,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
// * The endpoint is a host:port (or dns:///host:port).
// * Credentials are obtained via GCE metadata server, using the default
// service account.
// * Opted in via GOOGLE_CLOUD_ENABLE_DIRECT_PATH environment variable.
// For example, GOOGLE_CLOUD_ENABLE_DIRECT_PATH=spanner,pubsub
if isDirectPathEnabled(endpoint) && isTokenSourceDirectPathCompatible(creds.TokenSource) {
if o.EnableDirectPath && checkDirectPathEndPoint(endpoint) && isTokenSourceDirectPathCompatible(creds.TokenSource) {
if !strings.HasPrefix(endpoint, "dns:///") {
endpoint = "dns:///" + endpoint
}
Expand Down Expand Up @@ -189,7 +186,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
// point when isDirectPathEnabled will default to true, we guard it by
// the Directpath env var for now once we can introspect user defined
// dialer (https://github.com/grpc/grpc-go/issues/2795).
if timeoutDialerOption != nil && isDirectPathEnabled(endpoint) {
if timeoutDialerOption != nil && o.EnableDirectPath && checkDirectPathEndPoint(endpoint) {
grpcOpts = append(grpcOpts, timeoutDialerOption)
}

Expand Down Expand Up @@ -250,8 +247,8 @@ func isTokenSourceDirectPathCompatible(ts oauth2.TokenSource) bool {
return true
}

func isDirectPathEnabled(endpoint string) bool {
// Only host:port is supported, not other schemes (e.g., "tcp://" or "unix://").
func checkDirectPathEndPoint(endpoint string) bool {
// Only [dns:///]host[:port] is supported, not other schemes (e.g., "tcp://" or "unix://").
// Also don't try direct path if the user has chosen an alternate name resolver
// (i.e., via ":///" prefix).
//
Expand All @@ -261,15 +258,11 @@ func isDirectPathEnabled(endpoint string) bool {
return false
}

// Only try direct path if the user has opted in via the environment variable.
directPathAPIs := strings.Split(os.Getenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH"), ",")
mohanli-ml marked this conversation as resolved.
Show resolved Hide resolved
for _, api := range directPathAPIs {
// Ignore empty string since an empty env variable splits into [""]
if api != "" && strings.Contains(endpoint, api) {
return true
}
if endpoint == "" {
return false
}
return false

return true
}

func processAndValidateOpts(opts []option.ClientOption) (*internal.DialSettings, error) {
Expand Down
8 changes: 3 additions & 5 deletions transport/grpc/dial_socketopt_test.go
Expand Up @@ -11,14 +11,14 @@ import (
"errors"
"fmt"
"net"
"os"
"syscall"
"testing"
"time"

"golang.org/x/oauth2"
"golang.org/x/sys/unix"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -90,9 +90,6 @@ func getTCPUserTimeout(conn net.Conn) (int, error) {

// Check that tcp timeout dialer overwrites user defined dialer.
func TestDialWithDirectPathEnabled(t *testing.T) {
os.Setenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH", "example,other")
defer os.Clearenv()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)

userDialer := grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
Expand All @@ -104,7 +101,8 @@ func TestDialWithDirectPathEnabled(t *testing.T) {
conn, err := Dial(ctx,
option.WithTokenSource(oauth2.StaticTokenSource(nil)), // No creds.
option.WithGRPCDialOption(userDialer),
option.WithEndpoint("example.google.com:443"))
option.WithEndpoint("example.google.com:443"),
internaloption.EnableDirectPath(true))
if err != nil {
t.Errorf("DialGRPC: error %v, want nil", err)
}
Expand Down
45 changes: 11 additions & 34 deletions transport/grpc/dial_test.go
Expand Up @@ -8,7 +8,6 @@ import (
"context"
"errors"
"net"
"os"
"testing"
"time"

Expand Down Expand Up @@ -55,62 +54,40 @@ func TestGRPCHook(t *testing.T) {
}
}

func TestIsDirectPathEnabled(t *testing.T) {
func TestCheckDirectPathEndPoint(t *testing.T) {
for _, testcase := range []struct {
name string
endpoint string
envVar string
want bool
}{
{
name: "matches",
endpoint: "some-api",
envVar: "some-api",
want: true,
},
{
name: "does not match",
endpoint: "some-api",
envVar: "some-other-api",
name: "empty endpoint are disallowed",
endpoint: "",
want: false,
},
{
name: "matches in list",
endpoint: "some-api-2",
envVar: "some-api-1,some-api-2,some-api-3",
name: "dns schemes are allowed",
endpoint: "dns:///foo",
want: true,
},
{
name: "empty env var",
endpoint: "",
envVar: "",
want: false,
},
{
name: "trailing comma",
endpoint: "",
envVar: "foo,bar,",
want: false,
name: "host without no prefix are allowed",
endpoint: "foo",
want: true,
},
{
name: "dns schemes are allowed",
endpoint: "dns:///foo",
envVar: "dns:///foo",
name: "host with port are allowed",
endpoint: "foo:1234",
want: true,
},
{
name: "non-dns schemes are disallowed",
endpoint: "https://foo",
envVar: "https://foo",
want: false,
},
} {
t.Run(testcase.name, func(t *testing.T) {
if err := os.Setenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH", testcase.envVar); err != nil {
t.Fatal(err)
}

if got := isDirectPathEnabled(testcase.endpoint); got != testcase.want {
if got := checkDirectPathEndPoint(testcase.endpoint); got != testcase.want {
t.Fatalf("got %v, want %v", got, testcase.want)
}
})
Expand Down