Skip to content

Commit

Permalink
feat: Add access token info in ephemeral ssl cert for proxy IAM db au…
Browse files Browse the repository at this point in the history
…thentication
  • Loading branch information
shawnhuang-gg committed Jan 9, 2021
1 parent f9908cc commit dc724da
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 25 deletions.
42 changes: 26 additions & 16 deletions cmd/cloud_sql_proxy/cloud_sql_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ can be removed automatically by this program.`)
tokenFile = flag.String("credential_file", "", `If provided, this json file will be used to retrieve Service Account credentials.
You may set the GOOGLE_APPLICATION_CREDENTIALS environment variable for the same effect.`)
ipAddressTypes = flag.String("ip_address_types", "PUBLIC,PRIVATE", "Default to be 'PUBLIC,PRIVATE'. Options: a list of strings separated by ',', e.g. 'PUBLIC,PRIVATE' ")
// Settings for IAM db proxy authentication
enableIAMLogin = flag.Bool("enable_iam_login", false, "Enables IAM DB proxy authentication")

skipInvalidInstanceConfigs = flag.Bool("skip_failed_instance_config", false, `Setting this flag will allow you to prevent the proxy from terminating when
some instance configurations could not be parsed and/or are unavailable.`)
Expand Down Expand Up @@ -123,6 +125,10 @@ Authorization:
This will override gcloud or GCE (Google Compute Engine) credentials,
if they exist.
* To configure the proxy using IAM authentication, pass the -enable_iam_login
flag. This will consolidate gcloud credential into ephemeral cert that will
be checked on server side.
General:
-quiet
Disable log messages (e.g. when new connections are established).
Expand Down Expand Up @@ -271,33 +277,35 @@ func checkFlags(onGCE bool) error {
return nil
}

func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, error) {
func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, oauth2.TokenSource, error) {
all, err := ioutil.ReadFile(f)
if err != nil {
return nil, fmt.Errorf("invalid json file %q: %v", f, err)
return nil, nil, fmt.Errorf("invalid json file %q: %v", f, err)
}
// First try and load this as a service account config, which allows us to see the service account email:
if cfg, err := goauth.JWTConfigFromJSON(all, proxy.SQLScope); err == nil {
logging.Infof("using credential file for authentication; email=%s", cfg.Email)
return cfg.Client(ctx), nil
return cfg.Client(ctx), cfg.TokenSource(ctx), nil
}

cred, err := goauth.CredentialsFromJSON(ctx, all, proxy.SQLScope)
if err != nil {
return nil, fmt.Errorf("invalid json file %q: %v", f, err)
return nil, nil, fmt.Errorf("invalid json file %q: %v", f, err)
}
logging.Infof("using credential file for authentication; path=%q", f)
return oauth2.NewClient(ctx, cred.TokenSource), nil
return oauth2.NewClient(ctx, cred.TokenSource), cred.TokenSource, nil
}

func authenticatedClient(ctx context.Context) (*http.Client, error) {
if *tokenFile != "" {
return authenticatedClientFromPath(ctx, *tokenFile)
} else if tok := *token; tok != "" {
src := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tok})
return oauth2.NewClient(ctx, src), nil
} else if f := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"); f != "" {
return authenticatedClientFromPath(ctx, f)
func authenticatedClient(ctx context.Context) (*http.Client, oauth2.TokenSource, error) {
if !*enableIAMLogin {
if *tokenFile != "" {
return authenticatedClientFromPath(ctx, *tokenFile)
} else if tok := *token; tok != "" {
src := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tok})
return oauth2.NewClient(ctx, src), src, nil
} else if f := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"); f != "" {
return authenticatedClientFromPath(ctx, f)
}
}

// If flags or env don't specify an auth source, try either gcloud or application default
Expand All @@ -307,10 +315,10 @@ func authenticatedClient(ctx context.Context) (*http.Client, error) {
src, err = goauth.DefaultTokenSource(ctx, proxy.SQLScope)
}
if err != nil {
return nil, err
return nil, nil, err
}

return oauth2.NewClient(ctx, src), nil
return oauth2.NewClient(ctx, src), src, nil
}

func stringList(s string) []string {
Expand Down Expand Up @@ -462,7 +470,7 @@ func main() {
}

ctx := context.Background()
client, err := authenticatedClient(ctx)
client, tokSrc, err := authenticatedClient(ctx)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -497,6 +505,8 @@ func main() {
IgnoreRegion: !*checkRegion,
UserAgent: userAgentFromVersionString(),
IPAddrTypeOpts: ipAddrTypeOptsInput,
EnableIAMLogin: *enableIAMLogin,
TokenSource: tokSrc,
}),
Conns: connset,
RefreshCfgThrottle: refreshCfgThrottle,
Expand Down
42 changes: 36 additions & 6 deletions proxy/certs/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (

"github.com/GoogleCloudPlatform/cloudsql-proxy/logging"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/util"
"golang.org/x/oauth2"
"google.golang.org/api/googleapi"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand Down Expand Up @@ -68,6 +69,12 @@ type RemoteOpts struct {

// IP address type options
IPAddrTypeOpts []string

// Enable IAM proxy db authentication
EnableIAMLogin bool

// Token source for token information used in cert creation
TokenSource oauth2.TokenSource
}

// NewCertSourceOpts returns a CertSource configured with the provided Opts.
Expand Down Expand Up @@ -105,7 +112,7 @@ func NewCertSourceOpts(c *http.Client, opts RemoteOpts) *RemoteCertSource {
}
}

return &RemoteCertSource{pkey, serv, !opts.IgnoreRegion, opts.IPAddrTypeOpts}
return &RemoteCertSource{pkey, serv, !opts.IgnoreRegion, opts.IPAddrTypeOpts, opts.EnableIAMLogin, opts.TokenSource}
}

// RemoteCertSource implements a CertSource, using Cloud SQL APIs to
Expand All @@ -123,6 +130,10 @@ type RemoteCertSource struct {
checkRegion bool
// a list of ip address types that users select
IPAddrTypes []string
// flag to enable IAM proxy db authentication
EnableIAMLogin bool
// token source for the token information used in cert creation
TokenSource oauth2.TokenSource
}

// Constants for backoffAPIRetry. These cause the retry logic to scale the
Expand Down Expand Up @@ -171,11 +182,21 @@ func (s *RemoteCertSource) Local(instance string) (ret tls.Certificate, err erro

p, r, n := util.SplitName(instance)
regionName := fmt.Sprintf("%s~%s", r, n)
req := s.serv.SslCerts.CreateEphemeral(p, regionName,
&sqladmin.SslCertsCreateEphemeralRequest{
PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"})),
},
)
pubKey := string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"}))
createEphemeralRequest := sqladmin.SslCertsCreateEphemeralRequest{
PublicKey: pubKey,
}
if s.EnableIAMLogin {
tok, e := s.TokenSource.Token()
if e != nil {
return ret, e
}
createEphemeralRequest = sqladmin.SslCertsCreateEphemeralRequest{
PublicKey: pubKey,
AccessToken: tok.AccessToken,
}
}
req := s.serv.SslCerts.CreateEphemeral(p, regionName, &createEphemeralRequest)

var data *sqladmin.SslCert
err = backoffAPIRetry("createEphemeral for", instance, func() error {
Expand Down Expand Up @@ -277,3 +298,12 @@ func (s *RemoteCertSource) Remote(instance string) (cert *x509.Certificate, addr

return c, ipAddrInUse, p + ":" + n, data.DatabaseVersion, err
}

// TokenExpiration returns the expiration time for token source associated with remote cert source.
func (s *RemoteCertSource) TokenExpiration() (ret time.Time, err error) {
tok, err := s.TokenSource.Token()
if err != nil {
return ret, err
}
return tok.Expiry, nil
}
11 changes: 10 additions & 1 deletion proxy/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ type CertSource interface {
Local(instance string) (tls.Certificate, error)
// Remote returns the instance's CA certificate, address, and name.
Remote(instance string) (cert *x509.Certificate, addr, name, version string, err error)
// TokenExpiration returns expiration time of the token information.
TokenExpiration() (time.Time, error)
}

// Client is a type to handle connecting to a Server. All fields are required
Expand Down Expand Up @@ -232,11 +234,18 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, vers
VerifyPeerCertificate: genVerifyPeerCertificateFunc(name, certs),
}

tokenExpiry, err := c.Certs.TokenExpiration()
if err != nil {
return "", nil, "", err
}
expire := mycert.Leaf.NotAfter
if expire.After(tokenExpiry) {
expire = tokenExpiry
}
now := time.Now()
timeToRefresh := expire.Sub(now) - refreshCfgBuffer
if timeToRefresh <= 0 {
err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", expire, now)
err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", now, expire)
logging.Errorf("ephemeral certificate (%+v) error: %v", mycert, err)
return "", nil, "", err
}
Expand Down
57 changes: 55 additions & 2 deletions proxy/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ type fakeCerts struct {
}

type blockingCertSource struct {
values map[string]*fakeCerts
validUntil time.Time
values map[string]*fakeCerts
validUntil time.Time
tokenExpire time.Time
}

func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
Expand All @@ -66,6 +67,10 @@ func (cs *blockingCertSource) Remote(instance string) (cert *x509.Certificate, a
return &x509.Certificate{}, "fake address", "fake name", "fake version", nil
}

func (cs *blockingCertSource) TokenExpiration() (ret time.Time, err error) {
return cs.tokenExpire, nil
}

func TestContextDialer(t *testing.T) {
b := &fakeCerts{}
c := &Client{
Expand All @@ -74,6 +79,7 @@ func TestContextDialer(t *testing.T) {
instance: b,
},
forever,
forever,
},
ContextDialer: func(context.Context, string, string) (net.Conn, error) {
return nil, errFakeDial
Expand All @@ -96,6 +102,7 @@ func TestClientCache(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down Expand Up @@ -123,6 +130,7 @@ func TestConcurrentRefresh(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down Expand Up @@ -164,6 +172,7 @@ func TestMaximumConnectionsCount(t *testing.T) {
certSource := blockingCertSource{
map[string]*fakeCerts{},
forever,
forever,
}
firstDialExited := make(chan struct{})
c := &Client{
Expand Down Expand Up @@ -225,6 +234,7 @@ func TestShutdownTerminatesEarly(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, nil
Expand Down Expand Up @@ -256,6 +266,49 @@ func TestRefreshTimer(t *testing.T) {
instance: b,
},
certCreated.Add(timeToExpire),
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
},
RefreshCfgThrottle: 20 * time.Millisecond,
RefreshCfgBuffer: time.Second,
}
// Call Dial to cache the cert.
if _, err := c.Dial(instance); err != errFakeDial {
t.Fatalf("Dial(%s) failed: %v", instance, err)
}
c.cacheL.Lock()
cfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}

time.Sleep(timeToExpire - time.Since(certCreated))
// Check if cert was refreshed in the background, without calling Dial again.
c.cacheL.Lock()
newCfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}
if !newCfg.lastRefreshed.After(cfg.lastRefreshed) {
t.Error("expected cert to be refreshed.")
}
}

func TestRefreshTimerTokenExpires(t *testing.T) {
timeToExpire := 5 * time.Second
b := &fakeCerts{}
certCreated := time.Now()
c := &Client{
Certs: &blockingCertSource{
map[string]*fakeCerts{
instance: b,
},
forever,
certCreated.Add(timeToExpire),
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down

0 comments on commit dc724da

Please sign in to comment.