Skip to content

Commit

Permalink
[IMPROVED] Websocket: generating INFO to send to clients
Browse files Browse the repository at this point in the history
PR #4255 added code in generateClientInfoJSON to set the proper
info Host/Port/TLSAvailable/TLSRequired fields to send to clients.

However, this was requiring a lock but more importantly was computing
the listener's host/port everytime, which is not necessary since
this is immutable because we don't support the change during a config
reload.

Also, the TLSRequired field was set based on the server TLSConfig's
InsecureSkipVerify value, which is irrelevant for a server. The
mere presence of a TLSConfig (c.srv.websocket.tls being true) is
enough.

I have modified the TestWSReloadTLSConfig test to verify that the
tls block cannot be removed and no_tls set to true, which means
that tls value can't change. I also added check for the info's
Host/Port/TLSAvailable/TLSRequired values.

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
  • Loading branch information
kozlovic committed May 9, 2024
1 parent 4bf29a9 commit 63204b7
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 49 deletions.
23 changes: 5 additions & 18 deletions server/client.go
Expand Up @@ -2383,24 +2383,11 @@ func (c *client) generateClientInfoJSON(info Info) []byte {
info.MaxPayload = c.mpay
if c.isWebsocket() {
info.ClientConnectURLs = info.WSConnectURLs
if c.srv != nil { // Otherwise lame duck info can panic
c.srv.websocket.mu.RLock()
info.TLSAvailable = c.srv.websocket.tls
if c.srv.websocket.tls && c.srv.websocket.server != nil {
if tc := c.srv.websocket.server.TLSConfig; tc != nil {
info.TLSRequired = !tc.InsecureSkipVerify
}
}
if c.srv.websocket.listener != nil {
laddr := c.srv.websocket.listener.Addr().String()
if h, p, err := net.SplitHostPort(laddr); err == nil {
if p, err := strconv.Atoi(p); err == nil {
info.Host = h
info.Port = p
}
}
}
c.srv.websocket.mu.RUnlock()
// Otherwise lame duck info can panic
if c.srv != nil {
ws := &c.srv.websocket
info.TLSAvailable, info.TLSRequired = ws.tls, ws.tls
info.Host, info.Port = ws.host, ws.port
}
}
info.WSConnectURLs = nil
Expand Down
16 changes: 14 additions & 2 deletions server/websocket.go
Expand Up @@ -128,13 +128,18 @@ type srvWebsocket struct {
server *http.Server
listener net.Listener
listenerErr error
tls bool
allowedOrigins map[string]*allowedOrigin // host will be the key
sameOrigin bool
connectURLs []string
connectURLsMap refCountedUrlSet
authOverride bool // indicate if there is auth override in websocket config
rawHeaders string // raw headers to be used in the upgrade response.

// These are immutable and can be accessed without lock.
// This is the case when generating the client INFO.
tls bool // True if TLS is required (TLSConfig is specified).
host string // Host/IP the webserver is listening on (shortcut to opts.Websocket.Host).
port int // Port the webserver is listening on. This is after an ephemeral port may have been selected (shortcut to opts.Websocket.Port).
}

type allowedOrigin struct {
Expand Down Expand Up @@ -1153,7 +1158,12 @@ func (s *Server) startWebsocketServer() {
s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!")
}

s.websocket.tls = proto == "wss"
// These 3 are immutable and will be accessed without lock by the client
// when generating/sending the INFO protocols.
s.websocket.tls = proto == wsSchemePrefixTLS
s.websocket.host, s.websocket.port = o.Host, o.Port

// This will be updated when/if the cluster changes.
s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port)
if err != nil {
s.Fatalf("Unable to get websocket connect URLs: %v", err)
Expand Down Expand Up @@ -1192,8 +1202,10 @@ func (s *Server) startWebsocketServer() {
ReadTimeout: o.HandshakeTimeout,
ErrorLog: log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0),
}
s.websocket.mu.Lock()
s.websocket.server = hs
s.websocket.listener = hl
s.websocket.mu.Unlock()
go func() {
if err := hs.Serve(hl); err != http.ErrServerClosed {
s.Fatalf("websocket listener error: %v", err)
Expand Down
134 changes: 105 additions & 29 deletions server/websocket_test.go
Expand Up @@ -4061,59 +4061,135 @@ func TestWSJWTCookieUser(t *testing.T) {
}

func TestWSReloadTLSConfig(t *testing.T) {
tlsBlock := `
tls {
cert_file: '%s'
key_file: '%s'
ca_file: '../test/configs/certs/ca.pem'
verify: %v
}
`
template := `
listen: "127.0.0.1:-1"
websocket {
listen: "127.0.0.1:-1"
tls {
cert_file: '%s'
key_file: '%s'
ca_file: '../test/configs/certs/ca.pem'
}
%s
no_tls: %v
}
`
conf := createConfFile(t, []byte(fmt.Sprintf(template,
"../test/configs/certs/server-noip.pem",
"../test/configs/certs/server-key-noip.pem")))
fmt.Sprintf(tlsBlock,
"../test/configs/certs/server-noip.pem",
"../test/configs/certs/server-key-noip.pem",
false), false)))

s, o := RunServerWithConfig(conf)
defer s.Shutdown()

addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)

check := func(tlsConfig *tls.Config, handshakeFail bool, errTxt string) {
t.Helper()

wsc, err := net.Dial("tcp", addr)
require_NoError(t, err)
defer wsc.Close()

wsc = tls.Client(wsc, tlsConfig)
err = wsc.(*tls.Conn).Handshake()
if handshakeFail {
require_True(t, err != nil)
require_Contains(t, err.Error(), errTxt)
return
}
require_NoError(t, err)

req := testWSCreateValidReq()
req.URL, _ = url.Parse(wsSchemePrefixTLS + "://" + addr)
err = req.Write(wsc)
require_NoError(t, err)

br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if errTxt == _EMPTY_ {
require_NoError(t, err)
} else {
require_True(t, err != nil)
require_Contains(t, err.Error(), errTxt)
return
}
defer resp.Body.Close()
l := testWSReadFrame(t, br)
require_True(t, bytes.HasPrefix(l, []byte("INFO {")))
var info Info
err = json.Unmarshal(l[5:], &info)
require_NoError(t, err)
require_True(t, info.TLSAvailable)
require_True(t, info.TLSRequired)
require_Equal[string](t, info.Host, "127.0.0.1")
require_Equal[int](t, info.Port, o.Websocket.Port)
}
defer wsc.Close()

tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"}
tlsConfig, err := GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error generating TLS config: %v", err)
}
require_NoError(t, err)
tlsConfig.ServerName = "127.0.0.1"
tlsConfig.RootCAs = tlsConfig.ClientCAs
tlsConfig.ClientCAs = nil
wsc = tls.Client(wsc, tlsConfig.Clone())
if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") {
t.Fatalf("Unexpected error: %v", err)
}
wsc.Close()

