diff --git a/response_writer.go b/response_writer.go index 6e03082..7f76cef 100644 --- a/response_writer.go +++ b/response_writer.go @@ -78,6 +78,10 @@ func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return hijacker.Hijack() } +func (rw *responseWriter) CloseNotify() <-chan bool { + return rw.ResponseWriter.(http.CloseNotifier).CloseNotify() +} + func (rw *responseWriter) callBefore() { for i := len(rw.beforeFuncs) - 1; i >= 0; i-- { rw.beforeFuncs[i](rw) diff --git a/response_writer_test.go b/response_writer_test.go index c3fef49..4e5c170 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -6,8 +6,29 @@ import ( "net/http" "net/http/httptest" "testing" + "time" ) +type closeNotifyingRecorder struct { + *httptest.ResponseRecorder + closed chan bool +} + +func newCloseNotifyingRecorder() *closeNotifyingRecorder { + return &closeNotifyingRecorder{ + httptest.NewRecorder(), + make(chan bool, 1), + } +} + +func (c *closeNotifyingRecorder) close() { + c.closed <- true +} + +func (c *closeNotifyingRecorder) CloseNotify() <-chan bool { + return c.closed +} + type hijackableResponse struct { Hijacked bool } @@ -95,3 +116,17 @@ func Test_ResponseWriter_Hijack(t *testing.T) { } expect(t, hijackable.Hijacked, true) } + +func Test_ResponseWriter_CloseNotify(t *testing.T) { + rec := newCloseNotifyingRecorder() + rw := NewResponseWriter(rec) + closed := false + notifier := rw.(http.CloseNotifier).CloseNotify() + rec.close() + select { + case <-notifier: + closed = true + case <-time.After(time.Second): + } + expect(t, closed, true) +}