New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move packet writing code, add graceful shutdown handler, fix close handshake #43
Changes from all commits
929d35a
0c2a888
b8753a5
b5c5b2d
b6b959d
5140f09
fa31f90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,96 +16,131 @@ import Hummingbird | |
import Logging | ||
import NIOCore | ||
import NIOWebSocket | ||
import ServiceLifecycle | ||
|
||
/// WebSocket type | ||
public enum WebSocketType: Sendable { | ||
case client | ||
case server | ||
} | ||
|
||
/// Handler processing raw WebSocket packets. | ||
/// | ||
/// Manages ping, pong and close messages. Collates data and text messages into final frame | ||
/// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user. | ||
actor WebSocketHandler: Sendable { | ||
enum SocketType: Sendable { | ||
case client | ||
case server | ||
} | ||
|
||
static let pingDataSize = 16 | ||
|
||
let asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame> | ||
let type: SocketType | ||
var closed = false | ||
let type: WebSocketType | ||
var closed: Bool | ||
var pingData: ByteBuffer | ||
|
||
init(asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, type: WebSocketHandler.SocketType) { | ||
init(asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, type: WebSocketType) { | ||
self.asyncChannel = asyncChannel | ||
self.type = type | ||
self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize) | ||
self.closed = false | ||
} | ||
|
||
/// Handle WebSocket AsynChannel | ||
func handle<Context: WebSocketContextProtocol>(handler: @escaping WebSocketDataHandler<Context>, context: Context) async { | ||
try? await self.asyncChannel.executeThenClose { inbound, outbound in | ||
do { | ||
try await withThrowingTaskGroup(of: Void.self) { group in | ||
let webSocketHandlerInbound = WebSocketHandlerInbound() | ||
defer { | ||
asyncChannel.channel.close(promise: nil) | ||
webSocketHandlerInbound.finish() | ||
} | ||
let webSocketHandlerOutbound = WebSocketHandlerOutboundWriter(webSocket: self, outbound: outbound) | ||
group.addTask { | ||
// parse messages coming from inbound | ||
var frameSequence: WebSocketFrameSequence? | ||
for try await frame in inbound { | ||
switch frame.opcode { | ||
case .connectionClose: | ||
return | ||
case .ping: | ||
try await self.onPing(frame, outbound: outbound, context: context) | ||
case .pong: | ||
try await self.onPong(frame, outbound: outbound, context: context) | ||
case .text, .binary: | ||
if var frameSeq = frameSequence { | ||
frameSeq.append(frame) | ||
frameSequence = frameSeq | ||
} else { | ||
frameSequence = WebSocketFrameSequence(frame: frame) | ||
let asyncChannel = self.asyncChannel | ||
try? await asyncChannel.executeThenClose { inbound, outbound in | ||
let webSocketInbound = WebSocketInboundStream() | ||
let webSocketOutbound = WebSocketOutboundWriter( | ||
type: self.type, | ||
allocator: asyncChannel.channel.allocator, | ||
outbound: outbound | ||
) | ||
try await withTaskCancellationHandler { | ||
try await withGracefulShutdownHandler { | ||
try await withThrowingTaskGroup(of: Void.self) { group in | ||
group.addTask { | ||
defer { | ||
webSocketInbound.finish() | ||
} | ||
// parse messages coming from inbound | ||
var frameSequence: WebSocketFrameSequence? | ||
for try await frame in inbound { | ||
do { | ||
context.logger.trace("Received \(frame.opcode)") | ||
switch frame.opcode { | ||
case .connectionClose: | ||
// we received a connection close. Finish the inbound data stream, | ||
// send a close back if it hasn't already been send and exit | ||
webSocketInbound.finish() | ||
_ = try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) | ||
return | ||
case .ping: | ||
try await self.onPing(frame, outbound: webSocketOutbound, context: context) | ||
case .pong: | ||
try await self.onPong(frame, outbound: webSocketOutbound, context: context) | ||
case .text, .binary: | ||
if var frameSeq = frameSequence { | ||
frameSeq.append(frame) | ||
frameSequence = frameSeq | ||
} else { | ||
frameSequence = WebSocketFrameSequence(frame: frame) | ||
} | ||
case .continuation: | ||
if var frameSeq = frameSequence { | ||
frameSeq.append(frame) | ||
frameSequence = frameSeq | ||
} else { | ||
try await self.close(code: .protocolError, outbound: webSocketOutbound, context: context) | ||
} | ||
default: | ||
break | ||
} | ||
if let frameSeq = frameSequence, frame.fin { | ||
await webSocketInbound.send(frameSeq.data) | ||
frameSequence = nil | ||
} | ||
} catch { | ||
// catch errors while processing websocket frames so responding close message | ||
// can be dealt with | ||
let errorCode = WebSocketErrorCode(error) | ||
try await self.close(code: errorCode, outbound: webSocketOutbound, context: context) | ||
} | ||
case .continuation: | ||
if var frameSeq = frameSequence { | ||
frameSeq.append(frame) | ||
frameSequence = frameSeq | ||
} | ||
} | ||
group.addTask { | ||
do { | ||
// handle websocket data and text | ||
try await handler(webSocketInbound, webSocketOutbound, context) | ||
try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) | ||
} catch { | ||
if self.type == .server { | ||
let errorCode = WebSocketErrorCode.unexpectedServerError | ||
try await self.close(code: errorCode, outbound: webSocketOutbound, context: context) | ||
} else { | ||
try await self.close(code: .protocolError, outbound: outbound, context: context) | ||
try await asyncChannel.channel.close(mode: .input) | ||
} | ||
default: | ||
break | ||
} | ||
if let frameSeq = frameSequence, frame.fin { | ||
await webSocketHandlerInbound.send(frameSeq.data) | ||
frameSequence = nil | ||
} | ||
} | ||
try await group.next() | ||
webSocketInbound.finish() | ||
} | ||
group.addTask { | ||
// handle websocket data and text | ||
try await handler(webSocketHandlerInbound, webSocketHandlerOutbound, context) | ||
try await self.close(code: .normalClosure, outbound: outbound, context: context) | ||
} onGracefulShutdown: { | ||
Task { | ||
try? await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) | ||
} | ||
try await group.next() | ||
} | ||
} catch let error as NIOWebSocketError { | ||
let errorCode = WebSocketErrorCode(error) | ||
try await self.close(code: errorCode, outbound: outbound, context: context) | ||
} catch { | ||
let errorCode = WebSocketErrorCode.unexpectedServerError | ||
try await self.close(code: errorCode, outbound: outbound, context: context) | ||
} onCancel: { | ||
Task { | ||
webSocketInbound.finish() | ||
try await asyncChannel.channel.close(mode: .input) | ||
} | ||
} | ||
} | ||
context.logger.debug("Closed WebSocket") | ||
} | ||
|
||
/// Respond to ping | ||
func onPing( | ||
_ frame: WebSocketFrame, | ||
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>, | ||
outbound: WebSocketOutboundWriter, | ||
context: some WebSocketContextProtocol | ||
) async throws { | ||
if frame.fin { | ||
|
@@ -118,9 +153,10 @@ actor WebSocketHandler: Sendable { | |
/// Respond to pong | ||
func onPong( | ||
_ frame: WebSocketFrame, | ||
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>, | ||
outbound: WebSocketOutboundWriter, | ||
context: some WebSocketContextProtocol | ||
) async throws { | ||
guard !self.closed else { return } | ||
let frameData = frame.unmaskedData | ||
guard self.pingData.readableBytes == 0 || frameData == self.pingData else { | ||
try await self.close(code: .goingAway, outbound: outbound, context: context) | ||
|
@@ -130,62 +166,48 @@ actor WebSocketHandler: Sendable { | |
} | ||
|
||
/// Send ping | ||
func ping(outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>) async throws { | ||
func ping(outbound: WebSocketOutboundWriter) async throws { | ||
guard !self.closed else { return } | ||
if self.pingData.readableBytes == 0 { | ||
// creating random payload | ||
let random = (0..<Self.pingDataSize).map { _ in UInt8.random(in: 0...255) } | ||
self.pingData.writeBytes(random) | ||
} | ||
try await self.send(frame: .init(fin: true, opcode: .ping, data: self.pingData), outbound: outbound) | ||
try await outbound.write(frame: .init(fin: true, opcode: .ping, data: self.pingData)) | ||
} | ||
|
||
/// Send pong | ||
func pong(data: ByteBuffer?, outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>) async throws { | ||
func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws { | ||
guard !self.closed else { return } | ||
try await self.send(frame: .init(fin: true, opcode: .pong, data: data ?? .init()), outbound: outbound) | ||
try await outbound.write(frame: .init(fin: true, opcode: .pong, data: data ?? .init())) | ||
} | ||
|
||
/// Send close | ||
func close( | ||
code: WebSocketErrorCode = .normalClosure, | ||
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>, | ||
outbound: WebSocketOutboundWriter, | ||
context: some WebSocketContextProtocol | ||
) async throws { | ||
guard !self.closed else { return } | ||
self.closed = true | ||
|
||
var buffer = context.allocator.buffer(capacity: 2) | ||
buffer.write(webSocketErrorCode: code) | ||
try await self.send(frame: .init(fin: true, opcode: .connectionClose, data: buffer), outbound: outbound) | ||
} | ||
|
||
/// Send WebSocket frame | ||
func send( | ||
frame: WebSocketFrame, | ||
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame> | ||
) async throws { | ||
var frame = frame | ||
frame.maskKey = self.makeMaskKey() | ||
try await outbound.write(frame) | ||
} | ||
|
||
/// Make mask key to be used in WebSocket frame | ||
private func makeMaskKey() -> WebSocketMaskingKey? { | ||
guard self.type == .client else { return nil } | ||
let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) } | ||
return WebSocketMaskingKey(bytes) | ||
try await outbound.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should consider making There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I disagree here, this initiates the close handshake and doesn't actually close the websocket. It only gets closed when we receive the close connection message from the client. If we ignore the error then the close connection is never sent and we don't receive a close connection from the client. Remember also this is an internal function. Users will close the connection by exiting the handler. In the current situation when the error is thrown out of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, good clarification. I misremember that then |
||
outbound.finish() | ||
} | ||
} | ||
|
||
extension WebSocketErrorCode { | ||
init(_ error: NIOWebSocketError) { | ||
init(_ error: any Error) { | ||
switch error { | ||
case .invalidFrameLength: | ||
case NIOWebSocketError.invalidFrameLength: | ||
self = .messageTooLarge | ||
case .fragmentedControlFrame, | ||
.multiByteControlFrameLength: | ||
case NIOWebSocketError.fragmentedControlFrame, | ||
NIOWebSocketError.multiByteControlFrameLength: | ||
self = .protocolError | ||
default: | ||
self = .unexpectedServerError | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way around an unstructured task here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not. I can't close either the inbound stream or outbound writer as they are needed for the close handshake. It's a recurring issue with graceful shutdown handlers, where something more complex is need to initiate shutdown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could have a child task that waits on a continuation and then calls close. I then resume the continuation in the graceful shutdown. But I'd have to somehow get the continuation to the shutdown, and also support the situation where the graceful shutdown is called before the continuation is created. That'd be a bit of a mess