// Handshake should fail with error regarding SANs
check(tlsConfig.Clone(), true, "SAN")

// Replace certs with ones that allow IP.
reloadUpdateConfig(t, s, conf, fmt.Sprintf(template,
"../test/configs/certs/server-cert.pem",
"../test/configs/certs/server-key.pem"))
fmt.Sprintf(tlsBlock,
"../test/configs/certs/server-cert.pem",
"../test/configs/certs/server-key.pem",
false), false))

wsc, err = net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
defer wsc.Close()
// Connection should succeed
check(tlsConfig.Clone(), false, _EMPTY_)

wsc = tls.Client(wsc, tlsConfig.Clone())
if err := wsc.(*tls.Conn).Handshake(); err != nil {
t.Fatalf("Error on TLS handshake: %v", err)
// Udpate config to require client cert.
reloadUpdateConfig(t, s, conf, fmt.Sprintf(template,
fmt.Sprintf(tlsBlock,
"../test/configs/certs/server-cert.pem",
"../test/configs/certs/server-key.pem",
true), false))

// Connection should fail saying that a tls cert is required
check(tlsConfig.Clone(), false, "required")

// Add a client cert
tc = &TLSConfigOpts{
CertFile: "../test/configs/certs/client-cert.pem",
KeyFile: "../test/configs/certs/client-key.pem",
}
tlsConfig, err = GenTLSConfig(tc)
require_NoError(t, err)
tlsConfig.InsecureSkipVerify = true

// Connection should succeed
check(tlsConfig.Clone(), false, _EMPTY_)

// Removing the tls{} block but with no_tls still false should fail
changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, _EMPTY_, false)))
err = s.Reload()
require_True(t, err != nil)
require_Contains(t, err.Error(), "TLS configuration")

// We should still be able to connect a TLS client
check(tlsConfig.Clone(), false, _EMPTY_)

// Now remove the tls{} block and set no_tls: true and that should fail
// since this is not supported.
changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, _EMPTY_, true)))
err = s.Reload()
require_True(t, err != nil)
require_Contains(t, err.Error(), "not supported")

// We should still be able to connect a TLS client
check(tlsConfig.Clone(), false, _EMPTY_)
}

type captureClientConnectedLogger struct {
Expand Down

0 comments on commit 63204b7

Please sign in to comment.