Skip to content

Commit

Permalink
feat: implement configurable ping timeout for nhooyr.io/websocket #83
Browse files Browse the repository at this point in the history
- added `pingTimeout   time.Duration` to `internal/nhooyr.io/websocket/conn.go`'s `Conn struct`
- implemented `SetPingTimeout`, following how `nhooyr.io/websocket` exposes configurable conn values
- if `conn`'s `pingTimeout` is nonpositive, then use `handleControl`'s 5 second timeout https://github.com/xmidt-org/xmidt-agent/blob/78dffe0cad394ab82f581940e3a8117f04941077/internal/nhooyr.io/websocket/read.go#L301
  • Loading branch information
denopink committed May 3, 2024
1 parent 78dffe0 commit 924a2c6
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 1 deletion.
2 changes: 2 additions & 0 deletions internal/nhooyr.io/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strconv"
"sync"
"sync/atomic"
"time"
)

// MessageType represents the type of a WebSocket message.
Expand Down Expand Up @@ -79,6 +80,7 @@ type Conn struct {
closeErr error
wroteClose bool

pingTimeout time.Duration
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
Expand Down
12 changes: 12 additions & 0 deletions internal/nhooyr.io/websocket/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ func (c *Conn) SetPongListener(f func(context.Context, []byte)) {
c.pongListener = f
}

// SetPingTimeout sets the maximum time allowed between PINGs for the connection
// before the connection is closed.
// Nonpositive PingTimeout will default to handleControl's 5 second timeout.
func (c *Conn) SetPingTimeout(d time.Duration) {
c.pingTimeout = d
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
Expand Down Expand Up @@ -313,6 +320,11 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {

switch h.opcode {
case opPing:
if c.pingTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, c.pingTimeout)
defer cancel()
}

c.pingListener(ctx, b)
return c.writeControl(ctx, opPong, b)
case opPong:
Expand Down
7 changes: 7 additions & 0 deletions internal/nhooyr.io/websocket/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ func (mw *msgWriter) close() {

func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
switch opcode {
case opPong:
if c.pingTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, c.pingTimeout)
}
}

defer cancel()

_, err := c.writeFrame(ctx, true, false, opcode, p)
Expand Down
82 changes: 82 additions & 0 deletions internal/websocket/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"sync"
Expand Down Expand Up @@ -358,3 +359,84 @@ func TestEndToEndConnectionIssues(t *testing.T) {
assert.True(started)
assert.True(msgCnt.Load() > 0, "got message")
}

func TestEndToEndPingTimeout(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

s := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
require.NoError(err)
defer c.CloseNow()

assert.Error(c.Ping(context.Background()))
}))
defer s.Close()

var (
connectCnt, disconnectCnt, heartbeatCnt atomic.Int64
got *ws.Websocket
err error
disconnectErrs []error
)
got, err = ws.New(
ws.URL(s.URL),
ws.DeviceID("mac:112233445566"),
ws.AddHeartbeatListener(
event.HeartbeatListenerFunc(
func(event.Heartbeat) {
heartbeatCnt.Add(1)
})),
ws.AddConnectListener(
event.ConnectListenerFunc(
func(event.Connect) {
connectCnt.Add(1)
})),
ws.AddDisconnectListener(
event.DisconnectListenerFunc(
func(e event.Disconnect) {
disconnectErrs = append(disconnectErrs, e.Err)
disconnectCnt.Add(1)
})),
ws.RetryPolicy(&retry.Config{
Interval: time.Second,
Multiplier: 2.0,
Jitter: 1.0 / 3.0,
MaxInterval: 341*time.Second + 333*time.Millisecond,
}),
ws.WithIPv4(),
ws.NowFunc(time.Now),
ws.ConnectTimeout(30*time.Second),
ws.FetchURLTimeout(30*time.Second),
ws.MaxMessageBytes(256*1024),
ws.CredentialsDecorator(func(h http.Header) error {
return nil
}),
// Trigger a ping timeout
ws.PingTimeout(time.Nanosecond),
)
require.NoError(err)
require.NotNil(got)

got.Start()
time.Sleep(500 * time.Millisecond)
got.Stop()
// heartbeatCnt should be zero due to a ping timeout
assert.Equal(int64(0), heartbeatCnt.Load())
assert.Greater(connectCnt.Load(), int64(0))
assert.Greater(disconnectCnt.Load(), int64(0))
// disconnectErrs should only contain
assert.NotEmpty(disconnectErrs)
// All disconnectErrs errors should be caused by context.DeadlineExceeded
for _, err := range disconnectErrs {
if errors.Is(err, net.ErrClosed) {
// net.ErrClosed may occur during testing, don't count them
continue
}

assert.ErrorIs(err, context.DeadlineExceeded)
}

}
11 changes: 10 additions & 1 deletion internal/websocket/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ func (ws *Websocket) run(ctx context.Context) {
// Store the connection so writing can take place.
ws.m.Lock()
ws.conn = conn
ws.conn.SetPingListener((func(context.Context, []byte) {
ws.conn.SetPingListener((func(ctx context.Context, b []byte) {
if ctx.Err() != nil {
return
}

ws.heartbeatListeners.Visit(func(l event.HeartbeatListener) {
l.OnHeartbeat(event.Heartbeat{
At: ws.nowFunc(),
Expand All @@ -242,6 +246,10 @@ func (ws *Websocket) run(ctx context.Context) {
})
}))
ws.conn.SetPongListener(func(ctx context.Context, b []byte) {
if ctx.Err() != nil {
return
}

ws.heartbeatListeners.Visit(func(l event.HeartbeatListener) {
l.OnHeartbeat(event.Heartbeat{
At: ws.nowFunc(),
Expand Down Expand Up @@ -334,6 +342,7 @@ func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.R
}

conn.SetReadLimit(ws.maxMessageBytes)
conn.SetPingTimeout(ws.pingTimeout)
return conn, resp, nil
}

Expand Down

0 comments on commit 924a2c6

Please sign in to comment.