Skip to content

Commit

Permalink
feat: new timeout writer implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed May 17, 2024
1 parent 39ca026 commit d4f8ca5
Showing 1 changed file with 75 additions and 54 deletions.
129 changes: 75 additions & 54 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/security"
Expand Down Expand Up @@ -263,57 +264,62 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
}
}

// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent
// writes after the context contained in it has exceeded the deadline. If a
// partial write occurs before the deadline is exceeded, but the writing is not
// complete it will allow further writes.
// timeoutResponseWriter is a http.ResponseWriter that queues up a response
// body to be sent if the serving completes before the context has exceeded its
// deadline.
type timeoutResponseWriter struct {
ctx context.Context
w http.ResponseWriter
wrote int32
mu sync.Mutex
sync.Mutex

header http.Header
wroteHeader bool
statusCode int
buf bytes.Buffer
}

func (t *timeoutResponseWriter) Header() http.Header {
t.mu.Lock()
defer t.mu.Unlock()
return t.w.Header()
t.Lock()
defer t.Unlock()

return t.header
}

func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ctx.Err() == context.DeadlineExceeded {
if atomic.LoadInt32(&t.wrote) == 0 {
return 0, context.DeadlineExceeded
}
t.Lock()
defer t.Unlock()

// writing started before the deadline exceeded, but the
// deadline came in the middle, so letting the writes go
// through
if !t.wroteHeader {
t.WriteHeader(http.StatusOK)
}

t.wrote = 1

return t.w.Write(bytes)
return t.buf.Write(bytes)
}

func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ctx.Err() == context.DeadlineExceeded {
if atomic.LoadInt32(&t.wrote) == 0 {
return
}
t.Lock()
defer t.Unlock()

// writing started before the deadline exceeded, but the
// deadline came in the middle, so letting the writes go
// through
t.statusCode = statusCode
t.wroteHeader = true
}

func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) {
t.Lock()
defer t.Unlock()

dst := w.Header()
for k, vv := range t.header {
dst[k] = vv
}
if !t.wroteHeader {
t.statusCode = http.StatusOK
}

t.wrote = 1
fmt.Printf("@@@@@@@@@@@@@@@@@@@@@@ %v %v\n", t.header, dst)

t.w.WriteHeader(statusCode)
w.WriteHeader(t.statusCode)
if _, err := w.Write(t.buf.Bytes()); err != nil {
logrus.WithError(err).Warn("Write failed")
}
}

func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
Expand All @@ -323,35 +329,50 @@ func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.H
defer cancel()

timeoutWriter := &timeoutResponseWriter{
w: w,
ctx: ctx,
header: make(http.Header),
}

panicChan := make(chan any, 1)
serverDone := make(chan struct{})
go func() {
<-ctx.Done()
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()

next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
close(serverDone)
}()

select {
case p := <-panicChan:
panic(p)

case <-serverDone:
timeoutWriter.finallyWrite(w)

case <-ctx.Done():
err := ctx.Err()

if err == context.DeadlineExceeded {
timeoutWriter.mu.Lock()
defer timeoutWriter.mu.Unlock()
if timeoutWriter.wrote == 0 {
// writer wasn't written to, so we're sending the error payload

httpError := &HTTPError{
HTTPStatus: http.StatusGatewayTimeout,
ErrorCode: ErrorCodeRequestTimeout,
Message: "Processing this request timed out, please retry after a moment.",
}
httpError := &HTTPError{
HTTPStatus: http.StatusGatewayTimeout,
ErrorCode: ErrorCodeRequestTimeout,
Message: "Processing this request timed out, please retry after a moment.",
}

httpError = httpError.WithInternalError(err)
httpError = httpError.WithInternalError(err)

HandleResponseError(httpError, w, r)
}
}
}()
HandleResponseError(httpError, w, r)
} else {
// unrecognized context error, so we should wait for the server to finish
// and write out the response
<-serverDone

next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
timeoutWriter.finallyWrite(w)
}
}
})
}
}

0 comments on commit d4f8ca5

Please sign in to comment.