diff --git a/transport/cert/default_cert.go b/transport/cert/default_cert.go index c03af65fd73..141ae457936 100644 --- a/transport/cert/default_cert.go +++ b/transport/cert/default_cert.go @@ -14,6 +14,7 @@ package cert import ( "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -23,6 +24,7 @@ import ( "os/user" "path/filepath" "sync" + "time" ) const ( @@ -30,10 +32,18 @@ const ( metadataFile = "context_aware_metadata.json" ) +// defaultCertData holds all the variables pertaining to +// the default certficate source created by DefaultSource. +type defaultCertData struct { + once sync.Once + source Source + err error + cachedCertMutex sync.Mutex + cachedCert *tls.Certificate +} + var ( - defaultSourceOnce sync.Once - defaultSource Source - defaultSourceErr error + defaultCert defaultCertData ) // Source is a function that can be passed into crypto/tls.Config.GetClientCertificate. @@ -44,10 +54,10 @@ type Source func(*tls.CertificateRequestInfo) (*tls.Certificate, error) // // If that file does not exist, a nil source is returned. func DefaultSource() (Source, error) { - defaultSourceOnce.Do(func() { - defaultSource, defaultSourceErr = newSecureConnectSource() + defaultCert.once.Do(func() { + defaultCert.source, defaultCert.err = newSecureConnectSource() }) - return defaultSource, defaultSourceErr + return defaultCert.source, defaultCert.err } type secureConnectSource struct { @@ -95,7 +105,11 @@ func validateMetadata(metadata secureConnectMetadata) error { } func (s *secureConnectSource) getClientCertificate(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - // TODO(cbro): consider caching valid certificates rather than exec'ing every time. + defaultCert.cachedCertMutex.Lock() + defer defaultCert.cachedCertMutex.Unlock() + if defaultCert.cachedCert != nil && !isCertificateExpired(defaultCert.cachedCert) { + return defaultCert.cachedCert, nil + } command := s.metadata.Cmd data, err := exec.Command(command[0], command[1:]...).Output() if err != nil { @@ -106,5 +120,18 @@ func (s *secureConnectSource) getClientCertificate(info *tls.CertificateRequestI if err != nil { return nil, err } + defaultCert.cachedCert = &cert return &cert, nil } + +// isCertificateExpired returns true if the given cert is expired or invalid. +func isCertificateExpired(cert *tls.Certificate) bool { + if len(cert.Certificate) == 0 { + return true + } + parsed, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return true + } + return time.Now().After(parsed.NotAfter) +} diff --git a/transport/cert/default_cert_test.go b/transport/cert/default_cert_test.go index 0ec3c44b144..2d7e333f332 100644 --- a/transport/cert/default_cert_test.go +++ b/transport/cert/default_cert_test.go @@ -5,31 +5,34 @@ package cert import ( + "bytes" "testing" ) func TestGetClientCertificateSuccess(t *testing.T) { + defaultCert.cachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/testcert.pem"}}} cert, err := source.getClientCertificate(nil) if err != nil { t.Error(err) } if cert.Certificate == nil { - t.Error("want non-nil cert, got nil") + t.Error("getClientCertificate: want non-nil Certificate, got nil") } if cert.PrivateKey == nil { - t.Error("want non-nil PrivateKey, got nil") + t.Error("getClientCertificate: want non-nil PrivateKey, got nil") } } func TestGetClientCertificateFailure(t *testing.T) { + defaultCert.cachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat"}}} _, err := source.getClientCertificate(nil) if err == nil { t.Error("Expecting error.") } if got, want := err.Error(), "tls: failed to find any PEM data in certificate input"; got != want { - t.Errorf("getClientCertificate, want %v err, got %v", want, got) + t.Errorf("getClientCertificate: want %v err, got %v", want, got) } } @@ -51,3 +54,54 @@ func TestValidateMetadataFailure(t *testing.T) { t.Errorf("validateMetadata: want %v err, got %v", want, got) } } + +func TestIsCertificateExpiredTrue(t *testing.T) { + defaultCert.cachedCert = nil + source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/testcert.pem"}}} + cert, err := source.getClientCertificate(nil) + if err != nil { + t.Error(err) + } + if !isCertificateExpired(cert) { + t.Error("isCertificateExpired: want true, got false") + } +} + +func TestIsCertificateExpiredFalse(t *testing.T) { + defaultCert.cachedCert = nil + source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/nonexpiringtestcert.pem"}}} + cert, err := source.getClientCertificate(nil) + if err != nil { + t.Error(err) + } + if isCertificateExpired(cert) { + t.Error("isCertificateExpired: want false, got true") + } +} + +func TestCertificateCaching(t *testing.T) { + defaultCert.cachedCert = nil + source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/nonexpiringtestcert.pem"}}} + cert, err := source.getClientCertificate(nil) + if err != nil { + t.Error(err) + } + if cert == nil { + t.Error("getClientCertificate: want non-nil cert, got nil") + } + if defaultCert.cachedCert == nil { + t.Error("getClientCertificate: want non-nil defaultSourceCachedCert, got nil") + } + + source = secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/testcert.pem"}}} + cert, err = source.getClientCertificate(nil) + if err != nil { + t.Error(err) + } + if !bytes.Equal(cert.Certificate[0], defaultCert.cachedCert.Certificate[0]) { + t.Error("getClientCertificate: want cached Certificate, got different Certificate") + } + if cert.PrivateKey != defaultCert.cachedCert.PrivateKey { + t.Error("getClientCertificate: want cached PrivateKey, got different PrivateKey") + } +} diff --git a/transport/cert/testdata/nonexpiringtestcert.pem b/transport/cert/testdata/nonexpiringtestcert.pem new file mode 100644 index 00000000000..43260a9c7ef --- /dev/null +++ b/transport/cert/testdata/nonexpiringtestcert.pem @@ -0,0 +1,50 @@ +-----BEGIN CERTIFICATE----- +MIIDujCCAqICCQD+yrCYuiC8djANBgkqhkiG9w0BAQsFADCBnTELMAkGA1UEBhMC +VVMxEzARBgNVBAgMCldhc2hpbmd0b24xETAPBgNVBAcMCEtpcmtsYW5kMQ8wDQYD +VQQKDAZHb29nbGUxDjAMBgNVBAsMBUNsb3VkMRswGQYDVQQDDBJnb29nbGVhcGlz +dGVzdC5jb20xKDAmBgkqhkiG9w0BCQEWGWdvb2dsZWFwaXN0ZXN0QGdvb2dsZS5j +b20wIBcNMjAxMDIzMjEyNTU1WhgPMjEyMDA5MjkyMTI1NTVaMIGdMQswCQYDVQQG +EwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjERMA8GA1UEBwwIS2lya2xhbmQxDzAN +BgNVBAoMBkdvb2dsZTEOMAwGA1UECwwFQ2xvdWQxGzAZBgNVBAMMEmdvb2dsZWFw +aXN0ZXN0LmNvbTEoMCYGCSqGSIb3DQEJARYZZ29vZ2xlYXBpc3Rlc3RAZ29vZ2xl +LmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKnzFX97VP4XSQ8l +4/Z08eajnAiGpK+ZQTV9k7Qy2tpo5+iFFiL0JLGP9+GRILuDGQufYlPLDhLLho9V +YXIR9UOhhapmQJqUAUFhvZlBEixLxcfwa2LecNiJ6+8gvJCoRbrPIrz91crY+t59 +aY/09vmsCbFDX8d8WWVnww4285dfKwE2IDinqZ1VuT4zYR66f4lL8qj6t5TXeGAW +Nkd6O3yuAVO8RLiXBRRABP5217mq0jNL+kJUormzhuKgvP+oxRsi56XHPGiq7l2e +54PS/cqa4atjqbhZI1xV27y0sVr0/CmBsfeM3TwLbCSjv7r0lCz64xtCJa8R45MA +22or9z8CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAnwLY9qBIQ2IYDLNLx16av8C6 +9vca8gOzMpYZ4UKHDN+Qk2CidpmFamXWDXqmOLNZYlmEoGY5n8zg8rwYK+vauqwb +o94HzxLmQcQ4kmAI4xJnMqKZAbukRdWw2GCuvdVqG4Osngz4WBIHrAsl4btogdJy +ACU/YUA3K0tLjwe6wUYYF6eu5sb6zJkF4cfLpqECWtF9XG6nkJbo2GomHFuHm+6t +gOj7YiqU/cHCyU4FQF9/2jDLzFHxt2Bb30zi602YjuIZhYp35ktI66XwsE4kFmwo +iHCEG0fXMNN7OMFmNg2YVLhaHxrQNFxbzOQdfKg2gi2qzX4AiCo1tx5LCg6aGw== +-----END CERTIFICATE----- +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCp8xV/e1T+F0kP +JeP2dPHmo5wIhqSvmUE1fZO0MtraaOfohRYi9CSxj/fhkSC7gxkLn2JTyw4Sy4aP +VWFyEfVDoYWqZkCalAFBYb2ZQRIsS8XH8Gti3nDYievvILyQqEW6zyK8/dXK2Pre +fWmP9Pb5rAmxQ1/HfFllZ8MONvOXXysBNiA4p6mdVbk+M2Eeun+JS/Ko+reU13hg +FjZHejt8rgFTvES4lwUUQAT+dte5qtIzS/pCVKK5s4bioLz/qMUbIuelxzxoqu5d +nueD0v3KmuGrY6m4WSNcVdu8tLFa9PwpgbH3jN08C2wko7+69JQs+uMbQiWvEeOT +ANtqK/c/AgMBAAECggEAYjeE3hb1yJ7Gb0WzmDR/tI4rV9YQiRcl03cOjJ6zUnQ8 +SmnXoD2+kwuj8y1/YD7kk436MnjwWjZbPqzWUylDuGE5sX/EqFEO5K1K+K3dhdII +rIMqXIo3Zz1WJ+2gbG2DVvHsnpKIIuIBIeISxsqIjUQ6mcJZMR2RQISV+roRTxIU +1Ga0xWrExcKL8FSjs8ih0DWU4vHoSYH4DFXB1/ViyLn+DEljnOlo8Q+7DG0uQQnX +ixfYMbXSJcZxFm1iwuZv8SESjqbTsogNny5Wi6H9Vp0JFasAPUjnc+QuD/U1HTDn +PCX3eBNMcxvVJDhu/7nnO7kcU1Cx0gJeN+1bklrAcQKBgQDURl0Ac8N94I82n4Lg +wjGLWj3AMxSEHNcZuomCvoYcLTmJdd2tOnunXhh1jANnx6q8P8aR5fiTthokIUdx +bOmWwFAbP6kMe0WFWQhXjX4mXLRmJ4mWayWCE7hstnDb3/Fr7LuJeg5L3OU4ss3b +j4UvhtuQ9Qh8piVhKwFkQh3tOQKBgQDM9NSkRDVW3Q37lMUdyn8B2FBF78e/9ck+ +5bHOs52G2hXJ4tyLYNjBoLXPpMp9VWRTXxUaii+gHSa4DkHTkFwIg34hLgrCX7Gc +a0rldvkpX0xWSANfvO9bvavPgKnLSP8j3mjDiwqJuy3L5TBThIHDvPV9F/akpLne +bdcywa4ANwKBgHlvAzcGAniZJPRXjfRrwxH3/slbr0nggcDLMG0l9uxZhse3MKgv +g5t8PbvI7A3LcEWeqka+a1R84Tl3/DnL11kRDQJ5iYiFYIDnLNmBLQBfGigySAhP +pTZjd6ZhO/DcjGx0EdiUhWcqp8qmpxMKaGOG30ZulntQRKPwiSxEkoApAoGBAJ1o +h4ulawXMfnmyt3T62XJ0TKp5zoKqZSYuSNIEdr5j7goAdvuApNiI8jmISY/arlOt +mcqpSIyC9wKyyHGQ1G4hdxRKhS7lScZlTL9REWlp7HnzksvLklV2JWcXXNBovrMw +lGth9PT00eZfni72fKb1D+FEL0Qh0zJ2T6mGwHkfAoGAMOy8bbyCASCYG9MYzqaP +Lf+AKKNEYUvUGspyJUqu5ERudr5stmei6PrchxFiKjm5Qg7B/M1VnKsCtL9kk8Z9 +lHgwU5mOATZvd9k/5oiuRxzXyrWqFoT/mivI2rZE+g5cLTLytCTnyLjHm5B/aTy8 +1AmbAh5hvWYs+EMKZAlQ5GM= +-----END PRIVATE KEY-----