diff --git a/app.go b/app.go index 2a993387..1a7b0082 100644 --- a/app.go +++ b/app.go @@ -51,5 +51,8 @@ func main() { engine, closeable := router.Create(db, vInfo, conf) defer closeable() - runner.Run(engine, conf) + if err := runner.Run(engine, conf); err != nil { + fmt.Println("Server error: ", err) + os.Exit(1) + } } diff --git a/go.mod b/go.mod index 89828a43..e777edc6 100644 --- a/go.mod +++ b/go.mod @@ -46,7 +46,6 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sync v0.6.0 golang.org/x/sys v0.12.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/protobuf v1.30.0 // indirect diff --git a/go.sum b/go.sum index 19d99cc4..8c6cd01f 100644 --- a/go.sum +++ b/go.sum @@ -158,8 +158,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/router/router.go b/router/router.go index 99268bbd..9ba81d86 100644 --- a/router/router.go +++ b/router/router.go @@ -30,26 +30,26 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co g.Use(gin.LoggerWithFormatter(logFormatter), gin.Recovery(), gerror.Handler(), location.Default()) g.NoRoute(gerror.NotFound()) - if conf.Server.SSL.Enabled != nil && conf.Server.SSL.RedirectToHTTPS != nil && *conf.Server.SSL.Enabled && *conf.Server.SSL.RedirectToHTTPS { + if *conf.Server.SSL.Enabled && *conf.Server.SSL.RedirectToHTTPS { g.Use(func(ctx *gin.Context) { - if ctx.Request.TLS == nil { - if ctx.Request.Method != http.MethodGet && ctx.Request.Method != http.MethodHead { - ctx.Data(http.StatusBadRequest, "text/plain; charset=utf-8", []byte("Use HTTPS")) - ctx.Abort() - return - } - host := ctx.Request.Host - if strings.Contains(host, ":") { - host = host[:strings.Index(host, ":")] - } - if conf.Server.SSL.Port != 443 { - host = fmt.Sprintf("%s:%d", host, conf.Server.SSL.Port) - } - ctx.Redirect(http.StatusFound, fmt.Sprintf("https://%s%s", host, ctx.Request.RequestURI)) + if ctx.Request.TLS != nil { + ctx.Next() + return + } + if ctx.Request.Method != http.MethodGet && ctx.Request.Method != http.MethodHead { + ctx.Data(http.StatusBadRequest, "text/plain; charset=utf-8", []byte("Use HTTPS")) ctx.Abort() return } - ctx.Next() + host := ctx.Request.Host + if idx := strings.LastIndex(host, ":"); idx != -1 { + host = host[:idx] + } + if conf.Server.SSL.Port != 443 { + host = fmt.Sprintf("%s:%d", host, conf.Server.SSL.Port) + } + ctx.Redirect(http.StatusFound, fmt.Sprintf("https://%s%s", host, ctx.Request.RequestURI)) + ctx.Abort() }) } streamHandler := stream.New( diff --git a/runner/runner.go b/runner/runner.go index f6b420a2..c4ea7bdc 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -14,113 +14,86 @@ import ( "github.com/gotify/server/v2/config" "golang.org/x/crypto/acme/autocert" - "golang.org/x/sync/errgroup" ) // Run starts the http server and if configured a https server. -func Run(router http.Handler, conf *config.Configuration) { - ctx, cancel := context.WithCancel(context.Background()) - go func() { - done := make(chan os.Signal, 1) - signal.Notify(done, os.Interrupt, syscall.SIGTERM) - <-done - cancel() - }() +func Run(router http.Handler, conf *config.Configuration) error { + shutdown := make(chan error) + go doShutdownOnSignal(shutdown) + + httpListener, err := startListening("plain connection", conf.Server.ListenAddr, conf.Server.Port, conf.Server.KeepAlivePeriodSeconds) + if err != nil { + return err + } + defer httpListener.Close() + s := &http.Server{Handler: router} - g, gCtx := errgroup.WithContext(ctx) if *conf.Server.SSL.Enabled { if *conf.Server.SSL.LetsEncrypt.Enabled { - certManager := autocert.Manager{ - Prompt: func(tosURL string) bool { return *conf.Server.SSL.LetsEncrypt.AcceptTOS }, - HostPolicy: autocert.HostWhitelist(conf.Server.SSL.LetsEncrypt.Hosts...), - Cache: autocert.DirCache(conf.Server.SSL.LetsEncrypt.Cache), - } - s.Handler = certManager.HTTPHandler(s.Handler) - s.TLSConfig = &tls.Config{GetCertificate: certManager.GetCertificate} + applyLetsEncrypt(s, conf) } - g.Go(func() error { - return runTLS(s, conf) - }) - } - g.Go(func() error { - return run(s, conf) - }) - g.Go(func() error { - <-gCtx.Done() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return s.Shutdown(ctx) - }) - if err := g.Wait(); err != nil { - fmt.Println("Error while running the server", err) - os.Exit(1) + + httpsListener, err := startListening("TLS connection", conf.Server.SSL.ListenAddr, conf.Server.SSL.Port, conf.Server.KeepAlivePeriodSeconds) + if err != nil { + return err + } + defer httpsListener.Close() + + go func() { + doShutdown(shutdown, s.ServeTLS(httpsListener, conf.Server.SSL.CertFile, conf.Server.SSL.CertKey)) + }() } + go func() { + doShutdown(shutdown, s.Serve(httpListener)) + }() + + err = <-shutdown + fmt.Println("Shutting down:", err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.Shutdown(ctx) } -func run(s *http.Server, conf *config.Configuration) error { - network, addr := listenAddrParse(conf.Server.ListenAddr, conf.Server.Port) - l, err := startListening(network, addr, conf.Server.KeepAlivePeriodSeconds) - if err != nil { - return err - } - fmt.Println("Started Listening for plain connection on", l.Addr().Network(), l.Addr().String()) - defer l.Close() - if err := s.Serve(l); err != http.ErrServerClosed { - return err - } - return nil +func doShutdownOnSignal(shutdown chan<- error) { + onSignal := make(chan os.Signal, 1) + signal.Notify(onSignal, os.Interrupt, syscall.SIGTERM) + sig := <-onSignal + doShutdown(shutdown, fmt.Errorf("received signal %s", sig)) } -func runTLS(s *http.Server, conf *config.Configuration) error { - network, addr := listenAddrParse(conf.Server.SSL.ListenAddr, conf.Server.SSL.Port) - l, err := startListening(network, addr, conf.Server.KeepAlivePeriodSeconds) - if err != nil { - return err - } - fmt.Println("Started Listening for TLS connection on", l.Addr().Network(), l.Addr().String()) - defer l.Close() - if err := s.ServeTLS(l, conf.Server.SSL.CertFile, conf.Server.SSL.CertKey); err != http.ErrServerClosed { - return err +func doShutdown(shutdown chan<- error, err error) { + select { + case shutdown <- err: + default: + // we only handle the first error, the other can be ignored } - return nil } -func startListening(network, addr string, keepAlive int) (net.Listener, error) { +func startListening(connectionType, listenAddr string, port, keepAlive int) (net.Listener, error) { + network, addr := getNetworkAndAddr(listenAddr, port) lc := net.ListenConfig{KeepAlive: time.Duration(keepAlive) * time.Second} - conn, err := lc.Listen(context.Background(), network, addr) - if err != nil { - return nil, err - } - return conn, nil -} -func listenAddrParse(ListenAddr string, Port int) (string, string) { - if strings.HasPrefix(ListenAddr, "unix:") { - return "unix", strings.TrimPrefix(ListenAddr, "unix:") + l, err := lc.Listen(context.Background(), network, addr) + if err == nil { + fmt.Println("Started listening for", connectionType, "on", l.Addr().Network(), l.Addr().String()) } - return "tcp", fmt.Sprintf("%s:%d", ListenAddr, Port) + return l, err } -func redirectToHTTPS(port string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" && r.Method != "HEAD" { - http.Error(w, "Use HTTPS", http.StatusBadRequest) - return - } - - target := "https://" + changePort(r.Host, port) + r.URL.RequestURI() - http.Redirect(w, r, target, http.StatusFound) +func getNetworkAndAddr(listenAddr string, port int) (string, string) { + if strings.HasPrefix(listenAddr, "unix:") { + return "unix", strings.TrimPrefix(listenAddr, "unix:") } + return "tcp", fmt.Sprintf("%s:%d", listenAddr, port) } -func changePort(hostPort, port string) string { - host, _, err := net.SplitHostPort(hostPort) - if err != nil { - // There is no exported error. - if !strings.Contains(err.Error(), "missing port") { - return hostPort - } - host = hostPort +func applyLetsEncrypt(s *http.Server, conf *config.Configuration) { + certManager := autocert.Manager{ + Prompt: func(tosURL string) bool { return *conf.Server.SSL.LetsEncrypt.AcceptTOS }, + HostPolicy: autocert.HostWhitelist(conf.Server.SSL.LetsEncrypt.Hosts...), + Cache: autocert.DirCache(conf.Server.SSL.LetsEncrypt.Cache), } - return net.JoinHostPort(host, port) + s.Handler = certManager.HTTPHandler(s.Handler) + s.TLSConfig = &tls.Config{GetCertificate: certManager.GetCertificate} } diff --git a/runner/runner_test.go b/runner/runner_test.go deleted file mode 100644 index fa222558..00000000 --- a/runner/runner_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package runner - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRedirect(t *testing.T) { - cases := []struct { - Request string - TLS int - Expect string - }{ - {Request: "http://gotify.net/meow", TLS: 443, Expect: "https://gotify.net:443/meow"}, - {Request: "http://gotify.net:8080/meow", TLS: 443, Expect: "https://gotify.net:443/meow"}, - {Request: "http://gotify.net:8080/meow", TLS: 8443, Expect: "https://gotify.net:8443/meow"}, - } - - for _, testCase := range cases { - name := fmt.Sprintf("%s -- %d -> %s", testCase.Request, testCase.TLS, testCase.Expect) - t.Run(name, func(t *testing.T) { - req := httptest.NewRequest("GET", testCase.Request, nil) - rec := httptest.NewRecorder() - - redirectToHTTPS(fmt.Sprint(testCase.TLS)).ServeHTTP(rec, req) - - assert.Equal(t, http.StatusFound, rec.Result().StatusCode) - assert.Equal(t, testCase.Expect, rec.Header().Get("location")) - }) - } -}