Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added '-enable_iam_login' flag for IAM db authentication #583

Merged
merged 6 commits into from
Feb 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ cloud_sql_proxy takes a few arguments to configure what instances to connect to
* `-skip_failed_instance_config`: Setting this flag will allow you to prevent the proxy from terminating when
some instance configurations could not be parsed and/or are unavailable.
* `-log_debug_stdout=true`: This is to log non-error output to stdOut instead of stdErr. For example, if you don't want connection related messages to log as errors, set this flag to true. Defaults to false.
* `-enable_iam_login`: This enables the proxy to use Cloud SQL IAM database authentication. This will cause the proxy to use IAM account credentials for database user authentication.

Note: `-instances` and `-instances_metadata` may be used at the same time but
are not compatible with the `-fuse` flag.
Expand Down
28 changes: 18 additions & 10 deletions cmd/cloud_sql_proxy/cloud_sql_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ can be removed automatically by this program.`)
tokenFile = flag.String("credential_file", "", `If provided, this json file will be used to retrieve Service Account credentials.
You may set the GOOGLE_APPLICATION_CREDENTIALS environment variable for the same effect.`)
ipAddressTypes = flag.String("ip_address_types", "PUBLIC,PRIVATE", "Default to be 'PUBLIC,PRIVATE'. Options: a list of strings separated by ',', e.g. 'PUBLIC,PRIVATE' ")
// Settings for IAM db proxy authentication
enableIAMLogin = flag.Bool("enable_iam_login", false, "Enables database user authentication using Cloud SQL's IAM DB Authentication.")

skipInvalidInstanceConfigs = flag.Bool("skip_failed_instance_config", false, `Setting this flag will allow you to prevent the proxy from terminating when
some instance configurations could not be parsed and/or are unavailable.`)
Expand Down Expand Up @@ -123,6 +125,10 @@ Authorization:
This will override gcloud or GCE (Google Compute Engine) credentials,
if they exist.

* To configure the proxy using IAM authentication, pass the -enable_iam_login
flag. This will cause the proxy to use IAM account credentials for
database user authentication.

General:
-quiet
Disable log messages (e.g. when new connections are established).
Expand Down Expand Up @@ -271,31 +277,31 @@ func checkFlags(onGCE bool) error {
return nil
}

func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, error) {
func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, oauth2.TokenSource, error) {
all, err := ioutil.ReadFile(f)
if err != nil {
return nil, fmt.Errorf("invalid json file %q: %v", f, err)
return nil, nil, fmt.Errorf("invalid json file %q: %v", f, err)
}
// First try and load this as a service account config, which allows us to see the service account email:
if cfg, err := goauth.JWTConfigFromJSON(all, proxy.SQLScope); err == nil {
logging.Infof("using credential file for authentication; email=%s", cfg.Email)
return cfg.Client(ctx), nil
return cfg.Client(ctx), cfg.TokenSource(ctx), nil
}

cred, err := goauth.CredentialsFromJSON(ctx, all, proxy.SQLScope)
if err != nil {
return nil, fmt.Errorf("invalid json file %q: %v", f, err)
return nil, nil, fmt.Errorf("invalid json file %q: %v", f, err)
}
logging.Infof("using credential file for authentication; path=%q", f)
return oauth2.NewClient(ctx, cred.TokenSource), nil
return oauth2.NewClient(ctx, cred.TokenSource), cred.TokenSource, nil
}

func authenticatedClient(ctx context.Context) (*http.Client, error) {
func authenticatedClient(ctx context.Context) (*http.Client, oauth2.TokenSource, error) {
if *tokenFile != "" {
return authenticatedClientFromPath(ctx, *tokenFile)
} else if tok := *token; tok != "" {
src := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tok})
return oauth2.NewClient(ctx, src), nil
return oauth2.NewClient(ctx, src), src, nil
} else if f := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"); f != "" {
return authenticatedClientFromPath(ctx, f)
}
Expand All @@ -307,10 +313,10 @@ func authenticatedClient(ctx context.Context) (*http.Client, error) {
src, err = goauth.DefaultTokenSource(ctx, proxy.SQLScope)
}
if err != nil {
return nil, err
return nil, nil, err
}

return oauth2.NewClient(ctx, src), nil
return oauth2.NewClient(ctx, src), src, nil
}

func stringList(s string) []string {
Expand Down Expand Up @@ -462,7 +468,7 @@ func main() {
}

ctx := context.Background()
client, err := authenticatedClient(ctx)
client, tokSrc, err := authenticatedClient(ctx)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -497,6 +503,8 @@ func main() {
IgnoreRegion: !*checkRegion,
UserAgent: userAgentFromVersionString(),
IPAddrTypeOpts: ipAddrTypeOptsInput,
EnableIAMLogin: *enableIAMLogin,
TokenSource: tokSrc,
}),
Conns: connset,
RefreshCfgThrottle: refreshCfgThrottle,
Expand Down
39 changes: 33 additions & 6 deletions proxy/certs/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (

"github.com/GoogleCloudPlatform/cloudsql-proxy/logging"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/util"
"golang.org/x/oauth2"
"google.golang.org/api/googleapi"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand Down Expand Up @@ -68,6 +69,12 @@ type RemoteOpts struct {

// IP address type options
IPAddrTypeOpts []string

// Enable IAM proxy db authentication
EnableIAMLogin bool

// Token source for token information used in cert creation
TokenSource oauth2.TokenSource
}

// NewCertSourceOpts returns a CertSource configured with the provided Opts.
Expand Down Expand Up @@ -105,7 +112,7 @@ func NewCertSourceOpts(c *http.Client, opts RemoteOpts) *RemoteCertSource {
}
}

return &RemoteCertSource{pkey, serv, !opts.IgnoreRegion, opts.IPAddrTypeOpts}
return &RemoteCertSource{pkey, serv, !opts.IgnoreRegion, opts.IPAddrTypeOpts, opts.EnableIAMLogin, opts.TokenSource}
}

// RemoteCertSource implements a CertSource, using Cloud SQL APIs to
Expand All @@ -123,6 +130,10 @@ type RemoteCertSource struct {
checkRegion bool
// a list of ip address types that users select
IPAddrTypes []string
// flag to enable IAM proxy db authentication
EnableIAMLogin bool
// token source for the token information used in cert creation
TokenSource oauth2.TokenSource
}

// Constants for backoffAPIRetry. These cause the retry logic to scale the
Expand Down Expand Up @@ -171,11 +182,18 @@ func (s *RemoteCertSource) Local(instance string) (ret tls.Certificate, err erro

p, r, n := util.SplitName(instance)
regionName := fmt.Sprintf("%s~%s", r, n)
req := s.serv.SslCerts.CreateEphemeral(p, regionName,
&sqladmin.SslCertsCreateEphemeralRequest{
PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"})),
},
)
pubKey := string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"}))
createEphemeralRequest := sqladmin.SslCertsCreateEphemeralRequest{
PublicKey: pubKey,
}
if s.EnableIAMLogin {
tok, e := s.TokenSource.Token()
if e != nil {
return ret, e
}
createEphemeralRequest.AccessToken = tok.AccessToken
}
req := s.serv.SslCerts.CreateEphemeral(p, regionName, &createEphemeralRequest)

var data *sqladmin.SslCert
err = backoffAPIRetry("createEphemeral for", instance, func() error {
Expand Down Expand Up @@ -277,3 +295,12 @@ func (s *RemoteCertSource) Remote(instance string) (cert *x509.Certificate, addr

return c, ipAddrInUse, p + ":" + n, data.DatabaseVersion, err
}

// TokenExpiration returns the expiration time for token source associated with remote cert source.
func (s *RemoteCertSource) TokenExpiration() (ret time.Time, err error) {
tok, err := s.TokenSource.Token()
if err != nil {
return ret, err
}
return tok.Expiry, nil
}
11 changes: 10 additions & 1 deletion proxy/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ type CertSource interface {
Local(instance string) (tls.Certificate, error)
// Remote returns the instance's CA certificate, address, and name.
Remote(instance string) (cert *x509.Certificate, addr, name, version string, err error)
// TokenExpiration returns expiration time of the token information.
TokenExpiration() (time.Time, error)
}

// Client is a type to handle connecting to a Server. All fields are required
Expand Down Expand Up @@ -232,11 +234,18 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, vers
VerifyPeerCertificate: genVerifyPeerCertificateFunc(name, certs),
}

tokenExpiry, err := c.Certs.TokenExpiration()
if err != nil {
return "", nil, "", err
}
expire := mycert.Leaf.NotAfter
if expire.After(tokenExpiry) {
expire = tokenExpiry
}
now := time.Now()
timeToRefresh := expire.Sub(now) - refreshCfgBuffer
if timeToRefresh <= 0 {
err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", expire, now)
err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", now, expire)
logging.Errorf("ephemeral certificate (%+v) error: %v", mycert, err)
return "", nil, "", err
}
Expand Down
57 changes: 55 additions & 2 deletions proxy/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ type fakeCerts struct {
}

type blockingCertSource struct {
values map[string]*fakeCerts
validUntil time.Time
values map[string]*fakeCerts
validUntil time.Time
tokenExpire time.Time
}

func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
Expand All @@ -66,6 +67,10 @@ func (cs *blockingCertSource) Remote(instance string) (cert *x509.Certificate, a
return &x509.Certificate{}, "fake address", "fake name", "fake version", nil
}

func (cs *blockingCertSource) TokenExpiration() (ret time.Time, err error) {
return cs.tokenExpire, nil
}

func TestContextDialer(t *testing.T) {
b := &fakeCerts{}
c := &Client{
Expand All @@ -74,6 +79,7 @@ func TestContextDialer(t *testing.T) {
instance: b,
},
forever,
forever,
},
ContextDialer: func(context.Context, string, string) (net.Conn, error) {
return nil, errFakeDial
Expand All @@ -96,6 +102,7 @@ func TestClientCache(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down Expand Up @@ -123,6 +130,7 @@ func TestConcurrentRefresh(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down Expand Up @@ -164,6 +172,7 @@ func TestMaximumConnectionsCount(t *testing.T) {
certSource := blockingCertSource{
map[string]*fakeCerts{},
forever,
forever,
}
firstDialExited := make(chan struct{})
c := &Client{
Expand Down Expand Up @@ -225,6 +234,7 @@ func TestShutdownTerminatesEarly(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, nil
Expand Down Expand Up @@ -256,6 +266,49 @@ func TestRefreshTimer(t *testing.T) {
instance: b,
},
certCreated.Add(timeToExpire),
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
},
RefreshCfgThrottle: 20 * time.Millisecond,
RefreshCfgBuffer: time.Second,
}
// Call Dial to cache the cert.
if _, err := c.Dial(instance); err != errFakeDial {
t.Fatalf("Dial(%s) failed: %v", instance, err)
}
c.cacheL.Lock()
cfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}

time.Sleep(timeToExpire - time.Since(certCreated))
// Check if cert was refreshed in the background, without calling Dial again.
c.cacheL.Lock()
newCfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}
if !newCfg.lastRefreshed.After(cfg.lastRefreshed) {
t.Error("expected cert to be refreshed.")
}
}

func TestRefreshTimerTokenExpires(t *testing.T) {
timeToExpire := 5 * time.Second
b := &fakeCerts{}
certCreated := time.Now()
c := &Client{
Certs: &blockingCertSource{
map[string]*fakeCerts{
instance: b,
},
forever,
certCreated.Add(timeToExpire),
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down
34 changes: 34 additions & 0 deletions tests/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package tests

import (
"context"
"database/sql"
"flag"
"fmt"
"io/ioutil"
Expand All @@ -35,6 +37,8 @@ var (
postgresPass = flag.String("postgres_pass", os.Getenv("POSTGRES_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).")
postgresDb = flag.String("postgres_db", os.Getenv("POSTGRES_DB"), "Name of the database to connect to.")

postgresIAMUser = flag.String("postgres_user_iam", os.Getenv("POSTGRES_USER_IAM"), "Name of database user configured with IAM DB Authentication.")

postgresPort = 5432
)

Expand Down Expand Up @@ -76,3 +80,33 @@ func TestPostgresConnLimit(t *testing.T) {
dsn := fmt.Sprintf("user=%s password=%s database=%s sslmode=disable", *postgresUser, *postgresPass, *postgresDb)
proxyConnLimitTest(t, *postgresConnName, "postgres", dsn, postgresPort)
}

func TestPostgresIAMDBAuthn(t *testing.T) {
requirePostgresVars(t)

ctx := context.Background()

// Start the proxy
p, err := StartProxy(ctx, fmt.Sprintf("-instances=%s=tcp:%d", *postgresConnName, 5432), "-enable_iam_login")
if err != nil {
t.Fatalf("unable to start proxy: %v", err)
}
defer p.Close()
output, err := p.WaitForServe(ctx)
if err != nil {
t.Fatalf("unable to verify proxy was serving: %s \n %s", err, output)
}

dsn := fmt.Sprintf("user=%s database=%s sslmode=disable", *postgresIAMUser, *postgresDb)
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatalf("unable to connect to db: %s", err)
}
defer db.Close()
_, err = db.Exec("SELECT 1;")
if err != nil {

t.Fatalf("unable to exec on db: %s", err)
}

}