Skip to content

Commit

Permalink
feat: Added '-enable_iam_login' flag for IAM db authentication (#583)
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnhuang-gg committed Feb 21, 2021
1 parent a363ba0 commit 470f92d
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 19 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -37,6 +37,7 @@ cloud_sql_proxy takes a few arguments to configure what instances to connect to
* `-skip_failed_instance_config`: Setting this flag will allow you to prevent the proxy from terminating when
some instance configurations could not be parsed and/or are unavailable.
* `-log_debug_stdout=true`: This is to log non-error output to stdOut instead of stdErr. For example, if you don't want connection related messages to log as errors, set this flag to true. Defaults to false.
* `-enable_iam_login`: This enables the proxy to use Cloud SQL IAM database authentication. This will cause the proxy to use IAM account credentials for database user authentication.

Note: `-instances` and `-instances_metadata` may be used at the same time but
are not compatible with the `-fuse` flag.
Expand Down
28 changes: 18 additions & 10 deletions cmd/cloud_sql_proxy/cloud_sql_proxy.go
Expand Up @@ -86,6 +86,8 @@ can be removed automatically by this program.`)
tokenFile = flag.String("credential_file", "", `If provided, this json file will be used to retrieve Service Account credentials.
You may set the GOOGLE_APPLICATION_CREDENTIALS environment variable for the same effect.`)
ipAddressTypes = flag.String("ip_address_types", "PUBLIC,PRIVATE", "Default to be 'PUBLIC,PRIVATE'. Options: a list of strings separated by ',', e.g. 'PUBLIC,PRIVATE' ")
// Settings for IAM db proxy authentication
enableIAMLogin = flag.Bool("enable_iam_login", false, "Enables database user authentication using Cloud SQL's IAM DB Authentication.")

skipInvalidInstanceConfigs = flag.Bool("skip_failed_instance_config", false, `Setting this flag will allow you to prevent the proxy from terminating when
some instance configurations could not be parsed and/or are unavailable.`)
Expand Down Expand Up @@ -123,6 +125,10 @@ Authorization:
This will override gcloud or GCE (Google Compute Engine) credentials,
if they exist.
* To configure the proxy using IAM authentication, pass the -enable_iam_login
flag. This will cause the proxy to use IAM account credentials for
database user authentication.
General:
-quiet
Disable log messages (e.g. when new connections are established).
Expand Down Expand Up @@ -271,31 +277,31 @@ func checkFlags(onGCE bool) error {
return nil
}

func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, error) {
func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, oauth2.TokenSource, error) {
all, err := ioutil.ReadFile(f)
if err != nil {
return nil, fmt.Errorf("invalid json file %q: %v", f, err)
return nil, nil, fmt.Errorf("invalid json file %q: %v", f, err)
}
// First try and load this as a service account config, which allows us to see the service account email:
if cfg, err := goauth.JWTConfigFromJSON(all, proxy.SQLScope); err == nil {
logging.Infof("using credential file for authentication; email=%s", cfg.Email)
return cfg.Client(ctx), nil
return cfg.Client(ctx), cfg.TokenSource(ctx), nil
}

cred, err := goauth.CredentialsFromJSON(ctx, all, proxy.SQLScope)
if err != nil {
return nil, fmt.Errorf("invalid json file %q: %v", f, err)
return nil, nil, fmt.Errorf("invalid json file %q: %v", f, err)
}
logging.Infof("using credential file for authentication; path=%q", f)
return oauth2.NewClient(ctx, cred.TokenSource), nil
return oauth2.NewClient(ctx, cred.TokenSource), cred.TokenSource, nil
}

func authenticatedClient(ctx context.Context) (*http.Client, error) {
func authenticatedClient(ctx context.Context) (*http.Client, oauth2.TokenSource, error) {
if *tokenFile != "" {
return authenticatedClientFromPath(ctx, *tokenFile)
} else if tok := *token; tok != "" {
src := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tok})
return oauth2.NewClient(ctx, src), nil
return oauth2.NewClient(ctx, src), src, nil
} else if f := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"); f != "" {
return authenticatedClientFromPath(ctx, f)
}
Expand All @@ -307,10 +313,10 @@ func authenticatedClient(ctx context.Context) (*http.Client, error) {
src, err = goauth.DefaultTokenSource(ctx, proxy.SQLScope)
}
if err != nil {
return nil, err
return nil, nil, err
}

return oauth2.NewClient(ctx, src), nil
return oauth2.NewClient(ctx, src), src, nil
}

func stringList(s string) []string {
Expand Down Expand Up @@ -462,7 +468,7 @@ func main() {
}

ctx := context.Background()
client, err := authenticatedClient(ctx)
client, tokSrc, err := authenticatedClient(ctx)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -497,6 +503,8 @@ func main() {
IgnoreRegion: !*checkRegion,
UserAgent: userAgentFromVersionString(),
IPAddrTypeOpts: ipAddrTypeOptsInput,
EnableIAMLogin: *enableIAMLogin,
TokenSource: tokSrc,
}),
Conns: connset,
RefreshCfgThrottle: refreshCfgThrottle,
Expand Down
39 changes: 33 additions & 6 deletions proxy/certs/certs.go
Expand Up @@ -31,6 +31,7 @@ import (

"github.com/GoogleCloudPlatform/cloudsql-proxy/logging"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/util"
"golang.org/x/oauth2"
"google.golang.org/api/googleapi"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand Down Expand Up @@ -68,6 +69,12 @@ type RemoteOpts struct {

// IP address type options
IPAddrTypeOpts []string

// Enable IAM proxy db authentication
EnableIAMLogin bool

// Token source for token information used in cert creation
TokenSource oauth2.TokenSource
}

// NewCertSourceOpts returns a CertSource configured with the provided Opts.
Expand Down Expand Up @@ -105,7 +112,7 @@ func NewCertSourceOpts(c *http.Client, opts RemoteOpts) *RemoteCertSource {
}
}

return &RemoteCertSource{pkey, serv, !opts.IgnoreRegion, opts.IPAddrTypeOpts}
return &RemoteCertSource{pkey, serv, !opts.IgnoreRegion, opts.IPAddrTypeOpts, opts.EnableIAMLogin, opts.TokenSource}
}

// RemoteCertSource implements a CertSource, using Cloud SQL APIs to
Expand All @@ -123,6 +130,10 @@ type RemoteCertSource struct {
checkRegion bool
// a list of ip address types that users select
IPAddrTypes []string
// flag to enable IAM proxy db authentication
EnableIAMLogin bool
// token source for the token information used in cert creation
TokenSource oauth2.TokenSource
}

// Constants for backoffAPIRetry. These cause the retry logic to scale the
Expand Down Expand Up @@ -171,11 +182,18 @@ func (s *RemoteCertSource) Local(instance string) (ret tls.Certificate, err erro

p, r, n := util.SplitName(instance)
regionName := fmt.Sprintf("%s~%s", r, n)
req := s.serv.SslCerts.CreateEphemeral(p, regionName,
&sqladmin.SslCertsCreateEphemeralRequest{
PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"})),
},
)
pubKey := string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"}))
createEphemeralRequest := sqladmin.SslCertsCreateEphemeralRequest{
PublicKey: pubKey,
}
if s.EnableIAMLogin {
tok, e := s.TokenSource.Token()
if e != nil {
return ret, e
}
createEphemeralRequest.AccessToken = tok.AccessToken
}
req := s.serv.SslCerts.CreateEphemeral(p, regionName, &createEphemeralRequest)

var data *sqladmin.SslCert
err = backoffAPIRetry("createEphemeral for", instance, func() error {
Expand Down Expand Up @@ -277,3 +295,12 @@ func (s *RemoteCertSource) Remote(instance string) (cert *x509.Certificate, addr

return c, ipAddrInUse, p + ":" + n, data.DatabaseVersion, err
}

// TokenExpiration returns the expiration time for token source associated with remote cert source.
func (s *RemoteCertSource) TokenExpiration() (ret time.Time, err error) {
tok, err := s.TokenSource.Token()
if err != nil {
return ret, err
}
return tok.Expiry, nil
}
11 changes: 10 additions & 1 deletion proxy/proxy/client.go
Expand Up @@ -55,6 +55,8 @@ type CertSource interface {
Local(instance string) (tls.Certificate, error)
// Remote returns the instance's CA certificate, address, and name.
Remote(instance string) (cert *x509.Certificate, addr, name, version string, err error)
// TokenExpiration returns expiration time of the token information.
TokenExpiration() (time.Time, error)
}

// Client is a type to handle connecting to a Server. All fields are required
Expand Down Expand Up @@ -232,11 +234,18 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, vers
VerifyPeerCertificate: genVerifyPeerCertificateFunc(name, certs),
}

tokenExpiry, err := c.Certs.TokenExpiration()
if err != nil {
return "", nil, "", err
}
expire := mycert.Leaf.NotAfter
if expire.After(tokenExpiry) {
expire = tokenExpiry
}
now := time.Now()
timeToRefresh := expire.Sub(now) - refreshCfgBuffer
if timeToRefresh <= 0 {
err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", expire, now)
err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", now, expire)
logging.Errorf("ephemeral certificate (%+v) error: %v", mycert, err)
return "", nil, "", err
}
Expand Down
57 changes: 55 additions & 2 deletions proxy/proxy/client_test.go
Expand Up @@ -41,8 +41,9 @@ type fakeCerts struct {
}

type blockingCertSource struct {
values map[string]*fakeCerts
validUntil time.Time
values map[string]*fakeCerts
validUntil time.Time
tokenExpire time.Time
}

func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
Expand All @@ -66,6 +67,10 @@ func (cs *blockingCertSource) Remote(instance string) (cert *x509.Certificate, a
return &x509.Certificate{}, "fake address", "fake name", "fake version", nil
}

func (cs *blockingCertSource) TokenExpiration() (ret time.Time, err error) {
return cs.tokenExpire, nil
}

func TestContextDialer(t *testing.T) {
b := &fakeCerts{}
c := &Client{
Expand All @@ -74,6 +79,7 @@ func TestContextDialer(t *testing.T) {
instance: b,
},
forever,
forever,
},
ContextDialer: func(context.Context, string, string) (net.Conn, error) {
return nil, errFakeDial
Expand All @@ -96,6 +102,7 @@ func TestClientCache(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down Expand Up @@ -123,6 +130,7 @@ func TestConcurrentRefresh(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down Expand Up @@ -164,6 +172,7 @@ func TestMaximumConnectionsCount(t *testing.T) {
certSource := blockingCertSource{
map[string]*fakeCerts{},
forever,
forever,
}
firstDialExited := make(chan struct{})
c := &Client{
Expand Down Expand Up @@ -225,6 +234,7 @@ func TestShutdownTerminatesEarly(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, nil
Expand Down Expand Up @@ -256,6 +266,49 @@ func TestRefreshTimer(t *testing.T) {
instance: b,
},
certCreated.Add(timeToExpire),
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
},
RefreshCfgThrottle: 20 * time.Millisecond,
RefreshCfgBuffer: time.Second,
}
// Call Dial to cache the cert.
if _, err := c.Dial(instance); err != errFakeDial {
t.Fatalf("Dial(%s) failed: %v", instance, err)
}
c.cacheL.Lock()
cfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}

time.Sleep(timeToExpire - time.Since(certCreated))
// Check if cert was refreshed in the background, without calling Dial again.
c.cacheL.Lock()
newCfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}
if !newCfg.lastRefreshed.After(cfg.lastRefreshed) {
t.Error("expected cert to be refreshed.")
}
}

func TestRefreshTimerTokenExpires(t *testing.T) {
timeToExpire := 5 * time.Second
b := &fakeCerts{}
certCreated := time.Now()
c := &Client{
Certs: &blockingCertSource{
map[string]*fakeCerts{
instance: b,
},
forever,
certCreated.Add(timeToExpire),
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down
34 changes: 34 additions & 0 deletions tests/postgres_test.go
Expand Up @@ -18,6 +18,8 @@
package tests

import (
"context"
"database/sql"
"flag"
"fmt"
"io/ioutil"
Expand All @@ -35,6 +37,8 @@ var (
postgresPass = flag.String("postgres_pass", os.Getenv("POSTGRES_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).")
postgresDb = flag.String("postgres_db", os.Getenv("POSTGRES_DB"), "Name of the database to connect to.")

postgresIAMUser = flag.String("postgres_user_iam", os.Getenv("POSTGRES_USER_IAM"), "Name of database user configured with IAM DB Authentication.")

postgresPort = 5432
)

Expand Down Expand Up @@ -76,3 +80,33 @@ func TestPostgresConnLimit(t *testing.T) {
dsn := fmt.Sprintf("user=%s password=%s database=%s sslmode=disable", *postgresUser, *postgresPass, *postgresDb)
proxyConnLimitTest(t, *postgresConnName, "postgres", dsn, postgresPort)
}

func TestPostgresIAMDBAuthn(t *testing.T) {
requirePostgresVars(t)

ctx := context.Background()

// Start the proxy
p, err := StartProxy(ctx, fmt.Sprintf("-instances=%s=tcp:%d", *postgresConnName, 5432), "-enable_iam_login")
if err != nil {
t.Fatalf("unable to start proxy: %v", err)
}
defer p.Close()
output, err := p.WaitForServe(ctx)
if err != nil {
t.Fatalf("unable to verify proxy was serving: %s \n %s", err, output)
}

dsn := fmt.Sprintf("user=%s database=%s sslmode=disable", *postgresIAMUser, *postgresDb)
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatalf("unable to connect to db: %s", err)
}
defer db.Close()
_, err = db.Exec("SELECT 1;")
if err != nil {

t.Fatalf("unable to exec on db: %s", err)
}

}

0 comments on commit 470f92d

Please sign in to comment.