From 470f92d29d7a32f7903a3cb6d49fb09363185866 Mon Sep 17 00:00:00 2001 From: Shan Huang <74681657+shawnhuang-gg@users.noreply.github.com> Date: Sat, 20 Feb 2021 20:25:19 -0800 Subject: [PATCH] feat: Added '-enable_iam_login' flag for IAM db authentication (#583) --- README.md | 1 + cmd/cloud_sql_proxy/cloud_sql_proxy.go | 28 ++++++++----- proxy/certs/certs.go | 39 +++++++++++++++--- proxy/proxy/client.go | 11 ++++- proxy/proxy/client_test.go | 57 +++++++++++++++++++++++++- tests/postgres_test.go | 34 +++++++++++++++ 6 files changed, 151 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 0ca41ef21..b87fd7ca4 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ cloud_sql_proxy takes a few arguments to configure what instances to connect to * `-skip_failed_instance_config`: Setting this flag will allow you to prevent the proxy from terminating when some instance configurations could not be parsed and/or are unavailable. * `-log_debug_stdout=true`: This is to log non-error output to stdOut instead of stdErr. For example, if you don't want connection related messages to log as errors, set this flag to true. Defaults to false. +* `-enable_iam_login`: This enables the proxy to use Cloud SQL IAM database authentication. This will cause the proxy to use IAM account credentials for database user authentication. Note: `-instances` and `-instances_metadata` may be used at the same time but are not compatible with the `-fuse` flag. diff --git a/cmd/cloud_sql_proxy/cloud_sql_proxy.go b/cmd/cloud_sql_proxy/cloud_sql_proxy.go index 0ff478ff1..36b55f53f 100644 --- a/cmd/cloud_sql_proxy/cloud_sql_proxy.go +++ b/cmd/cloud_sql_proxy/cloud_sql_proxy.go @@ -86,6 +86,8 @@ can be removed automatically by this program.`) tokenFile = flag.String("credential_file", "", `If provided, this json file will be used to retrieve Service Account credentials. You may set the GOOGLE_APPLICATION_CREDENTIALS environment variable for the same effect.`) ipAddressTypes = flag.String("ip_address_types", "PUBLIC,PRIVATE", "Default to be 'PUBLIC,PRIVATE'. Options: a list of strings separated by ',', e.g. 'PUBLIC,PRIVATE' ") + // Settings for IAM db proxy authentication + enableIAMLogin = flag.Bool("enable_iam_login", false, "Enables database user authentication using Cloud SQL's IAM DB Authentication.") skipInvalidInstanceConfigs = flag.Bool("skip_failed_instance_config", false, `Setting this flag will allow you to prevent the proxy from terminating when some instance configurations could not be parsed and/or are unavailable.`) @@ -123,6 +125,10 @@ Authorization: This will override gcloud or GCE (Google Compute Engine) credentials, if they exist. + * To configure the proxy using IAM authentication, pass the -enable_iam_login + flag. This will cause the proxy to use IAM account credentials for + database user authentication. + General: -quiet Disable log messages (e.g. when new connections are established). @@ -271,31 +277,31 @@ func checkFlags(onGCE bool) error { return nil } -func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, error) { +func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, oauth2.TokenSource, error) { all, err := ioutil.ReadFile(f) if err != nil { - return nil, fmt.Errorf("invalid json file %q: %v", f, err) + return nil, nil, fmt.Errorf("invalid json file %q: %v", f, err) } // First try and load this as a service account config, which allows us to see the service account email: if cfg, err := goauth.JWTConfigFromJSON(all, proxy.SQLScope); err == nil { logging.Infof("using credential file for authentication; email=%s", cfg.Email) - return cfg.Client(ctx), nil + return cfg.Client(ctx), cfg.TokenSource(ctx), nil } cred, err := goauth.CredentialsFromJSON(ctx, all, proxy.SQLScope) if err != nil { - return nil, fmt.Errorf("invalid json file %q: %v", f, err) + return nil, nil, fmt.Errorf("invalid json file %q: %v", f, err) } logging.Infof("using credential file for authentication; path=%q", f) - return oauth2.NewClient(ctx, cred.TokenSource), nil + return oauth2.NewClient(ctx, cred.TokenSource), cred.TokenSource, nil } -func authenticatedClient(ctx context.Context) (*http.Client, error) { +func authenticatedClient(ctx context.Context) (*http.Client, oauth2.TokenSource, error) { if *tokenFile != "" { return authenticatedClientFromPath(ctx, *tokenFile) } else if tok := *token; tok != "" { src := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tok}) - return oauth2.NewClient(ctx, src), nil + return oauth2.NewClient(ctx, src), src, nil } else if f := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"); f != "" { return authenticatedClientFromPath(ctx, f) } @@ -307,10 +313,10 @@ func authenticatedClient(ctx context.Context) (*http.Client, error) { src, err = goauth.DefaultTokenSource(ctx, proxy.SQLScope) } if err != nil { - return nil, err + return nil, nil, err } - return oauth2.NewClient(ctx, src), nil + return oauth2.NewClient(ctx, src), src, nil } func stringList(s string) []string { @@ -462,7 +468,7 @@ func main() { } ctx := context.Background() - client, err := authenticatedClient(ctx) + client, tokSrc, err := authenticatedClient(ctx) if err != nil { log.Fatal(err) } @@ -497,6 +503,8 @@ func main() { IgnoreRegion: !*checkRegion, UserAgent: userAgentFromVersionString(), IPAddrTypeOpts: ipAddrTypeOptsInput, + EnableIAMLogin: *enableIAMLogin, + TokenSource: tokSrc, }), Conns: connset, RefreshCfgThrottle: refreshCfgThrottle, diff --git a/proxy/certs/certs.go b/proxy/certs/certs.go index 334162001..b01ab28e3 100644 --- a/proxy/certs/certs.go +++ b/proxy/certs/certs.go @@ -31,6 +31,7 @@ import ( "github.com/GoogleCloudPlatform/cloudsql-proxy/logging" "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/util" + "golang.org/x/oauth2" "google.golang.org/api/googleapi" sqladmin "google.golang.org/api/sqladmin/v1beta4" ) @@ -68,6 +69,12 @@ type RemoteOpts struct { // IP address type options IPAddrTypeOpts []string + + // Enable IAM proxy db authentication + EnableIAMLogin bool + + // Token source for token information used in cert creation + TokenSource oauth2.TokenSource } // NewCertSourceOpts returns a CertSource configured with the provided Opts. @@ -105,7 +112,7 @@ func NewCertSourceOpts(c *http.Client, opts RemoteOpts) *RemoteCertSource { } } - return &RemoteCertSource{pkey, serv, !opts.IgnoreRegion, opts.IPAddrTypeOpts} + return &RemoteCertSource{pkey, serv, !opts.IgnoreRegion, opts.IPAddrTypeOpts, opts.EnableIAMLogin, opts.TokenSource} } // RemoteCertSource implements a CertSource, using Cloud SQL APIs to @@ -123,6 +130,10 @@ type RemoteCertSource struct { checkRegion bool // a list of ip address types that users select IPAddrTypes []string + // flag to enable IAM proxy db authentication + EnableIAMLogin bool + // token source for the token information used in cert creation + TokenSource oauth2.TokenSource } // Constants for backoffAPIRetry. These cause the retry logic to scale the @@ -171,11 +182,18 @@ func (s *RemoteCertSource) Local(instance string) (ret tls.Certificate, err erro p, r, n := util.SplitName(instance) regionName := fmt.Sprintf("%s~%s", r, n) - req := s.serv.SslCerts.CreateEphemeral(p, regionName, - &sqladmin.SslCertsCreateEphemeralRequest{ - PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"})), - }, - ) + pubKey := string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"})) + createEphemeralRequest := sqladmin.SslCertsCreateEphemeralRequest{ + PublicKey: pubKey, + } + if s.EnableIAMLogin { + tok, e := s.TokenSource.Token() + if e != nil { + return ret, e + } + createEphemeralRequest.AccessToken = tok.AccessToken + } + req := s.serv.SslCerts.CreateEphemeral(p, regionName, &createEphemeralRequest) var data *sqladmin.SslCert err = backoffAPIRetry("createEphemeral for", instance, func() error { @@ -277,3 +295,12 @@ func (s *RemoteCertSource) Remote(instance string) (cert *x509.Certificate, addr return c, ipAddrInUse, p + ":" + n, data.DatabaseVersion, err } + +// TokenExpiration returns the expiration time for token source associated with remote cert source. +func (s *RemoteCertSource) TokenExpiration() (ret time.Time, err error) { + tok, err := s.TokenSource.Token() + if err != nil { + return ret, err + } + return tok.Expiry, nil +} diff --git a/proxy/proxy/client.go b/proxy/proxy/client.go index cb551d220..15e2d3d9a 100644 --- a/proxy/proxy/client.go +++ b/proxy/proxy/client.go @@ -55,6 +55,8 @@ type CertSource interface { Local(instance string) (tls.Certificate, error) // Remote returns the instance's CA certificate, address, and name. Remote(instance string) (cert *x509.Certificate, addr, name, version string, err error) + // TokenExpiration returns expiration time of the token information. + TokenExpiration() (time.Time, error) } // Client is a type to handle connecting to a Server. All fields are required @@ -232,11 +234,18 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, vers VerifyPeerCertificate: genVerifyPeerCertificateFunc(name, certs), } + tokenExpiry, err := c.Certs.TokenExpiration() + if err != nil { + return "", nil, "", err + } expire := mycert.Leaf.NotAfter + if expire.After(tokenExpiry) { + expire = tokenExpiry + } now := time.Now() timeToRefresh := expire.Sub(now) - refreshCfgBuffer if timeToRefresh <= 0 { - err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", expire, now) + err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", now, expire) logging.Errorf("ephemeral certificate (%+v) error: %v", mycert, err) return "", nil, "", err } diff --git a/proxy/proxy/client_test.go b/proxy/proxy/client_test.go index f663bc6cb..f143197dc 100644 --- a/proxy/proxy/client_test.go +++ b/proxy/proxy/client_test.go @@ -41,8 +41,9 @@ type fakeCerts struct { } type blockingCertSource struct { - values map[string]*fakeCerts - validUntil time.Time + values map[string]*fakeCerts + validUntil time.Time + tokenExpire time.Time } func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) { @@ -66,6 +67,10 @@ func (cs *blockingCertSource) Remote(instance string) (cert *x509.Certificate, a return &x509.Certificate{}, "fake address", "fake name", "fake version", nil } +func (cs *blockingCertSource) TokenExpiration() (ret time.Time, err error) { + return cs.tokenExpire, nil +} + func TestContextDialer(t *testing.T) { b := &fakeCerts{} c := &Client{ @@ -74,6 +79,7 @@ func TestContextDialer(t *testing.T) { instance: b, }, forever, + forever, }, ContextDialer: func(context.Context, string, string) (net.Conn, error) { return nil, errFakeDial @@ -96,6 +102,7 @@ func TestClientCache(t *testing.T) { instance: b, }, forever, + forever, }, Dialer: func(string, string) (net.Conn, error) { return nil, errFakeDial @@ -123,6 +130,7 @@ func TestConcurrentRefresh(t *testing.T) { instance: b, }, forever, + forever, }, Dialer: func(string, string) (net.Conn, error) { return nil, errFakeDial @@ -164,6 +172,7 @@ func TestMaximumConnectionsCount(t *testing.T) { certSource := blockingCertSource{ map[string]*fakeCerts{}, forever, + forever, } firstDialExited := make(chan struct{}) c := &Client{ @@ -225,6 +234,7 @@ func TestShutdownTerminatesEarly(t *testing.T) { instance: b, }, forever, + forever, }, Dialer: func(string, string) (net.Conn, error) { return nil, nil @@ -256,6 +266,49 @@ func TestRefreshTimer(t *testing.T) { instance: b, }, certCreated.Add(timeToExpire), + forever, + }, + Dialer: func(string, string) (net.Conn, error) { + return nil, errFakeDial + }, + RefreshCfgThrottle: 20 * time.Millisecond, + RefreshCfgBuffer: time.Second, + } + // Call Dial to cache the cert. + if _, err := c.Dial(instance); err != errFakeDial { + t.Fatalf("Dial(%s) failed: %v", instance, err) + } + c.cacheL.Lock() + cfg, ok := c.cfgCache[instance] + c.cacheL.Unlock() + if !ok { + t.Fatalf("expected instance to be cached") + } + + time.Sleep(timeToExpire - time.Since(certCreated)) + // Check if cert was refreshed in the background, without calling Dial again. + c.cacheL.Lock() + newCfg, ok := c.cfgCache[instance] + c.cacheL.Unlock() + if !ok { + t.Fatalf("expected instance to be cached") + } + if !newCfg.lastRefreshed.After(cfg.lastRefreshed) { + t.Error("expected cert to be refreshed.") + } +} + +func TestRefreshTimerTokenExpires(t *testing.T) { + timeToExpire := 5 * time.Second + b := &fakeCerts{} + certCreated := time.Now() + c := &Client{ + Certs: &blockingCertSource{ + map[string]*fakeCerts{ + instance: b, + }, + forever, + certCreated.Add(timeToExpire), }, Dialer: func(string, string) (net.Conn, error) { return nil, errFakeDial diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 27c3c4c8d..d2c635a71 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -18,6 +18,8 @@ package tests import ( + "context" + "database/sql" "flag" "fmt" "io/ioutil" @@ -35,6 +37,8 @@ var ( postgresPass = flag.String("postgres_pass", os.Getenv("POSTGRES_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).") postgresDb = flag.String("postgres_db", os.Getenv("POSTGRES_DB"), "Name of the database to connect to.") + postgresIAMUser = flag.String("postgres_user_iam", os.Getenv("POSTGRES_USER_IAM"), "Name of database user configured with IAM DB Authentication.") + postgresPort = 5432 ) @@ -76,3 +80,33 @@ func TestPostgresConnLimit(t *testing.T) { dsn := fmt.Sprintf("user=%s password=%s database=%s sslmode=disable", *postgresUser, *postgresPass, *postgresDb) proxyConnLimitTest(t, *postgresConnName, "postgres", dsn, postgresPort) } + +func TestPostgresIAMDBAuthn(t *testing.T) { + requirePostgresVars(t) + + ctx := context.Background() + + // Start the proxy + p, err := StartProxy(ctx, fmt.Sprintf("-instances=%s=tcp:%d", *postgresConnName, 5432), "-enable_iam_login") + if err != nil { + t.Fatalf("unable to start proxy: %v", err) + } + defer p.Close() + output, err := p.WaitForServe(ctx) + if err != nil { + t.Fatalf("unable to verify proxy was serving: %s \n %s", err, output) + } + + dsn := fmt.Sprintf("user=%s database=%s sslmode=disable", *postgresIAMUser, *postgresDb) + db, err := sql.Open("postgres", dsn) + if err != nil { + t.Fatalf("unable to connect to db: %s", err) + } + defer db.Close() + _, err = db.Exec("SELECT 1;") + if err != nil { + + t.Fatalf("unable to exec on db: %s", err) + } + +}