Skip to content

Commit

Permalink
fix: ensure certificate expiration is correct (#659)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed Mar 5, 2021
1 parent f3bf3f9 commit 2fd2504
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 95 deletions.
36 changes: 13 additions & 23 deletions proxy/certs/certs.go
Expand Up @@ -196,18 +196,20 @@ func (s *RemoteCertSource) Local(instance string) (tls.Certificate, error) {
createEphemeralRequest := sqladmin.SslCertsCreateEphemeralRequest{
PublicKey: pubKey,
}
var tok *oauth2.Token
// If IAM login is enabled, add the OAuth2 token into the ephemeral
// certificate request.
if s.EnableIAMLogin {
tok, err := s.TokenSource.Token()
var tokErr error
tok, tokErr = s.TokenSource.Token()
if err != nil {
return tls.Certificate{}, err
return tls.Certificate{}, tokErr
}
// Always refresh the token to ensure its expiration is far enough in
// the future.
tok, err = refreshToken(s.TokenSource, tok)
if err != nil {
return tls.Certificate{}, err
tok, tokErr = refreshToken(s.TokenSource, tok)
if tokErr != nil {
return tls.Certificate{}, tokErr
}
createEphemeralRequest.AccessToken = tok.AccessToken
}
Expand All @@ -226,18 +228,19 @@ func (s *RemoteCertSource) Local(instance string) (tls.Certificate, error) {
if err != nil {
return tls.Certificate{}, fmt.Errorf("couldn't parse ephemeral certificate for instance %q: %v", instance, err)
}
if s.EnableIAMLogin {
// Adjust the certificate's expiration to be the earlier of tok.Expiry or c.NotAfter
if tok.Expiry.Before(c.NotAfter) {
c.NotAfter = tok.Expiry
}
}
return tls.Certificate{
Certificate: [][]byte{c.Raw},
PrivateKey: s.key,
Leaf: c,
}, nil
}

// IAMLoginEnabled reports whether IAM login has been enabled.
func (s *RemoteCertSource) IAMLoginEnabled() bool {
return s.EnableIAMLogin
}

func parseCert(pemCert string) (*x509.Certificate, error) {
bl, _ := pem.Decode([]byte(pemCert))
if bl == nil {
Expand Down Expand Up @@ -318,16 +321,3 @@ func (s *RemoteCertSource) Remote(instance string) (cert *x509.Certificate, addr

return c, ipAddrInUse, p + ":" + n, data.DatabaseVersion, err
}

// TokenExpiration returns the expiration time for token source associated with remote cert source.
func (s *RemoteCertSource) TokenExpiration() (time.Time, error) {
// if no token is being used, return zero for expiration
if s.TokenSource == nil {
return time.Time{}, nil
}
tok, err := s.TokenSource.Token()
if err != nil {
return time.Time{}, err
}
return tok.Expiry, nil
}
13 changes: 0 additions & 13 deletions proxy/proxy/client.go
Expand Up @@ -50,15 +50,11 @@ type Conn struct {

// CertSource is how a Client obtains various certificates required for operation.
type CertSource interface {
// IAMLoginEnabled reports whether IAM Login has been enabled.
IAMLoginEnabled() bool
// Local returns a certificate that can be used to authenticate with the
// provided instance.
Local(instance string) (tls.Certificate, error)
// Remote returns the instance's CA certificate, address, and name.
Remote(instance string) (cert *x509.Certificate, addr, name, version string, err error)
// TokenExpiration returns expiration time of the token information.
TokenExpiration() (time.Time, error)
}

// Client is a type to handle connecting to a Server. All fields are required
Expand Down Expand Up @@ -237,15 +233,6 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, vers
}

certExpiration := mycert.Leaf.NotAfter
if c.Certs.IAMLoginEnabled() {
tokenExpiration, tokErr := c.Certs.TokenExpiration()
if tokErr != nil {
return "", nil, "", tokErr
}
if certExpiration.After(tokenExpiration) {
certExpiration = tokenExpiration
}
}
now := time.Now()
timeToRefresh := certExpiration.Sub(now) - refreshCfgBuffer
if timeToRefresh <= 0 {
Expand Down
61 changes: 2 additions & 59 deletions proxy/proxy/client_test.go
Expand Up @@ -41,9 +41,8 @@ type fakeCerts struct {
}

type blockingCertSource struct {
values map[string]*fakeCerts
validUntil time.Time
tokenExpire time.Time
values map[string]*fakeCerts
validUntil time.Time
}

func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
Expand All @@ -67,14 +66,6 @@ func (cs *blockingCertSource) Remote(instance string) (cert *x509.Certificate, a
return &x509.Certificate{}, "fake address", "fake name", "fake version", nil
}

func (cs *blockingCertSource) TokenExpiration() (ret time.Time, err error) {
return cs.tokenExpire, nil
}

func (cs *blockingCertSource) IAMLoginEnabled() bool {
return true
}

func TestContextDialer(t *testing.T) {
b := &fakeCerts{}
c := &Client{
Expand All @@ -83,7 +74,6 @@ func TestContextDialer(t *testing.T) {
instance: b,
},
forever,
forever,
},
ContextDialer: func(context.Context, string, string) (net.Conn, error) {
return nil, errFakeDial
Expand All @@ -106,7 +96,6 @@ func TestClientCache(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down Expand Up @@ -134,7 +123,6 @@ func TestConcurrentRefresh(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down Expand Up @@ -176,7 +164,6 @@ func TestMaximumConnectionsCount(t *testing.T) {
certSource := blockingCertSource{
map[string]*fakeCerts{},
forever,
forever,
}
firstDialExited := make(chan struct{})
c := &Client{
Expand Down Expand Up @@ -238,7 +225,6 @@ func TestShutdownTerminatesEarly(t *testing.T) {
instance: b,
},
forever,
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, nil
Expand Down Expand Up @@ -270,49 +256,6 @@ func TestRefreshTimer(t *testing.T) {
instance: b,
},
certCreated.Add(timeToExpire),
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
},
RefreshCfgThrottle: 20 * time.Millisecond,
RefreshCfgBuffer: time.Second,
}
// Call Dial to cache the cert.
if _, err := c.Dial(instance); err != errFakeDial {
t.Fatalf("Dial(%s) failed: %v", instance, err)
}
c.cacheL.Lock()
cfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}

time.Sleep(timeToExpire - time.Since(certCreated))
// Check if cert was refreshed in the background, without calling Dial again.
c.cacheL.Lock()
newCfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}
if !newCfg.lastRefreshed.After(cfg.lastRefreshed) {
t.Error("expected cert to be refreshed.")
}
}

func TestRefreshTimerTokenExpires(t *testing.T) {
timeToExpire := 5 * time.Second
b := &fakeCerts{}
certCreated := time.Now()
c := &Client{
Certs: &blockingCertSource{
map[string]*fakeCerts{
instance: b,
},
forever,
certCreated.Add(timeToExpire),
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
Expand Down

0 comments on commit 2fd2504

Please sign in to comment.