Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move packet writing code, add graceful shutdown handler, fix close handshake #43

Merged
merged 7 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion Snippets/WebsocketTest.swift
@@ -1,7 +1,10 @@
import HTTPTypes
import Hummingbird
import HummingbirdWebSocket
import Logging

var logger = Logger(label: "Echo")
logger.logLevel = .trace
let router = Router(context: BasicWebSocketRequestContext.self)
router.middlewares.add(FileMiddleware("Snippets/public"))
router.get { _, _ in
Expand All @@ -19,6 +22,7 @@ router.ws("/ws") { inbound, outbound, _ in

let app = Application(
router: router,
server: .webSocketUpgrade(webSocketRouter: router)
server: .webSocketUpgrade(webSocketRouter: router),
logger: logger
)
try await app.runService()
2 changes: 1 addition & 1 deletion Sources/HummingbirdWebSocket/WebSocketDataHandler.swift
Expand Up @@ -19,4 +19,4 @@ import NIOCore
import NIOWebSocket

/// Function that handles websocket data and text blocks
public typealias WebSocketDataHandler<Context: WebSocketContextProtocol> = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void
public typealias WebSocketDataHandler<Context: WebSocketContextProtocol> = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void
192 changes: 107 additions & 85 deletions Sources/HummingbirdWebSocket/WebSocketHandler.swift
Expand Up @@ -16,96 +16,131 @@ import Hummingbird
import Logging
import NIOCore
import NIOWebSocket
import ServiceLifecycle

/// WebSocket type
public enum WebSocketType: Sendable {
case client
case server
}

/// Handler processing raw WebSocket packets.
///
/// Manages ping, pong and close messages. Collates data and text messages into final frame
/// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user.
actor WebSocketHandler: Sendable {
enum SocketType: Sendable {
case client
case server
}

static let pingDataSize = 16

let asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>
let type: SocketType
var closed = false
let type: WebSocketType
var closed: Bool
var pingData: ByteBuffer

init(asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, type: WebSocketHandler.SocketType) {
init(asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, type: WebSocketType) {
self.asyncChannel = asyncChannel
self.type = type
self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize)
self.closed = false
}

/// Handle WebSocket AsynChannel
func handle<Context: WebSocketContextProtocol>(handler: @escaping WebSocketDataHandler<Context>, context: Context) async {
try? await self.asyncChannel.executeThenClose { inbound, outbound in
do {
try await withThrowingTaskGroup(of: Void.self) { group in
let webSocketHandlerInbound = WebSocketHandlerInbound()
defer {
asyncChannel.channel.close(promise: nil)
webSocketHandlerInbound.finish()
}
let webSocketHandlerOutbound = WebSocketHandlerOutboundWriter(webSocket: self, outbound: outbound)
group.addTask {
// parse messages coming from inbound
var frameSequence: WebSocketFrameSequence?
for try await frame in inbound {
switch frame.opcode {
case .connectionClose:
return
case .ping:
try await self.onPing(frame, outbound: outbound, context: context)
case .pong:
try await self.onPong(frame, outbound: outbound, context: context)
case .text, .binary:
if var frameSeq = frameSequence {
frameSeq.append(frame)
frameSequence = frameSeq
} else {
frameSequence = WebSocketFrameSequence(frame: frame)
let asyncChannel = self.asyncChannel
try? await asyncChannel.executeThenClose { inbound, outbound in
let webSocketInbound = WebSocketInboundStream()
let webSocketOutbound = WebSocketOutboundWriter(
type: self.type,
allocator: asyncChannel.channel.allocator,
outbound: outbound
)
try await withTaskCancellationHandler {
try await withGracefulShutdownHandler {
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
defer {
webSocketInbound.finish()
}
// parse messages coming from inbound
var frameSequence: WebSocketFrameSequence?
for try await frame in inbound {
do {
context.logger.trace("Received \(frame.opcode)")
switch frame.opcode {
case .connectionClose:
// we received a connection close. Finish the inbound data stream,
// send a close back if it hasn't already been send and exit
webSocketInbound.finish()
_ = try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context)
return
case .ping:
try await self.onPing(frame, outbound: webSocketOutbound, context: context)
case .pong:
try await self.onPong(frame, outbound: webSocketOutbound, context: context)
case .text, .binary:
if var frameSeq = frameSequence {
frameSeq.append(frame)
frameSequence = frameSeq
} else {
frameSequence = WebSocketFrameSequence(frame: frame)
}
case .continuation:
if var frameSeq = frameSequence {
frameSeq.append(frame)
frameSequence = frameSeq
} else {
try await self.close(code: .protocolError, outbound: webSocketOutbound, context: context)
}
default:
break
}
if let frameSeq = frameSequence, frame.fin {
await webSocketInbound.send(frameSeq.data)
frameSequence = nil
}
} catch {
// catch errors while processing websocket frames so responding close message
// can be dealt with
let errorCode = WebSocketErrorCode(error)
try await self.close(code: errorCode, outbound: webSocketOutbound, context: context)
}
case .continuation:
if var frameSeq = frameSequence {
frameSeq.append(frame)
frameSequence = frameSeq
}
}
group.addTask {
do {
// handle websocket data and text
try await handler(webSocketInbound, webSocketOutbound, context)
try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context)
} catch {
if self.type == .server {
let errorCode = WebSocketErrorCode.unexpectedServerError
try await self.close(code: errorCode, outbound: webSocketOutbound, context: context)
} else {
try await self.close(code: .protocolError, outbound: outbound, context: context)
try await asyncChannel.channel.close(mode: .input)
}
default:
break
}
if let frameSeq = frameSequence, frame.fin {
await webSocketHandlerInbound.send(frameSeq.data)
frameSequence = nil
}
}
try await group.next()
webSocketInbound.finish()
}
group.addTask {
// handle websocket data and text
try await handler(webSocketHandlerInbound, webSocketHandlerOutbound, context)
try await self.close(code: .normalClosure, outbound: outbound, context: context)
} onGracefulShutdown: {
Task {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way around an unstructured task here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not. I can't close either the inbound stream or outbound writer as they are needed for the close handshake. It's a recurring issue with graceful shutdown handlers, where something more complex is need to initiate shutdown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could have a child task that waits on a continuation and then calls close. I then resume the continuation in the graceful shutdown. But I'd have to somehow get the continuation to the shutdown, and also support the situation where the graceful shutdown is called before the continuation is created. That'd be a bit of a mess

try? await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context)
}
try await group.next()
}
} catch let error as NIOWebSocketError {
let errorCode = WebSocketErrorCode(error)
try await self.close(code: errorCode, outbound: outbound, context: context)
} catch {
let errorCode = WebSocketErrorCode.unexpectedServerError
try await self.close(code: errorCode, outbound: outbound, context: context)
} onCancel: {
Task {
webSocketInbound.finish()
try await asyncChannel.channel.close(mode: .input)
}
}
}
context.logger.debug("Closed WebSocket")
}

/// Respond to ping
func onPing(
_ frame: WebSocketFrame,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>,
outbound: WebSocketOutboundWriter,
context: some WebSocketContextProtocol
) async throws {
if frame.fin {
Expand All @@ -118,9 +153,10 @@ actor WebSocketHandler: Sendable {
/// Respond to pong
func onPong(
_ frame: WebSocketFrame,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>,
outbound: WebSocketOutboundWriter,
context: some WebSocketContextProtocol
) async throws {
guard !self.closed else { return }
let frameData = frame.unmaskedData
guard self.pingData.readableBytes == 0 || frameData == self.pingData else {
try await self.close(code: .goingAway, outbound: outbound, context: context)
Expand All @@ -130,62 +166,48 @@ actor WebSocketHandler: Sendable {
}

/// Send ping
func ping(outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>) async throws {
func ping(outbound: WebSocketOutboundWriter) async throws {
guard !self.closed else { return }
if self.pingData.readableBytes == 0 {
// creating random payload
let random = (0..<Self.pingDataSize).map { _ in UInt8.random(in: 0...255) }
self.pingData.writeBytes(random)
}
try await self.send(frame: .init(fin: true, opcode: .ping, data: self.pingData), outbound: outbound)
try await outbound.write(frame: .init(fin: true, opcode: .ping, data: self.pingData))
}

/// Send pong
func pong(data: ByteBuffer?, outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>) async throws {
func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws {
guard !self.closed else { return }
try await self.send(frame: .init(fin: true, opcode: .pong, data: data ?? .init()), outbound: outbound)
try await outbound.write(frame: .init(fin: true, opcode: .pong, data: data ?? .init()))
}

/// Send close
func close(
code: WebSocketErrorCode = .normalClosure,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>,
outbound: WebSocketOutboundWriter,
context: some WebSocketContextProtocol
) async throws {
guard !self.closed else { return }
self.closed = true

var buffer = context.allocator.buffer(capacity: 2)
buffer.write(webSocketErrorCode: code)
try await self.send(frame: .init(fin: true, opcode: .connectionClose, data: buffer), outbound: outbound)
}

/// Send WebSocket frame
func send(
frame: WebSocketFrame,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>
) async throws {
var frame = frame
frame.maskKey = self.makeMaskKey()
try await outbound.write(frame)
}

/// Make mask key to be used in WebSocket frame
private func makeMaskKey() -> WebSocketMaskingKey? {
guard self.type == .client else { return nil }
let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) }
return WebSocketMaskingKey(bytes)
try await outbound.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should consider making close non-throwing. I get that writing a frame can fail, but that's an acceptable and ignorable failure in my book - at this point in the WebSocket's lifecycle. Possibly something to ask the NIO team's opinion on.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree here, this initiates the close handshake and doesn't actually close the websocket. It only gets closed when we receive the close connection message from the client. If we ignore the error then the close connection is never sent and we don't receive a close connection from the client. Remember also this is an internal function. Users will close the connection by exiting the handler.

In the current situation when the error is thrown out of the asyncChannel.executeThenClose closure which will force close the connection.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good clarification. I misremember that then

outbound.finish()
}
}

extension WebSocketErrorCode {
init(_ error: NIOWebSocketError) {
init(_ error: any Error) {
switch error {
case .invalidFrameLength:
case NIOWebSocketError.invalidFrameLength:
self = .messageTooLarge
case .fragmentedControlFrame,
.multiByteControlFrameLength:
case NIOWebSocketError.fragmentedControlFrame,
NIOWebSocketError.multiByteControlFrameLength:
self = .protocolError
default:
self = .unexpectedServerError
}
}
}
3 changes: 2 additions & 1 deletion Sources/HummingbirdWebSocket/WebSocketInboundStream.swift
Expand Up @@ -17,7 +17,8 @@ import NIOCore
import NIOWebSocket

/// Inbound websocket data AsyncSequence
public typealias WebSocketHandlerInbound = AsyncChannel<WebSocketDataFrame>
public typealias WebSocketInboundStream = AsyncChannel<WebSocketDataFrame>

/// Enumeration holding WebSocket data
public enum WebSocketDataFrame: Equatable, Sendable, CustomStringConvertible, CustomDebugStringConvertible {
case text(String)
Expand Down
40 changes: 28 additions & 12 deletions Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift
Expand Up @@ -16,7 +16,7 @@ import NIOCore
import NIOWebSocket

/// Outbound websocket writer
public struct WebSocketHandlerOutboundWriter {
public struct WebSocketOutboundWriter: Sendable {
/// WebSocket frame that can be written
public enum OutboundFrame: Sendable {
/// Text frame
Expand All @@ -25,13 +25,12 @@ public struct WebSocketHandlerOutboundWriter {
case binary(ByteBuffer)
/// Unsolicited pong frame
case pong
/// A ping frame. The returning pong will be dealt with by the underlying code
case ping
/// A custom frame not supported by the above
case custom(WebSocketFrame)
}

let webSocket: WebSocketHandler
let type: WebSocketType
let allocator: ByteBufferAllocator
let outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>

/// Write WebSocket frame
Expand All @@ -40,20 +39,37 @@ public struct WebSocketHandlerOutboundWriter {
switch frame {
case .binary(let buffer):
// send binary data
try await self.webSocket.send(frame: .init(fin: true, opcode: .binary, data: buffer), outbound: self.outbound)
try await self.write(frame: .init(fin: true, opcode: .binary, data: buffer))
case .text(let string):
// send text based data
let buffer = self.webSocket.asyncChannel.channel.allocator.buffer(string: string)
try await self.webSocket.send(frame: .init(fin: true, opcode: .text, data: buffer), outbound: self.outbound)
case .ping:
// send ping
try await self.webSocket.ping(outbound: self.outbound)
let buffer = self.allocator.buffer(string: string)
try await self.write(frame: .init(fin: true, opcode: .text, data: buffer))
case .pong:
// send unexplained pong as a heartbeat
try await self.webSocket.pong(data: nil, outbound: self.outbound)
try await self.write(frame: .init(fin: true, opcode: .pong, data: .init()))
case .custom(let frame):
// send custom WebSocketFrame
try await self.webSocket.send(frame: frame, outbound: self.outbound)
try await self.write(frame: frame)
}
}

/// Send WebSocket frame
func write(
frame: WebSocketFrame
) async throws {
var frame = frame
frame.maskKey = self.makeMaskKey()
try await self.outbound.write(frame)
}

func finish() {
self.outbound.finish()
}

/// Make mask key to be used in WebSocket frame
private func makeMaskKey() -> WebSocketMaskingKey? {
guard self.type == .client else { return nil }
let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) }
return WebSocketMaskingKey(bytes)
}
}