Skip to content

Commit

Permalink
Move packet writing code, add graceful shutdown handler, fix close ha…
Browse files Browse the repository at this point in the history
…ndshake (#43)

* Group inbound and outbound in WebSocket

* Fix closing websocket

* Treat cancallation and graceful shutdown differently

* Finish websocket outbound after close, improve logging

* Clean up close because of error

* Update after rebase

* Split WebSocket back into two
  • Loading branch information
adam-fowler committed Mar 22, 2024
1 parent 0f61374 commit df44f88
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 109 deletions.
6 changes: 5 additions & 1 deletion Snippets/WebsocketTest.swift
@@ -1,7 +1,10 @@
import HTTPTypes
import Hummingbird
import HummingbirdWebSocket
import Logging

var logger = Logger(label: "Echo")
logger.logLevel = .trace
let router = Router(context: BasicWebSocketRequestContext.self)
router.middlewares.add(FileMiddleware("Snippets/public"))
router.get { _, _ in
Expand All @@ -19,6 +22,7 @@ router.ws("/ws") { inbound, outbound, _ in

let app = Application(
router: router,
server: .webSocketUpgrade(webSocketRouter: router)
server: .webSocketUpgrade(webSocketRouter: router),
logger: logger
)
try await app.runService()
2 changes: 1 addition & 1 deletion Sources/HummingbirdWebSocket/WebSocketDataHandler.swift
Expand Up @@ -19,4 +19,4 @@ import NIOCore
import NIOWebSocket

/// Function that handles websocket data and text blocks
public typealias WebSocketDataHandler<Context: WebSocketContextProtocol> = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void
public typealias WebSocketDataHandler<Context: WebSocketContextProtocol> = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void
192 changes: 107 additions & 85 deletions Sources/HummingbirdWebSocket/WebSocketHandler.swift
Expand Up @@ -16,96 +16,131 @@ import Hummingbird
import Logging
import NIOCore
import NIOWebSocket
import ServiceLifecycle

/// WebSocket type
public enum WebSocketType: Sendable {
case client
case server
}

/// Handler processing raw WebSocket packets.
///
/// Manages ping, pong and close messages. Collates data and text messages into final frame
/// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user.
actor WebSocketHandler: Sendable {
enum SocketType: Sendable {
case client
case server
}

static let pingDataSize = 16

let asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>
let type: SocketType
var closed = false
let type: WebSocketType
var closed: Bool
var pingData: ByteBuffer

init(asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, type: WebSocketHandler.SocketType) {
init(asyncChannel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, type: WebSocketType) {
self.asyncChannel = asyncChannel
self.type = type
self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize)
self.closed = false
}

/// Handle WebSocket AsynChannel
func handle<Context: WebSocketContextProtocol>(handler: @escaping WebSocketDataHandler<Context>, context: Context) async {
try? await self.asyncChannel.executeThenClose { inbound, outbound in
do {
try await withThrowingTaskGroup(of: Void.self) { group in
let webSocketHandlerInbound = WebSocketHandlerInbound()
defer {
asyncChannel.channel.close(promise: nil)
webSocketHandlerInbound.finish()
}
let webSocketHandlerOutbound = WebSocketHandlerOutboundWriter(webSocket: self, outbound: outbound)
group.addTask {
// parse messages coming from inbound
var frameSequence: WebSocketFrameSequence?
for try await frame in inbound {
switch frame.opcode {
case .connectionClose:
return
case .ping:
try await self.onPing(frame, outbound: outbound, context: context)
case .pong:
try await self.onPong(frame, outbound: outbound, context: context)
case .text, .binary:
if var frameSeq = frameSequence {
frameSeq.append(frame)
frameSequence = frameSeq
} else {
frameSequence = WebSocketFrameSequence(frame: frame)
let asyncChannel = self.asyncChannel
try? await asyncChannel.executeThenClose { inbound, outbound in
let webSocketInbound = WebSocketInboundStream()
let webSocketOutbound = WebSocketOutboundWriter(
type: self.type,
allocator: asyncChannel.channel.allocator,
outbound: outbound
)
try await withTaskCancellationHandler {
try await withGracefulShutdownHandler {
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
defer {
webSocketInbound.finish()
}
// parse messages coming from inbound
var frameSequence: WebSocketFrameSequence?
for try await frame in inbound {
do {
context.logger.trace("Received \(frame.opcode)")
switch frame.opcode {
case .connectionClose:
// we received a connection close. Finish the inbound data stream,
// send a close back if it hasn't already been send and exit
webSocketInbound.finish()
_ = try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context)
return
case .ping:
try await self.onPing(frame, outbound: webSocketOutbound, context: context)
case .pong:
try await self.onPong(frame, outbound: webSocketOutbound, context: context)
case .text, .binary:
if var frameSeq = frameSequence {
frameSeq.append(frame)
frameSequence = frameSeq
} else {
frameSequence = WebSocketFrameSequence(frame: frame)
}
case .continuation:
if var frameSeq = frameSequence {
frameSeq.append(frame)
frameSequence = frameSeq
} else {
try await self.close(code: .protocolError, outbound: webSocketOutbound, context: context)
}
default:
break
}
if let frameSeq = frameSequence, frame.fin {
await webSocketInbound.send(frameSeq.data)
frameSequence = nil
}
} catch {
// catch errors while processing websocket frames so responding close message
// can be dealt with
let errorCode = WebSocketErrorCode(error)
try await self.close(code: errorCode, outbound: webSocketOutbound, context: context)
}
case .continuation:
if var frameSeq = frameSequence {
frameSeq.append(frame)
frameSequence = frameSeq
}
}
group.addTask {
do {
// handle websocket data and text
try await handler(webSocketInbound, webSocketOutbound, context)
try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context)
} catch {
if self.type == .server {
let errorCode = WebSocketErrorCode.unexpectedServerError
try await self.close(code: errorCode, outbound: webSocketOutbound, context: context)
} else {
try await self.close(code: .protocolError, outbound: outbound, context: context)
try await asyncChannel.channel.close(mode: .input)
}
default:
break
}
if let frameSeq = frameSequence, frame.fin {
await webSocketHandlerInbound.send(frameSeq.data)
frameSequence = nil
}
}
try await group.next()
webSocketInbound.finish()
}
group.addTask {
// handle websocket data and text
try await handler(webSocketHandlerInbound, webSocketHandlerOutbound, context)
try await self.close(code: .normalClosure, outbound: outbound, context: context)
} onGracefulShutdown: {
Task {
try? await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context)
}
try await group.next()
}
} catch let error as NIOWebSocketError {
let errorCode = WebSocketErrorCode(error)
try await self.close(code: errorCode, outbound: outbound, context: context)
} catch {
let errorCode = WebSocketErrorCode.unexpectedServerError
try await self.close(code: errorCode, outbound: outbound, context: context)
} onCancel: {
Task {
webSocketInbound.finish()
try await asyncChannel.channel.close(mode: .input)
}
}
}
context.logger.debug("Closed WebSocket")
}

/// Respond to ping
func onPing(
_ frame: WebSocketFrame,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>,
outbound: WebSocketOutboundWriter,
context: some WebSocketContextProtocol
) async throws {
if frame.fin {
Expand All @@ -118,9 +153,10 @@ actor WebSocketHandler: Sendable {
/// Respond to pong
func onPong(
_ frame: WebSocketFrame,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>,
outbound: WebSocketOutboundWriter,
context: some WebSocketContextProtocol
) async throws {
guard !self.closed else { return }
let frameData = frame.unmaskedData
guard self.pingData.readableBytes == 0 || frameData == self.pingData else {
try await self.close(code: .goingAway, outbound: outbound, context: context)
Expand All @@ -130,62 +166,48 @@ actor WebSocketHandler: Sendable {
}

/// Send ping
func ping(outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>) async throws {
func ping(outbound: WebSocketOutboundWriter) async throws {
guard !self.closed else { return }
if self.pingData.readableBytes == 0 {
// creating random payload
let random = (0..<Self.pingDataSize).map { _ in UInt8.random(in: 0...255) }
self.pingData.writeBytes(random)
}
try await self.send(frame: .init(fin: true, opcode: .ping, data: self.pingData), outbound: outbound)
try await outbound.write(frame: .init(fin: true, opcode: .ping, data: self.pingData))
}

/// Send pong
func pong(data: ByteBuffer?, outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>) async throws {
func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws {
guard !self.closed else { return }
try await self.send(frame: .init(fin: true, opcode: .pong, data: data ?? .init()), outbound: outbound)
try await outbound.write(frame: .init(fin: true, opcode: .pong, data: data ?? .init()))
}

/// Send close
func close(
code: WebSocketErrorCode = .normalClosure,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>,
outbound: WebSocketOutboundWriter,
context: some WebSocketContextProtocol
) async throws {
guard !self.closed else { return }
self.closed = true

var buffer = context.allocator.buffer(capacity: 2)
buffer.write(webSocketErrorCode: code)
try await self.send(frame: .init(fin: true, opcode: .connectionClose, data: buffer), outbound: outbound)
}

/// Send WebSocket frame
func send(
frame: WebSocketFrame,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>
) async throws {
var frame = frame
frame.maskKey = self.makeMaskKey()
try await outbound.write(frame)
}

/// Make mask key to be used in WebSocket frame
private func makeMaskKey() -> WebSocketMaskingKey? {
guard self.type == .client else { return nil }
let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) }
return WebSocketMaskingKey(bytes)
try await outbound.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer))
outbound.finish()
}
}

extension WebSocketErrorCode {
init(_ error: NIOWebSocketError) {
init(_ error: any Error) {
switch error {
case .invalidFrameLength:
case NIOWebSocketError.invalidFrameLength:
self = .messageTooLarge
case .fragmentedControlFrame,
.multiByteControlFrameLength:
case NIOWebSocketError.fragmentedControlFrame,
NIOWebSocketError.multiByteControlFrameLength:
self = .protocolError
default:
self = .unexpectedServerError
}
}
}
3 changes: 2 additions & 1 deletion Sources/HummingbirdWebSocket/WebSocketInboundStream.swift
Expand Up @@ -17,7 +17,8 @@ import NIOCore
import NIOWebSocket

/// Inbound websocket data AsyncSequence
public typealias WebSocketHandlerInbound = AsyncChannel<WebSocketDataFrame>
public typealias WebSocketInboundStream = AsyncChannel<WebSocketDataFrame>

/// Enumeration holding WebSocket data
public enum WebSocketDataFrame: Equatable, Sendable, CustomStringConvertible, CustomDebugStringConvertible {
case text(String)
Expand Down
40 changes: 28 additions & 12 deletions Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift
Expand Up @@ -16,7 +16,7 @@ import NIOCore
import NIOWebSocket

/// Outbound websocket writer
public struct WebSocketHandlerOutboundWriter {
public struct WebSocketOutboundWriter: Sendable {
/// WebSocket frame that can be written
public enum OutboundFrame: Sendable {
/// Text frame
Expand All @@ -25,13 +25,12 @@ public struct WebSocketHandlerOutboundWriter {
case binary(ByteBuffer)
/// Unsolicited pong frame
case pong
/// A ping frame. The returning pong will be dealt with by the underlying code
case ping
/// A custom frame not supported by the above
case custom(WebSocketFrame)
}

let webSocket: WebSocketHandler
let type: WebSocketType
let allocator: ByteBufferAllocator
let outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>

/// Write WebSocket frame
Expand All @@ -40,20 +39,37 @@ public struct WebSocketHandlerOutboundWriter {
switch frame {
case .binary(let buffer):
// send binary data
try await self.webSocket.send(frame: .init(fin: true, opcode: .binary, data: buffer), outbound: self.outbound)
try await self.write(frame: .init(fin: true, opcode: .binary, data: buffer))
case .text(let string):
// send text based data
let buffer = self.webSocket.asyncChannel.channel.allocator.buffer(string: string)
try await self.webSocket.send(frame: .init(fin: true, opcode: .text, data: buffer), outbound: self.outbound)
case .ping:
// send ping
try await self.webSocket.ping(outbound: self.outbound)
let buffer = self.allocator.buffer(string: string)
try await self.write(frame: .init(fin: true, opcode: .text, data: buffer))
case .pong:
// send unexplained pong as a heartbeat
try await self.webSocket.pong(data: nil, outbound: self.outbound)
try await self.write(frame: .init(fin: true, opcode: .pong, data: .init()))
case .custom(let frame):
// send custom WebSocketFrame
try await self.webSocket.send(frame: frame, outbound: self.outbound)
try await self.write(frame: frame)
}
}

/// Send WebSocket frame
func write(
frame: WebSocketFrame
) async throws {
var frame = frame
frame.maskKey = self.makeMaskKey()
try await self.outbound.write(frame)
}

func finish() {
self.outbound.finish()
}

/// Make mask key to be used in WebSocket frame
private func makeMaskKey() -> WebSocketMaskingKey? {
guard self.type == .client else { return nil }
let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) }
return WebSocketMaskingKey(bytes)
}
}

0 comments on commit df44f88

Please sign in to comment.