From 34feedac6ddf6f97e2aa17d4a02670c93420dd7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Elsd=C3=B6rfer?= Date: Sun, 2 Feb 2014 02:25:04 +0100 Subject: [PATCH] router: Add support for serving HTTPS with SNI. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update client example to support adding domains with certificates. Signed-off-by: Michael Elsdörfer --- Godeps | 6 +- examples/client/client.go | 97 ++++++++++++++++++++++- http.go | 162 +++++++++++++++++++++++++++++++------- rpc.go | 2 +- server.go | 4 +- 5 files changed, 238 insertions(+), 33 deletions(-) diff --git a/Godeps b/Godeps index 50836f80af..ae6d10f069 100644 --- a/Godeps +++ b/Godeps @@ -18,6 +18,10 @@ { "ImportPath": "github.com/flynn/rpcplus", "Rev": "68355e62cadb10d431f81b507ee30805b99022cd" - } + }, + { + "ImportPath": "github.com/inconshreveable/go-vhost", + "Rev": "abc5f6a77596abba0c71c07cfe22b6ccfd0a7d2d" + } ] } diff --git a/examples/client/client.go b/examples/client/client.go index dbf4aa3c98..18da5f6219 100644 --- a/examples/client/client.go +++ b/examples/client/client.go @@ -1,21 +1,114 @@ package main import ( + "bytes" + "encoding/pem" + "flag" + "fmt" + "io/ioutil" "log" "os" "github.com/flynn/rpcplus" "github.com/flynn/strowger/types" + "strings" ) func main() { - client, err := rpcplus.DialHTTP("tcp", "localhost:1115") + rpcAddr := flag.String("rpc", "localhost:1115", "strowger RPC address to connect to") + certPath := flag.String("cert", "", "path to DER encoded certificate for SSL, - for stdin") + keyPath := flag.String("key", "", "path to DER encoded private key for SSL, - for stdin") + flag.Parse() + if len(flag.Args()) != 2 { + fmt.Fprintf(os.Stderr, "Usage: %s [flags] domain service-name\n", os.Args[0]) + flag.PrintDefaults() + os.Exit(64) + } + + domain, serviceName := flag.Arg(0), flag.Arg(1) + + var stdin []byte + var err error + if *certPath == "-" || *keyPath == "-" { + stdin, err = ioutil.ReadAll(os.Stdin) + if err != nil { + log.Fatal("Failed to read from stdin: ", err) + } + } + + tlsCert, err := readCert(*certPath, stdin) + if err != nil { + return + } + tlsKey, err := readKey(*keyPath, stdin) + if err != nil { + return + } + + client, err := rpcplus.DialHTTP("tcp", *rpcAddr) if err != nil { log.Fatal(err) } - err = client.Call("Router.AddFrontend", &strowger.Config{Service: "example-server", HTTPDomain: os.Args[1]}, &struct{}{}) + frontendConfig := &strowger.Config{ + Service: serviceName, + HTTPDomain: domain, + HTTPSCert: tlsCert, + HTTPSKey: tlsKey, + } + err = client.Call("Router.AddFrontend", frontendConfig, &struct{}{}) if err != nil { log.Fatal(err) } } + +func readCert(path string, stdin []byte) ([]byte, error) { + if path == "-" { + var buffer bytes.Buffer + var derBlock *pem.Block + for { + derBlock, stdin = pem.Decode(stdin) + if derBlock == nil { + break + } + if derBlock.Type == "CERTIFICATE" { + buffer.Write(pem.EncodeToMemory(derBlock)) + } + } + if buffer.Len() > 0 { + return buffer.Bytes(), nil + } + log.Fatal("No certificate PEM blocks found in stdin") + } + return readFile(path) +} + +func readKey(path string, stdin []byte) ([]byte, error) { + if path == "-" { + var derBlock *pem.Block + for { + derBlock, stdin = pem.Decode(stdin) + if derBlock == nil { + break + } + if strings.Contains(derBlock.Type, "PRIVATE KEY") { + return pem.EncodeToMemory(derBlock), nil + } + } + log.Fatal("No private key PEM blocks found in stdin") + } + return readFile(path) +} + +func readFile(path string) ([]byte, error) { + if path == "" { + return nil, nil + } + + contents, err := ioutil.ReadFile(path) + if err != nil { + log.Printf("Failed to open %s: %s", path, err) + return nil, err + } + return contents, nil +} diff --git a/http.go b/http.go index 4d4f5a701d..b9886463d8 100644 --- a/http.go +++ b/http.go @@ -19,6 +19,7 @@ import ( "github.com/coreos/go-etcd/etcd" "github.com/flynn/go-discoverd" + "github.com/inconshreveable/go-vhost" ) type HTTPFrontend struct { @@ -27,7 +28,7 @@ type HTTPFrontend struct { TLSConfig *tls.Config mtx sync.RWMutex - domains map[string]*httpServer + domains map[string]*domain services map[string]*httpServer etcdPrefix string @@ -46,28 +47,29 @@ type DiscoverdClient interface { NewServiceSet(string) (discoverd.ServiceSet, error) } -func NewHTTPFrontend(addr string, etcdc EtcdClient, discoverdc DiscoverdClient) *HTTPFrontend { +func NewHTTPFrontend(addr, tlsAddr string, etcdc EtcdClient, discoverdc DiscoverdClient) *HTTPFrontend { return &HTTPFrontend{ Addr: addr, + TLSAddr: tlsAddr, etcd: etcdc, etcdPrefix: "/strowger/http/", discoverd: discoverdc, - domains: make(map[string]*httpServer), + domains: make(map[string]*domain), services: make(map[string]*httpServer), } } -func (s *HTTPFrontend) AddHTTPDomain(domain string, service string, certs [][]byte, key []byte) error { - return s.addDomain(domain, service, true) +func (s *HTTPFrontend) AddHTTPDomain(domain string, service string, cert []byte, key []byte) error { + return s.addDomain(domain, service, cert, key, true) } var ErrDomainExists = errors.New("strowger: domain exists with different service") -func (s *HTTPFrontend) addDomain(domain string, service string, persist bool) error { +func (s *HTTPFrontend) addDomain(name string, service string, cert []byte, key []byte, persist bool) error { s.mtx.Lock() defer s.mtx.Unlock() - if server, ok := s.domains[domain]; ok { + if server, ok := s.domains[name]; ok { if server.name != service { return ErrDomainExists } @@ -82,19 +84,34 @@ func (s *HTTPFrontend) addDomain(domain string, service string, persist bool) er } server = &httpServer{name: service, services: services} } + + var keypair *tls.Certificate + if cert != nil && key != nil { + created, err := tls.X509KeyPair(cert, key) + if err != nil { + return err + } + keypair = &created + } + domainCfg := &domain{name: name, server: server, keypair: keypair} + if persist { - if _, err := s.etcd.Create(s.etcdPrefix+domain+"/service", service, 0); err != nil { + if _, err := s.etcd.Create(s.etcdPrefix+name+"/service", service, 0); err != nil { + return err + } + if _, err := s.etcd.Create(s.etcdPrefix+name+"/tls/key", string(key), 0); err != nil { + return err + } + if _, err := s.etcd.Create(s.etcdPrefix+name+"/tls/cert", string(cert), 0); err != nil { return err } } - // TODO: set cert/key data if provided server.refs++ - s.domains[domain] = server + s.domains[name] = domainCfg s.services[service] = server - // TODO: TLS config - log.Println("Add service", service, "to domain", domain) + log.Println("Add service", service, "to domain", name) return nil } @@ -102,7 +119,7 @@ func (s *HTTPFrontend) addDomain(domain string, service string, persist bool) er func (s *HTTPFrontend) RemoveHTTPDomain(domain string) { s.mtx.Lock() defer s.mtx.Unlock() - server := s.domains[domain] + server := s.domains[domain].server if server == nil { return } @@ -115,6 +132,24 @@ func (s *HTTPFrontend) RemoveHTTPDomain(domain string) { // TODO: persist } +// Given a recursive node structure from etcd, find particular +// values. "name" may contain slashes ("foo/bar") +func etcdFindChild(node *etcd.Node, name string) *etcd.Node { + parts := strings.Split(name, "/") +outer: + for _, part := range parts { + for _, childNode := range node.Nodes { + if path.Base(childNode.Key) == part { + node = &childNode + continue outer + } + } + return nil + } + + return node +} + func (s *HTTPFrontend) syncDatabase() { var since uint64 data, err := s.etcd.Get(s.etcdPrefix, false, true) @@ -132,11 +167,29 @@ func (s *HTTPFrontend) syncDatabase() { continue } domain := path.Base(node.Key) - serviceRes, err := s.etcd.Get(node.Key+"/service", false, false) + + // Recursively get the whole service structure + serviceRes, err := s.etcd.Get(node.Key, false, true) if err != nil { log.Fatal(err) } - if err := s.addDomain(domain, serviceRes.Node.Value, false); err != nil { + + // Find the individual (service name, ssl info) + serviceNode := etcdFindChild(serviceRes.Node, "service") + if serviceNode == nil { + continue + } + var cert, key []byte + certNode := etcdFindChild(serviceRes.Node, "tls/cert") + if certNode != nil && certNode.Value != "" { + cert = []byte(certNode.Value) + } + keyNode := etcdFindChild(serviceRes.Node, "tls/key") + if keyNode != nil && keyNode.Value != "" { + key = []byte(keyNode.Value) + } + + if err := s.addDomain(domain, serviceNode.Value, cert, key, false); err != nil { log.Fatal(err) } } @@ -147,6 +200,7 @@ watch: // TODO: store stop go s.etcd.Watch(s.etcdPrefix, since, false, stream, stop) for res := range stream { + // XXX support changes to cert/key if res.Node.Dir || path.Base(res.Node.Key) == "service" { continue } @@ -155,7 +209,7 @@ watch: _, exists := s.domains[domain] s.mtx.Unlock() if !exists { - if err := s.addDomain(domain, res.Node.Value, false); err != nil { + if err := s.addDomain(domain, res.Node.Value, nil, nil, false); err != nil { // TODO: log error } } @@ -176,7 +230,7 @@ func (s *HTTPFrontend) serve() { // TODO: log error break } - go s.handle(conn) + go s.handle(conn, false) } } @@ -192,10 +246,20 @@ func (s *HTTPFrontend) serveTLS() { // TODO: log error break } - go s.handle(conn) + go s.handle(conn, true) } } +func (s *HTTPFrontend) findBackendForHost(host string) *domain { + s.mtx.RLock() + defer s.mtx.RUnlock() + // TODO: handle wildcard domains + log.Println(s.domains, host) + backend := s.domains[host] + log.Println("Backend match:", backend) + return backend +} + func fail(sc *httputil.ServerConn, req *http.Request, code int, msg string) { resp := &http.Response{ StatusCode: code, @@ -208,8 +272,40 @@ func fail(sc *httputil.ServerConn, req *http.Request, code int, msg string) { sc.Write(req, resp) } -func (s *HTTPFrontend) handle(conn net.Conn) { +func (s *HTTPFrontend) handle(conn net.Conn, isTLS bool) { defer conn.Close() + + var backend *domain + + // For TLS, use the SNI hello to determine the domain. + // At this stage, if we don't find a match, we simply + // close the connection down. + if isTLS { + // Parse out host via SNI first + vhostConn, err := vhost.TLS(conn) + if err != nil { + log.Println("Failed to decode TLS connection", err) + return + } + host := vhostConn.Host() + log.Println("SNI host is:", host) + + // Find a backend for the key + backend = s.findBackendForHost(host) + if backend == nil { + return + } + if backend.keypair == nil { + log.Println("Canot serve TLS, no certificate defined for this domain") + return + } + + // Init a TLS decryptor + tlscfg := &tls.Config{Certificates: []tls.Certificate{*backend.keypair}} + conn = tls.Server(vhostConn, tlscfg) + } + + // Decode the first request from the connection sc := httputil.NewServerConn(conn, nil) req, err := sc.Read() if err != nil { @@ -219,19 +315,29 @@ func (s *HTTPFrontend) handle(conn net.Conn) { return } - s.mtx.RLock() - // TODO: handle wildcard domains - backend := s.domains[req.Host] - s.mtx.RUnlock() - log.Println(req, backend) + // If we do not have a backend yet (unencrypted connection), + // look at the host header to find one or 404 out. if backend == nil { - fail(sc, req, 404, "Not Found") - return + backend = s.findBackendForHost(req.Host) + log.Println(req, backend) + if backend == nil { + fail(sc, req, 404, "Not Found") + return + } } - _, tls := conn.(*tls.Conn) - backend.handle(req, sc, tls) + + backend.server.handle(req, sc, isTLS) +} + +// A domain served by a frontend, associated TLS certs, +// and link to backend service set. +type domain struct { + name string + keypair *tls.Certificate + server *httpServer } +// A service definition: name, and set of backends. type httpServer struct { name string services discoverd.ServiceSet diff --git a/rpc.go b/rpc.go index 9afdb4feeb..55e8d7ad39 100644 --- a/rpc.go +++ b/rpc.go @@ -13,7 +13,7 @@ type Router struct { func (r *Router) AddFrontend(config *strowger.Config, res *struct{}) error { switch config.Type { case strowger.FrontendHTTP: - if err := r.s.AddHTTPDomain(config.HTTPDomain, config.Service, nil, nil); err != nil { + if err := r.s.AddHTTPDomain(config.HTTPDomain, config.Service, config.HTTPSCert, config.HTTPSKey); err != nil { return err } default: diff --git a/server.go b/server.go index 180e8ae204..dbcebbaef1 100644 --- a/server.go +++ b/server.go @@ -18,6 +18,7 @@ type Server struct { func (s *Server) ListenAndServe(quit <-chan struct{}) { go s.HTTPFrontend.serve() + go s.HTTPFrontend.serveTLS() go s.HTTPFrontend.syncDatabase() <-quit // TODO: unregister from service discovery @@ -27,6 +28,7 @@ func (s *Server) ListenAndServe(quit <-chan struct{}) { func main() { rpcAddr := flag.String("rpcaddr", ":1115", "rpc listen address") httpAddr := flag.String("httpaddr", ":8080", "http frontend listen address") + httpsAddr := flag.String("httpsaddr", ":4433", "https frontend listen address") flag.Parse() // Will use DISCOVERD environment variable @@ -42,7 +44,7 @@ func main() { } var s Server - s.HTTPFrontend = NewHTTPFrontend(*httpAddr, etcd.NewClient(etcdAddr), d) + s.HTTPFrontend = NewHTTPFrontend(*httpAddr, *httpsAddr, etcd.NewClient(etcdAddr), d) rpc.Register(&Router{s}) rpc.HandleHTTP() go http.ListenAndServe(*rpcAddr, nil)