From 929d35afc2f8c010d5ce94ac6f977fa9848aa9a4 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 21 Mar 2024 18:54:15 +0000 Subject: [PATCH 1/7] Group inbound and outbound in WebSocket --- Snippets/WebsocketTest.swift | 6 +- Sources/HummingbirdWebSocket/WebSocket.swift | 39 +++++ .../WebSocketDataHandler.swift | 2 +- .../WebSocketHandler.swift | 149 ++++++++---------- .../WebSocketInboundStream.swift | 3 +- .../WebSocketOutboundWriter.swift | 36 +++-- .../WebSocketTests.swift | 124 +++++++-------- 7 files changed, 199 insertions(+), 160 deletions(-) create mode 100644 Sources/HummingbirdWebSocket/WebSocket.swift diff --git a/Snippets/WebsocketTest.swift b/Snippets/WebsocketTest.swift index fad3fce..78c7dae 100644 --- a/Snippets/WebsocketTest.swift +++ b/Snippets/WebsocketTest.swift @@ -8,12 +8,12 @@ router.get { _, _ in "Hello" } -router.ws("/ws") { inbound, outbound, _ in - for try await packet in inbound { +router.ws("/ws") { ws, _ in + for try await packet in ws.inbound { if case .text("disconnect") = packet { break } - try await outbound.write(.custom(packet.webSocketFrame)) + try await ws.outbound.write(.custom(packet.webSocketFrame)) } } diff --git a/Sources/HummingbirdWebSocket/WebSocket.swift b/Sources/HummingbirdWebSocket/WebSocket.swift new file mode 100644 index 0000000..9383486 --- /dev/null +++ b/Sources/HummingbirdWebSocket/WebSocket.swift @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOWebSocket + +/// WebSocket +public struct WebSocket: Sendable { + enum SocketType: Sendable { + case client + case server + } + + /// Inbound stream type + public typealias Inbound = WebSocketInboundStream + /// Outbound writer type + public typealias Outbound = WebSocketOutboundWriter + + /// Inbound stream + public let inbound: Inbound + /// Outbound writer + public let outbound: Outbound + + init(type: SocketType, outbound: NIOAsyncChannelOutboundWriter, allocator: ByteBufferAllocator) { + self.inbound = .init() + self.outbound = .init(type: type, allocator: allocator, outbound: outbound) + } +} diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 74b4c2a..8bf402d 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -19,4 +19,4 @@ import NIOCore import NIOWebSocket /// Function that handles websocket data and text blocks -public typealias WebSocketDataHandler = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void +public typealias WebSocketDataHandler = @Sendable (WebSocket, Context) async throws -> Void diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 6bcac5e..96d9cdf 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -16,25 +16,21 @@ import Hummingbird import Logging import NIOCore import NIOWebSocket +import ServiceLifecycle /// Handler processing raw WebSocket packets. /// /// Manages ping, pong and close messages. Collates data and text messages into final frame /// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user. actor WebSocketHandler: Sendable { - enum SocketType: Sendable { - case client - case server - } - static let pingDataSize = 16 let asyncChannel: NIOAsyncChannel - let type: SocketType + let type: WebSocket.SocketType var closed = false var pingData: ByteBuffer - init(asyncChannel: NIOAsyncChannel, type: WebSocketHandler.SocketType) { + init(asyncChannel: NIOAsyncChannel, type: WebSocket.SocketType) { self.asyncChannel = asyncChannel self.type = type self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize) @@ -43,61 +39,69 @@ actor WebSocketHandler: Sendable { /// Handle WebSocket AsynChannel func handle(handler: @escaping WebSocketDataHandler, context: Context) async { try? await self.asyncChannel.executeThenClose { inbound, outbound in - do { - try await withThrowingTaskGroup(of: Void.self) { group in - let webSocketHandlerInbound = WebSocketHandlerInbound() - defer { - asyncChannel.channel.close(promise: nil) - webSocketHandlerInbound.finish() - } - let webSocketHandlerOutbound = WebSocketHandlerOutboundWriter(webSocket: self, outbound: outbound) - group.addTask { - // parse messages coming from inbound - var frameSequence: WebSocketFrameSequence? - for try await frame in inbound { - switch frame.opcode { - case .connectionClose: - return - case .ping: - try await self.onPing(frame, outbound: outbound, context: context) - case .pong: - try await self.onPong(frame, outbound: outbound, context: context) - case .text, .binary: - if var frameSeq = frameSequence { - frameSeq.append(frame) - frameSequence = frameSeq - } else { - frameSequence = WebSocketFrameSequence(frame: frame) + let webSocket = WebSocket(type: self.type, outbound: outbound, allocator: self.asyncChannel.channel.allocator) + defer { + asyncChannel.channel.close(promise: nil) + webSocket.inbound.finish() + } + try await withTaskCancellationOrGracefulShutdownHandler { + do { + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + // parse messages coming from inbound + var frameSequence: WebSocketFrameSequence? + for try await frame in inbound { + switch frame.opcode { + case .connectionClose: + print("\(self.type): Received close") + return + case .ping: + try await self.onPing(frame, outbound: webSocket.outbound, context: context) + case .pong: + try await self.onPong(frame, outbound: webSocket.outbound, context: context) + case .text, .binary: + if var frameSeq = frameSequence { + frameSeq.append(frame) + frameSequence = frameSeq + } else { + frameSequence = WebSocketFrameSequence(frame: frame) + } + case .continuation: + if var frameSeq = frameSequence { + frameSeq.append(frame) + frameSequence = frameSeq + } else { + try await self.close(code: .protocolError, outbound: webSocket.outbound, context: context) + } + default: + break } - case .continuation: - if var frameSeq = frameSequence { - frameSeq.append(frame) - frameSequence = frameSeq - } else { - try await self.close(code: .protocolError, outbound: outbound, context: context) + if let frameSeq = frameSequence, frame.fin { + await webSocket.inbound.send(frameSeq.data) + frameSequence = nil } - default: - break - } - if let frameSeq = frameSequence, frame.fin { - await webSocketHandlerInbound.send(frameSeq.data) - frameSequence = nil } } + group.addTask { + // handle websocket data and text + try await handler(webSocket, context) + try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) + } + try await group.next() + print("\(self.type): Closed") } - group.addTask { - // handle websocket data and text - try await handler(webSocketHandlerInbound, webSocketHandlerOutbound, context) - try await self.close(code: .normalClosure, outbound: outbound, context: context) - } - try await group.next() + } catch let error as NIOWebSocketError { + let errorCode = WebSocketErrorCode(error) + try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + } catch { + let errorCode = WebSocketErrorCode.unexpectedServerError + try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + } + } onCancelOrGracefulShutdown: { + Task { + try? await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) + webSocket.inbound.finish() } - } catch let error as NIOWebSocketError { - let errorCode = WebSocketErrorCode(error) - try await self.close(code: errorCode, outbound: outbound, context: context) - } catch { - let errorCode = WebSocketErrorCode.unexpectedServerError - try await self.close(code: errorCode, outbound: outbound, context: context) } } } @@ -105,7 +109,7 @@ actor WebSocketHandler: Sendable { /// Respond to ping func onPing( _ frame: WebSocketFrame, - outbound: NIOAsyncChannelOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { if frame.fin { @@ -118,7 +122,7 @@ actor WebSocketHandler: Sendable { /// Respond to pong func onPong( _ frame: WebSocketFrame, - outbound: NIOAsyncChannelOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { let frameData = frame.unmaskedData @@ -130,26 +134,26 @@ actor WebSocketHandler: Sendable { } /// Send ping - func ping(outbound: NIOAsyncChannelOutboundWriter) async throws { + func ping(outbound: WebSocketOutboundWriter) async throws { guard !self.closed else { return } if self.pingData.readableBytes == 0 { // creating random payload let random = (0..) async throws { + func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws { guard !self.closed else { return } - try await self.send(frame: .init(fin: true, opcode: .pong, data: data ?? .init()), outbound: outbound) + try await outbound.write(frame: .init(fin: true, opcode: .pong, data: data ?? .init())) } /// Send close func close( code: WebSocketErrorCode = .normalClosure, - outbound: NIOAsyncChannelOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { guard !self.closed else { return } @@ -157,24 +161,7 @@ actor WebSocketHandler: Sendable { var buffer = context.allocator.buffer(capacity: 2) buffer.write(webSocketErrorCode: code) - try await self.send(frame: .init(fin: true, opcode: .connectionClose, data: buffer), outbound: outbound) - } - - /// Send WebSocket frame - func send( - frame: WebSocketFrame, - outbound: NIOAsyncChannelOutboundWriter - ) async throws { - var frame = frame - frame.maskKey = self.makeMaskKey() - try await outbound.write(frame) - } - - /// Make mask key to be used in WebSocket frame - private func makeMaskKey() -> WebSocketMaskingKey? { - guard self.type == .client else { return nil } - let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) } - return WebSocketMaskingKey(bytes) + try await outbound.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) } } diff --git a/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift b/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift index 487afa5..d6067ae 100644 --- a/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift +++ b/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift @@ -17,7 +17,8 @@ import NIOCore import NIOWebSocket /// Inbound websocket data AsyncSequence -public typealias WebSocketHandlerInbound = AsyncChannel +public typealias WebSocketInboundStream = AsyncChannel + /// Enumeration holding WebSocket data public enum WebSocketDataFrame: Equatable, Sendable, CustomStringConvertible, CustomDebugStringConvertible { case text(String) diff --git a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift index dccc044..d3b9d4e 100644 --- a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift +++ b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift @@ -16,7 +16,7 @@ import NIOCore import NIOWebSocket /// Outbound websocket writer -public struct WebSocketHandlerOutboundWriter { +public struct WebSocketOutboundWriter: Sendable { /// WebSocket frame that can be written public enum OutboundFrame: Sendable { /// Text frame @@ -25,13 +25,12 @@ public struct WebSocketHandlerOutboundWriter { case binary(ByteBuffer) /// Unsolicited pong frame case pong - /// A ping frame. The returning pong will be dealt with by the underlying code - case ping /// A custom frame not supported by the above case custom(WebSocketFrame) } - let webSocket: WebSocketHandler + let type: WebSocket.SocketType + let allocator: ByteBufferAllocator let outbound: NIOAsyncChannelOutboundWriter /// Write WebSocket frame @@ -40,20 +39,33 @@ public struct WebSocketHandlerOutboundWriter { switch frame { case .binary(let buffer): // send binary data - try await self.webSocket.send(frame: .init(fin: true, opcode: .binary, data: buffer), outbound: self.outbound) + try await self.write(frame: .init(fin: true, opcode: .binary, data: buffer)) case .text(let string): // send text based data - let buffer = self.webSocket.asyncChannel.channel.allocator.buffer(string: string) - try await self.webSocket.send(frame: .init(fin: true, opcode: .text, data: buffer), outbound: self.outbound) - case .ping: - // send ping - try await self.webSocket.ping(outbound: self.outbound) + let buffer = self.allocator.buffer(string: string) + try await self.write(frame: .init(fin: true, opcode: .text, data: buffer)) case .pong: // send unexplained pong as a heartbeat - try await self.webSocket.pong(data: nil, outbound: self.outbound) + try await self.write(frame: .init(fin: true, opcode: .pong, data: .init())) case .custom(let frame): // send custom WebSocketFrame - try await self.webSocket.send(frame: frame, outbound: self.outbound) + try await self.write(frame: frame) } } + + /// Send WebSocket frame + func write( + frame: WebSocketFrame + ) async throws { + var frame = frame + frame.maskKey = self.makeMaskKey() + try await self.outbound.write(frame) + } + + /// Make mask key to be used in WebSocket frame + private func makeMaskKey() -> WebSocketMaskingKey? { + guard self.type == .client else { return nil } + let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) } + return WebSocketMaskingKey(bytes) + } } diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index b7666bf..f03de62 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -209,37 +209,37 @@ final class HummingbirdWebSocketTests: XCTestCase { // MARK: Tests func testServerToClientMessage() async throws { - try await self.testClientAndServer { _, outbound, _ in - try await outbound.write(.text("Hello")) - } client: { inbound, _, _ in - var inboundIterator = inbound.makeAsyncIterator() + try await self.testClientAndServer { ws, _ in + try await ws.outbound.write(.text("Hello")) + } client: { ws, _ in + var inboundIterator = ws.inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello")) } } func testClientToServerMessage() async throws { - try await self.testClientAndServer { inbound, _, _ in - var inboundIterator = inbound.makeAsyncIterator() + try await self.testClientAndServer { ws, _ in + var inboundIterator = ws.inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello")) - } client: { _, outbound, _ in - try await outbound.write(.text("Hello")) + } client: { ws, _ in + try await ws.outbound.write(.text("Hello")) } } func testClientToServerSplitPacket() async throws { - try await self.testClientAndServer { inbound, outbound, _ in - for try await packet in inbound { - try await outbound.write(.custom(packet.webSocketFrame)) + try await self.testClientAndServer { ws, _ in + for try await packet in ws.inbound { + try await ws.outbound.write(.custom(packet.webSocketFrame)) } - } client: { inbound, outbound, _ in + } client: { ws, _ in let buffer = ByteBuffer(string: "Hello ") - try await outbound.write(.custom(.init(fin: false, opcode: .text, data: buffer))) + try await ws.outbound.write(.custom(.init(fin: false, opcode: .text, data: buffer))) let buffer2 = ByteBuffer(string: "World!") - try await outbound.write(.custom(.init(fin: true, opcode: .text, data: buffer2))) + try await ws.outbound.write(.custom(.init(fin: true, opcode: .text, data: buffer2))) - var inboundIterator = inbound.makeAsyncIterator() + var inboundIterator = ws.inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello World!")) } @@ -247,23 +247,23 @@ final class HummingbirdWebSocketTests: XCTestCase { // test connection is closed when buffer is too large func testTooLargeBuffer() async throws { - try await self.testClientAndServer { inbound, outbound, _ in + try await self.testClientAndServer { ws, _ in let buffer = ByteBuffer(repeating: 1, count: (1 << 14) + 1) - try await outbound.write(.binary(buffer)) - for try await _ in inbound {} - } client: { inbound, _, _ in - for try await _ in inbound {} + try await ws.outbound.write(.binary(buffer)) + for try await _ in ws.inbound {} + } client: { ws, _ in + for try await _ in ws.inbound {} } } func testNotWebSocket() async throws { do { - try await self.testClientAndServer { inbound, _, _ in - for try await _ in inbound {} + try await self.testClientAndServer { ws, _ in + for try await _ in ws.inbound {} } shouldUpgrade: { _ in return nil - } client: { inbound, _, _ in - for try await _ in inbound {} + } client: { ws, _ in + for try await _ in ws.inbound {} } } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} } @@ -272,7 +272,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let client = try WebSocketClient( url: .init("ws://localhost:10245"), logger: Logger(label: "TestNoConnection") - ) { _, _, _ in + ) { _, _ in } do { try await client.run() @@ -281,8 +281,8 @@ final class HummingbirdWebSocketTests: XCTestCase { } func testTLS() async throws { - try await self.testClientAndServer(serverTLSConfiguration: getServerTLSConfiguration()) { _, outbound, _ in - try await outbound.write(.text("Hello")) + try await self.testClientAndServer(serverTLSConfiguration: getServerTLSConfiguration()) { ws, _ in + try await ws.outbound.write(.text("Hello")) } getClient: { port, logger in var clientTLSConfiguration = try getClientTLSConfiguration() clientTLSConfiguration.certificateVerification = .none @@ -290,8 +290,8 @@ final class HummingbirdWebSocketTests: XCTestCase { url: .init("wss://localhost:\(port)"), tlsConfiguration: clientTLSConfiguration, logger: logger - ) { inbound, _, _ in - var inboundIterator = inbound.makeAsyncIterator() + ) { ws, _ in + var inboundIterator = ws.inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello")) } @@ -299,8 +299,8 @@ final class HummingbirdWebSocketTests: XCTestCase { } func testURLPath() async throws { - try await self.testClientAndServer { inbound, _, _ in - for try await _ in inbound {} + try await self.testClientAndServer { ws, _ in + for try await _ in ws.inbound {} } shouldUpgrade: { head in XCTAssertEqual(head.path, "/ws") return [:] @@ -308,14 +308,14 @@ final class HummingbirdWebSocketTests: XCTestCase { try WebSocketClient( url: .init("ws://localhost:\(port)/ws"), logger: logger - ) { _, _, _ in + ) { _, _ in } } } func testQueryParameters() async throws { - try await self.testClientAndServer { inbound, _, _ in - for try await _ in inbound {} + try await self.testClientAndServer { ws, _ in + for try await _ in ws.inbound {} } shouldUpgrade: { head in let request = Request(head: head, body: .init(buffer: ByteBuffer())) XCTAssertEqual(request.uri.query, "query=parameters&test=true") @@ -324,14 +324,14 @@ final class HummingbirdWebSocketTests: XCTestCase { try WebSocketClient( url: .init("ws://localhost:\(port)/ws?query=parameters&test=true"), logger: logger - ) { _, _, _ in + ) { _, _ in } } } func testAdditionalHeaders() async throws { - try await self.testClientAndServer { inbound, _, _ in - for try await _ in inbound {} + try await self.testClientAndServer { ws, _ in + for try await _ in ws.inbound {} } shouldUpgrade: { head in let request = Request(head: head, body: .init(buffer: ByteBuffer())) XCTAssertEqual(request.headers[.secWebSocketExtensions], "hb") @@ -341,7 +341,7 @@ final class HummingbirdWebSocketTests: XCTestCase { url: .init("ws://localhost:\(port)/ws?query=parameters&test=true"), configuration: .init(additionalHeaders: [.secWebSocketExtensions: "hb"]), logger: logger - ) { _, _, _ in + ) { _, _ in } } } @@ -360,8 +360,8 @@ final class HummingbirdWebSocketTests: XCTestCase { let app = Application( router: router, server: .webSocketUpgrade { _, _, _ in - return .upgrade([:]) { _, outbound, _ in - try await outbound.write(.text("Hello")) + return .upgrade([:]) { ws, _ in + try await ws.outbound.write(.text("Hello")) } }, onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, @@ -378,8 +378,8 @@ final class HummingbirdWebSocketTests: XCTestCase { try await serviceGroup.run() } group.addTask { - try await WebSocketClient.connect(url: .init("ws://localhost:\(promise.wait())/ws"), logger: logger) { inbound, _, _ in - var inboundIterator = inbound.makeAsyncIterator() + try await WebSocketClient.connect(url: .init("ws://localhost:\(promise.wait())/ws"), logger: logger) { ws, _ in + var inboundIterator = ws.inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello")) } @@ -393,24 +393,24 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws1") { _, _ in return .upgrade([:]) - } onUpgrade: { _, outbound, _ in - try await outbound.write(.text("One")) + } handle: { ws, _ in + try await ws.outbound.write(.text("One")) } router.ws("/ws2") { _, _ in return .upgrade([:]) - } onUpgrade: { _, outbound, _ in - try await outbound.write(.text("Two")) + } onUpgrade: { ws, _ in + try await ws.outbound.write(.text("Two")) } try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/ws1"), logger: logger) { inbound, _, _ in - var inboundIterator = inbound.makeAsyncIterator() + try WebSocketClient(url: .init("ws://localhost:\(port)/ws1"), logger: logger) { ws, _ in + var inboundIterator = ws.inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("One")) } } try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/ws2"), logger: logger) { inbound, _, _ in - var inboundIterator = inbound.makeAsyncIterator() + try WebSocketClient(url: .init("ws://localhost:\(port)/ws2"), logger: logger) { ws, _ in + var inboundIterator = ws.inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Two")) } @@ -422,13 +422,13 @@ final class HummingbirdWebSocketTests: XCTestCase { router.group("/ws") .add(middleware: WebSocketUpgradeMiddleware { _, _ in return .upgrade([:]) - } onUpgrade: { _, outbound, _ in - try await outbound.write(.text("One")) + } onUpgrade: { ws, _ in + try await ws.outbound.write(.text("One")) }) .get { _, _ -> Response in return .init(status: .ok) } do { try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { _, _, _ in } + try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { _, _ in } } } } @@ -437,12 +437,12 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in return .upgrade([:]) - } onUpgrade: { _, outbound, _ in - try await outbound.write(.text("One")) + } onUpgrade: { ws, _ in + try await ws.outbound.write(.text("One")) } do { try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/not-ws"), logger: logger) { _, _, _ in } + try WebSocketClient(url: .init("ws://localhost:\(port)/not-ws"), logger: logger) { _, _ in } } } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} } @@ -471,13 +471,13 @@ final class HummingbirdWebSocketTests: XCTestCase { router.middlewares.add(MyMiddleware()) router.ws("/ws") { _, _ in return .upgrade([:]) - } onUpgrade: { _, outbound, context in - try await outbound.write(.text(context.name)) + } onUpgrade: { ws, context in + try await ws.outbound.write(.text(context.name)) } do { try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { inbound, _, _ in - let text = await inbound.first { _ in true } + try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { ws, _ in + let text = await ws.inbound.first { _ in true } XCTAssertEqual(text, .text("Roger Moore")) } } @@ -488,8 +488,8 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in return .upgrade([:]) - } onUpgrade: { _, outbound, _ in - try await outbound.write(.text("Hello")) + } onUpgrade: { ws, _ in + try await ws.outbound.write(.text("Hello")) } router.get("/http") { _, _ in return "Hello" From 0c2a888a81bf71eba562fe2a96120353f233e85e Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 21 Mar 2024 20:44:56 +0000 Subject: [PATCH 2/7] Fix closing websocket --- .../WebSocketHandler.swift | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 96d9cdf..25d4206 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -27,33 +27,34 @@ actor WebSocketHandler: Sendable { let asyncChannel: NIOAsyncChannel let type: WebSocket.SocketType - var closed = false + var closed: Bool var pingData: ByteBuffer init(asyncChannel: NIOAsyncChannel, type: WebSocket.SocketType) { self.asyncChannel = asyncChannel self.type = type self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize) + self.closed = false } /// Handle WebSocket AsynChannel func handle(handler: @escaping WebSocketDataHandler, context: Context) async { try? await self.asyncChannel.executeThenClose { inbound, outbound in let webSocket = WebSocket(type: self.type, outbound: outbound, allocator: self.asyncChannel.channel.allocator) - defer { - asyncChannel.channel.close(promise: nil) - webSocket.inbound.finish() - } try await withTaskCancellationOrGracefulShutdownHandler { - do { - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - // parse messages coming from inbound - var frameSequence: WebSocketFrameSequence? - for try await frame in inbound { + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + // parse messages coming from inbound + var frameSequence: WebSocketFrameSequence? + for try await frame in inbound { + do { + print("\(self.type): Received \(frame.opcode)") switch frame.opcode { case .connectionClose: - print("\(self.type): Received close") + // we received a connection close. Finish the inbound data stream, + // send a close back if it hasn't already been send and exit + webSocket.inbound.finish() + _ = try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) return case .ping: try await self.onPing(frame, outbound: webSocket.outbound, context: context) @@ -80,30 +81,35 @@ actor WebSocketHandler: Sendable { await webSocket.inbound.send(frameSeq.data) frameSequence = nil } + } catch let error as NIOWebSocketError { + let errorCode = WebSocketErrorCode(error) + try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + } catch { + let errorCode = WebSocketErrorCode.unexpectedServerError + try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) } } - group.addTask { + } + group.addTask { + do { // handle websocket data and text try await handler(webSocket, context) try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) + } catch { + let errorCode = WebSocketErrorCode.unexpectedServerError + try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) } - try await group.next() - print("\(self.type): Closed") } - } catch let error as NIOWebSocketError { - let errorCode = WebSocketErrorCode(error) - try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) - } catch { - let errorCode = WebSocketErrorCode.unexpectedServerError - try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + try await group.next() + webSocket.inbound.finish() } } onCancelOrGracefulShutdown: { Task { try? await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) - webSocket.inbound.finish() } } } + print("\(self.type): Really Closed") } /// Respond to ping @@ -125,6 +131,7 @@ actor WebSocketHandler: Sendable { outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { + guard !self.closed else { return } let frameData = frame.unmaskedData guard self.pingData.readableBytes == 0 || frameData == self.pingData else { try await self.close(code: .goingAway, outbound: outbound, context: context) From b8753a577c2011e7308ab30ae97aa650b9838713 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 21 Mar 2024 21:41:42 +0000 Subject: [PATCH 3/7] Treat cancallation and graceful shutdown differently --- .../WebSocketHandler.swift | 129 ++++++++++-------- 1 file changed, 72 insertions(+), 57 deletions(-) diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 25d4206..e5b68b4 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -39,73 +39,88 @@ actor WebSocketHandler: Sendable { /// Handle WebSocket AsynChannel func handle(handler: @escaping WebSocketDataHandler, context: Context) async { - try? await self.asyncChannel.executeThenClose { inbound, outbound in - let webSocket = WebSocket(type: self.type, outbound: outbound, allocator: self.asyncChannel.channel.allocator) - try await withTaskCancellationOrGracefulShutdownHandler { - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - // parse messages coming from inbound - var frameSequence: WebSocketFrameSequence? - for try await frame in inbound { - do { - print("\(self.type): Received \(frame.opcode)") - switch frame.opcode { - case .connectionClose: - // we received a connection close. Finish the inbound data stream, - // send a close back if it hasn't already been send and exit - webSocket.inbound.finish() - _ = try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) - return - case .ping: - try await self.onPing(frame, outbound: webSocket.outbound, context: context) - case .pong: - try await self.onPong(frame, outbound: webSocket.outbound, context: context) - case .text, .binary: - if var frameSeq = frameSequence { - frameSeq.append(frame) - frameSequence = frameSeq - } else { - frameSequence = WebSocketFrameSequence(frame: frame) + let asyncChannel = self.asyncChannel + try? await asyncChannel.executeThenClose { inbound, outbound in + let webSocket = WebSocket(type: self.type, outbound: outbound, allocator: asyncChannel.channel.allocator) + try await withTaskCancellationHandler { + try await withGracefulShutdownHandler { + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + defer { + webSocket.inbound.finish() + } + // parse messages coming from inbound + var frameSequence: WebSocketFrameSequence? + for try await frame in inbound { + do { + print("\(self.type): Received \(frame.opcode)") + switch frame.opcode { + case .connectionClose: + // we received a connection close. Finish the inbound data stream, + // send a close back if it hasn't already been send and exit + webSocket.inbound.finish() + _ = try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) + return + case .ping: + try await self.onPing(frame, outbound: webSocket.outbound, context: context) + case .pong: + try await self.onPong(frame, outbound: webSocket.outbound, context: context) + case .text, .binary: + if var frameSeq = frameSequence { + frameSeq.append(frame) + frameSequence = frameSeq + } else { + frameSequence = WebSocketFrameSequence(frame: frame) + } + case .continuation: + if var frameSeq = frameSequence { + frameSeq.append(frame) + frameSequence = frameSeq + } else { + try await self.close(code: .protocolError, outbound: webSocket.outbound, context: context) + } + default: + break } - case .continuation: - if var frameSeq = frameSequence { - frameSeq.append(frame) - frameSequence = frameSeq - } else { - try await self.close(code: .protocolError, outbound: webSocket.outbound, context: context) + if let frameSeq = frameSequence, frame.fin { + await webSocket.inbound.send(frameSeq.data) + frameSequence = nil } - default: - break - } - if let frameSeq = frameSequence, frame.fin { - await webSocket.inbound.send(frameSeq.data) - frameSequence = nil + } catch let error as NIOWebSocketError { + let errorCode = WebSocketErrorCode(error) + try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + } catch { + let errorCode = WebSocketErrorCode.unexpectedServerError + try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) } - } catch let error as NIOWebSocketError { - let errorCode = WebSocketErrorCode(error) - try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + } + } + group.addTask { + do { + // handle websocket data and text + try await handler(webSocket, context) + try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) } catch { - let errorCode = WebSocketErrorCode.unexpectedServerError - try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + if self.type == .server { + let errorCode = WebSocketErrorCode.unexpectedServerError + try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + } else { + try await asyncChannel.channel.close(mode: .input) + } } } + try await group.next() + webSocket.inbound.finish() } - group.addTask { - do { - // handle websocket data and text - try await handler(webSocket, context) - try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) - } catch { - let errorCode = WebSocketErrorCode.unexpectedServerError - try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) - } + } onGracefulShutdown: { + Task { + try? await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) } - try await group.next() - webSocket.inbound.finish() } - } onCancelOrGracefulShutdown: { + } onCancel: { Task { - try? await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) + webSocket.inbound.finish() + try await asyncChannel.channel.close(mode: .input) } } } From b5c5b2d7a3d2de544250d7853d1b678597bb72c7 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 22 Mar 2024 08:32:04 +0000 Subject: [PATCH 4/7] Finish websocket outbound after close, improve logging --- Snippets/WebsocketTest.swift | 6 ++- .../WebSocketHandler.swift | 5 ++- .../WebSocketOutboundWriter.swift | 4 ++ .../WebSocketTests.swift | 41 +++++++++++++++---- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/Snippets/WebsocketTest.swift b/Snippets/WebsocketTest.swift index 78c7dae..b0ef9dc 100644 --- a/Snippets/WebsocketTest.swift +++ b/Snippets/WebsocketTest.swift @@ -1,7 +1,10 @@ import HTTPTypes import Hummingbird import HummingbirdWebSocket +import Logging +var logger = Logger(label: "Echo") +logger.logLevel = .trace let router = Router(context: BasicWebSocketRequestContext.self) router.middlewares.add(FileMiddleware("Snippets/public")) router.get { _, _ in @@ -19,6 +22,7 @@ router.ws("/ws") { ws, _ in let app = Application( router: router, - server: .webSocketUpgrade(webSocketRouter: router) + server: .webSocketUpgrade(webSocketRouter: router), + logger: logger ) try await app.runService() diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index e5b68b4..40e1d42 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -53,7 +53,7 @@ actor WebSocketHandler: Sendable { var frameSequence: WebSocketFrameSequence? for try await frame in inbound { do { - print("\(self.type): Received \(frame.opcode)") + context.logger.trace("Received \(frame.opcode)") switch frame.opcode { case .connectionClose: // we received a connection close. Finish the inbound data stream, @@ -124,7 +124,7 @@ actor WebSocketHandler: Sendable { } } } - print("\(self.type): Really Closed") + context.logger.debug("Closed WebSocket") } /// Respond to ping @@ -184,6 +184,7 @@ actor WebSocketHandler: Sendable { var buffer = context.allocator.buffer(capacity: 2) buffer.write(webSocketErrorCode: code) try await outbound.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) + outbound.finish() } } diff --git a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift index d3b9d4e..f9e5e5c 100644 --- a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift +++ b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift @@ -62,6 +62,10 @@ public struct WebSocketOutboundWriter: Sendable { try await self.outbound.write(frame) } + func finish() { + self.outbound.finish() + } + /// Make mask key to be used in WebSocket frame private func makeMaskKey() -> WebSocketMaskingKey? { guard self.type == .client else { return nil } diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index f03de62..66b07f4 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -82,8 +82,13 @@ final class HummingbirdWebSocketTests: XCTestCase { ) async throws { try await withThrowingTaskGroup(of: Void.self) { group in let promise = Promise() - let logger = { - var logger = Logger(label: "WebSocketTest") + let serverLogger = { + var logger = Logger(label: "WebSocketServer") + logger.logLevel = .debug + return logger + }() + let clientLogger = { + var logger = Logger(label: "WebSocketClient") logger.logLevel = .debug return logger }() @@ -101,7 +106,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router: router, server: .tls(webSocketUpgrade, tlsConfiguration: serverTLSConfiguration), onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, - logger: logger + logger: serverLogger ) serviceGroup = ServiceGroup( configuration: .init( @@ -115,7 +120,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router: router, server: webSocketUpgrade, onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, - logger: logger + logger: serverLogger ) serviceGroup = ServiceGroup( configuration: .init( @@ -129,7 +134,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try await serviceGroup.run() } group.addTask { - let client = try await getClient(promise.wait(), logger) + let client = try await getClient(promise.wait(), clientLogger) try await client.run() } do { @@ -169,8 +174,13 @@ final class HummingbirdWebSocketTests: XCTestCase { ) async throws { try await withThrowingTaskGroup(of: Void.self) { group in let promise = Promise() - let logger = { - var logger = Logger(label: "WebSocketTest") + let serverLogger = { + var logger = Logger(label: "WebSocketServer") + logger.logLevel = .debug + return logger + }() + let clientLogger = { + var logger = Logger(label: "WebSocketClient") logger.logLevel = .debug return logger }() @@ -180,7 +190,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router: router, server: .webSocketUpgrade(webSocketRouter: webSocketRouter), onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, - logger: logger + logger: serverLogger ) serviceGroup = ServiceGroup( configuration: .init( @@ -193,7 +203,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try await serviceGroup.run() } group.addTask { - let client = try await getClient(promise.wait(), logger) + let client = try await getClient(promise.wait(), clientLogger) try await client.run() } do { @@ -567,3 +577,16 @@ final class HummingbirdWebSocketTests: XCTestCase { */ } + +extension Logger { + /// Create new Logger with additional metadata value + /// - Parameters: + /// - metadataKey: Metadata key + /// - value: Metadata value + /// - Returns: Logger + func with(metadataKey: String, value: MetadataValue) -> Logger { + var logger = self + logger[metadataKey: metadataKey] = value + return logger + } +} From b6b959d9761afb27b7f15384057e66c403535c7a Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 22 Mar 2024 09:07:58 +0000 Subject: [PATCH 5/7] Clean up close because of error --- .../HummingbirdWebSocket/WebSocketHandler.swift | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 40e1d42..3afe300 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -86,11 +86,10 @@ actor WebSocketHandler: Sendable { await webSocket.inbound.send(frameSeq.data) frameSequence = nil } - } catch let error as NIOWebSocketError { - let errorCode = WebSocketErrorCode(error) - try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) } catch { - let errorCode = WebSocketErrorCode.unexpectedServerError + // catch errors while processing websocket frames so responding close message + // can be dealt with + let errorCode = WebSocketErrorCode(error) try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) } } @@ -189,13 +188,15 @@ actor WebSocketHandler: Sendable { } extension WebSocketErrorCode { - init(_ error: NIOWebSocketError) { + init(_ error: any Error) { switch error { - case .invalidFrameLength: + case NIOWebSocketError.invalidFrameLength: self = .messageTooLarge - case .fragmentedControlFrame, - .multiByteControlFrameLength: + case NIOWebSocketError.fragmentedControlFrame, + NIOWebSocketError.multiByteControlFrameLength: self = .protocolError + default: + self = .unexpectedServerError } } } From 5140f093b45b2268667e8429cd268b1f8124b29b Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 22 Mar 2024 10:37:10 +0000 Subject: [PATCH 6/7] Update after rebase --- Tests/HummingbirdWebSocketTests/WebSocketTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index 66b07f4..bb4ca0f 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -403,7 +403,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws1") { _, _ in return .upgrade([:]) - } handle: { ws, _ in + } onUpgrade: { ws, _ in try await ws.outbound.write(.text("One")) } router.ws("/ws2") { _, _ in From fa31f9075376ed51a96f436dd24700e5b164b334 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 22 Mar 2024 15:44:18 +0000 Subject: [PATCH 7/7] Split WebSocket back into two --- Snippets/WebsocketTest.swift | 6 +- Sources/HummingbirdWebSocket/WebSocket.swift | 39 ------ .../WebSocketDataHandler.swift | 2 +- .../WebSocketHandler.swift | 45 ++++--- .../WebSocketOutboundWriter.swift | 2 +- .../WebSocketTests.swift | 124 +++++++++--------- 6 files changed, 95 insertions(+), 123 deletions(-) delete mode 100644 Sources/HummingbirdWebSocket/WebSocket.swift diff --git a/Snippets/WebsocketTest.swift b/Snippets/WebsocketTest.swift index b0ef9dc..1c2ac3a 100644 --- a/Snippets/WebsocketTest.swift +++ b/Snippets/WebsocketTest.swift @@ -11,12 +11,12 @@ router.get { _, _ in "Hello" } -router.ws("/ws") { ws, _ in - for try await packet in ws.inbound { +router.ws("/ws") { inbound, outbound, _ in + for try await packet in inbound { if case .text("disconnect") = packet { break } - try await ws.outbound.write(.custom(packet.webSocketFrame)) + try await outbound.write(.custom(packet.webSocketFrame)) } } diff --git a/Sources/HummingbirdWebSocket/WebSocket.swift b/Sources/HummingbirdWebSocket/WebSocket.swift deleted file mode 100644 index 9383486..0000000 --- a/Sources/HummingbirdWebSocket/WebSocket.swift +++ /dev/null @@ -1,39 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Hummingbird server framework project -// -// Copyright (c) 2023-2024 the Hummingbird authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import NIOCore -import NIOWebSocket - -/// WebSocket -public struct WebSocket: Sendable { - enum SocketType: Sendable { - case client - case server - } - - /// Inbound stream type - public typealias Inbound = WebSocketInboundStream - /// Outbound writer type - public typealias Outbound = WebSocketOutboundWriter - - /// Inbound stream - public let inbound: Inbound - /// Outbound writer - public let outbound: Outbound - - init(type: SocketType, outbound: NIOAsyncChannelOutboundWriter, allocator: ByteBufferAllocator) { - self.inbound = .init() - self.outbound = .init(type: type, allocator: allocator, outbound: outbound) - } -} diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 8bf402d..6bc9a7a 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -19,4 +19,4 @@ import NIOCore import NIOWebSocket /// Function that handles websocket data and text blocks -public typealias WebSocketDataHandler = @Sendable (WebSocket, Context) async throws -> Void +public typealias WebSocketDataHandler = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 3afe300..ed1bc61 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -18,6 +18,12 @@ import NIOCore import NIOWebSocket import ServiceLifecycle +/// WebSocket type +public enum WebSocketType: Sendable { + case client + case server +} + /// Handler processing raw WebSocket packets. /// /// Manages ping, pong and close messages. Collates data and text messages into final frame @@ -26,11 +32,11 @@ actor WebSocketHandler: Sendable { static let pingDataSize = 16 let asyncChannel: NIOAsyncChannel - let type: WebSocket.SocketType + let type: WebSocketType var closed: Bool var pingData: ByteBuffer - init(asyncChannel: NIOAsyncChannel, type: WebSocket.SocketType) { + init(asyncChannel: NIOAsyncChannel, type: WebSocketType) { self.asyncChannel = asyncChannel self.type = type self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize) @@ -41,13 +47,18 @@ actor WebSocketHandler: Sendable { func handle(handler: @escaping WebSocketDataHandler, context: Context) async { let asyncChannel = self.asyncChannel try? await asyncChannel.executeThenClose { inbound, outbound in - let webSocket = WebSocket(type: self.type, outbound: outbound, allocator: asyncChannel.channel.allocator) + let webSocketInbound = WebSocketInboundStream() + let webSocketOutbound = WebSocketOutboundWriter( + type: self.type, + allocator: asyncChannel.channel.allocator, + outbound: outbound + ) try await withTaskCancellationHandler { try await withGracefulShutdownHandler { try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { defer { - webSocket.inbound.finish() + webSocketInbound.finish() } // parse messages coming from inbound var frameSequence: WebSocketFrameSequence? @@ -58,13 +69,13 @@ actor WebSocketHandler: Sendable { case .connectionClose: // we received a connection close. Finish the inbound data stream, // send a close back if it hasn't already been send and exit - webSocket.inbound.finish() - _ = try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) + webSocketInbound.finish() + _ = try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) return case .ping: - try await self.onPing(frame, outbound: webSocket.outbound, context: context) + try await self.onPing(frame, outbound: webSocketOutbound, context: context) case .pong: - try await self.onPong(frame, outbound: webSocket.outbound, context: context) + try await self.onPong(frame, outbound: webSocketOutbound, context: context) case .text, .binary: if var frameSeq = frameSequence { frameSeq.append(frame) @@ -77,48 +88,48 @@ actor WebSocketHandler: Sendable { frameSeq.append(frame) frameSequence = frameSeq } else { - try await self.close(code: .protocolError, outbound: webSocket.outbound, context: context) + try await self.close(code: .protocolError, outbound: webSocketOutbound, context: context) } default: break } if let frameSeq = frameSequence, frame.fin { - await webSocket.inbound.send(frameSeq.data) + await webSocketInbound.send(frameSeq.data) frameSequence = nil } } catch { // catch errors while processing websocket frames so responding close message // can be dealt with let errorCode = WebSocketErrorCode(error) - try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + try await self.close(code: errorCode, outbound: webSocketOutbound, context: context) } } } group.addTask { do { // handle websocket data and text - try await handler(webSocket, context) - try await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) + try await handler(webSocketInbound, webSocketOutbound, context) + try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) } catch { if self.type == .server { let errorCode = WebSocketErrorCode.unexpectedServerError - try await self.close(code: errorCode, outbound: webSocket.outbound, context: context) + try await self.close(code: errorCode, outbound: webSocketOutbound, context: context) } else { try await asyncChannel.channel.close(mode: .input) } } } try await group.next() - webSocket.inbound.finish() + webSocketInbound.finish() } } onGracefulShutdown: { Task { - try? await self.close(code: .normalClosure, outbound: webSocket.outbound, context: context) + try? await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) } } } onCancel: { Task { - webSocket.inbound.finish() + webSocketInbound.finish() try await asyncChannel.channel.close(mode: .input) } } diff --git a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift index f9e5e5c..4d334ee 100644 --- a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift +++ b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift @@ -29,7 +29,7 @@ public struct WebSocketOutboundWriter: Sendable { case custom(WebSocketFrame) } - let type: WebSocket.SocketType + let type: WebSocketType let allocator: ByteBufferAllocator let outbound: NIOAsyncChannelOutboundWriter diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index bb4ca0f..780613c 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -219,37 +219,37 @@ final class HummingbirdWebSocketTests: XCTestCase { // MARK: Tests func testServerToClientMessage() async throws { - try await self.testClientAndServer { ws, _ in - try await ws.outbound.write(.text("Hello")) - } client: { ws, _ in - var inboundIterator = ws.inbound.makeAsyncIterator() + try await self.testClientAndServer { _, outbound, _ in + try await outbound.write(.text("Hello")) + } client: { inbound, _, _ in + var inboundIterator = inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello")) } } func testClientToServerMessage() async throws { - try await self.testClientAndServer { ws, _ in - var inboundIterator = ws.inbound.makeAsyncIterator() + try await self.testClientAndServer { inbound, _, _ in + var inboundIterator = inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello")) - } client: { ws, _ in - try await ws.outbound.write(.text("Hello")) + } client: { _, outbound, _ in + try await outbound.write(.text("Hello")) } } func testClientToServerSplitPacket() async throws { - try await self.testClientAndServer { ws, _ in - for try await packet in ws.inbound { - try await ws.outbound.write(.custom(packet.webSocketFrame)) + try await self.testClientAndServer { inbound, outbound, _ in + for try await packet in inbound { + try await outbound.write(.custom(packet.webSocketFrame)) } - } client: { ws, _ in + } client: { inbound, outbound, _ in let buffer = ByteBuffer(string: "Hello ") - try await ws.outbound.write(.custom(.init(fin: false, opcode: .text, data: buffer))) + try await outbound.write(.custom(.init(fin: false, opcode: .text, data: buffer))) let buffer2 = ByteBuffer(string: "World!") - try await ws.outbound.write(.custom(.init(fin: true, opcode: .text, data: buffer2))) + try await outbound.write(.custom(.init(fin: true, opcode: .text, data: buffer2))) - var inboundIterator = ws.inbound.makeAsyncIterator() + var inboundIterator = inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello World!")) } @@ -257,23 +257,23 @@ final class HummingbirdWebSocketTests: XCTestCase { // test connection is closed when buffer is too large func testTooLargeBuffer() async throws { - try await self.testClientAndServer { ws, _ in + try await self.testClientAndServer { inbound, outbound, _ in let buffer = ByteBuffer(repeating: 1, count: (1 << 14) + 1) - try await ws.outbound.write(.binary(buffer)) - for try await _ in ws.inbound {} - } client: { ws, _ in - for try await _ in ws.inbound {} + try await outbound.write(.binary(buffer)) + for try await _ in inbound {} + } client: { inbound, _, _ in + for try await _ in inbound {} } } func testNotWebSocket() async throws { do { - try await self.testClientAndServer { ws, _ in - for try await _ in ws.inbound {} + try await self.testClientAndServer { inbound, _, _ in + for try await _ in inbound {} } shouldUpgrade: { _ in return nil - } client: { ws, _ in - for try await _ in ws.inbound {} + } client: { inbound, _, _ in + for try await _ in inbound {} } } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} } @@ -282,7 +282,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let client = try WebSocketClient( url: .init("ws://localhost:10245"), logger: Logger(label: "TestNoConnection") - ) { _, _ in + ) { _, _, _ in } do { try await client.run() @@ -291,8 +291,8 @@ final class HummingbirdWebSocketTests: XCTestCase { } func testTLS() async throws { - try await self.testClientAndServer(serverTLSConfiguration: getServerTLSConfiguration()) { ws, _ in - try await ws.outbound.write(.text("Hello")) + try await self.testClientAndServer(serverTLSConfiguration: getServerTLSConfiguration()) { _, outbound, _ in + try await outbound.write(.text("Hello")) } getClient: { port, logger in var clientTLSConfiguration = try getClientTLSConfiguration() clientTLSConfiguration.certificateVerification = .none @@ -300,8 +300,8 @@ final class HummingbirdWebSocketTests: XCTestCase { url: .init("wss://localhost:\(port)"), tlsConfiguration: clientTLSConfiguration, logger: logger - ) { ws, _ in - var inboundIterator = ws.inbound.makeAsyncIterator() + ) { inbound, _, _ in + var inboundIterator = inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello")) } @@ -309,8 +309,8 @@ final class HummingbirdWebSocketTests: XCTestCase { } func testURLPath() async throws { - try await self.testClientAndServer { ws, _ in - for try await _ in ws.inbound {} + try await self.testClientAndServer { inbound, _, _ in + for try await _ in inbound {} } shouldUpgrade: { head in XCTAssertEqual(head.path, "/ws") return [:] @@ -318,14 +318,14 @@ final class HummingbirdWebSocketTests: XCTestCase { try WebSocketClient( url: .init("ws://localhost:\(port)/ws"), logger: logger - ) { _, _ in + ) { _, _, _ in } } } func testQueryParameters() async throws { - try await self.testClientAndServer { ws, _ in - for try await _ in ws.inbound {} + try await self.testClientAndServer { inbound, _, _ in + for try await _ in inbound {} } shouldUpgrade: { head in let request = Request(head: head, body: .init(buffer: ByteBuffer())) XCTAssertEqual(request.uri.query, "query=parameters&test=true") @@ -334,14 +334,14 @@ final class HummingbirdWebSocketTests: XCTestCase { try WebSocketClient( url: .init("ws://localhost:\(port)/ws?query=parameters&test=true"), logger: logger - ) { _, _ in + ) { _, _, _ in } } } func testAdditionalHeaders() async throws { - try await self.testClientAndServer { ws, _ in - for try await _ in ws.inbound {} + try await self.testClientAndServer { inbound, _, _ in + for try await _ in inbound {} } shouldUpgrade: { head in let request = Request(head: head, body: .init(buffer: ByteBuffer())) XCTAssertEqual(request.headers[.secWebSocketExtensions], "hb") @@ -351,7 +351,7 @@ final class HummingbirdWebSocketTests: XCTestCase { url: .init("ws://localhost:\(port)/ws?query=parameters&test=true"), configuration: .init(additionalHeaders: [.secWebSocketExtensions: "hb"]), logger: logger - ) { _, _ in + ) { _, _, _ in } } } @@ -370,8 +370,8 @@ final class HummingbirdWebSocketTests: XCTestCase { let app = Application( router: router, server: .webSocketUpgrade { _, _, _ in - return .upgrade([:]) { ws, _ in - try await ws.outbound.write(.text("Hello")) + return .upgrade([:]) { _, outbound, _ in + try await outbound.write(.text("Hello")) } }, onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, @@ -388,8 +388,8 @@ final class HummingbirdWebSocketTests: XCTestCase { try await serviceGroup.run() } group.addTask { - try await WebSocketClient.connect(url: .init("ws://localhost:\(promise.wait())/ws"), logger: logger) { ws, _ in - var inboundIterator = ws.inbound.makeAsyncIterator() + try await WebSocketClient.connect(url: .init("ws://localhost:\(promise.wait())/ws"), logger: logger) { inbound, _, _ in + var inboundIterator = inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Hello")) } @@ -403,24 +403,24 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws1") { _, _ in return .upgrade([:]) - } onUpgrade: { ws, _ in - try await ws.outbound.write(.text("One")) + } onUpgrade: { _, outbound, _ in + try await outbound.write(.text("One")) } router.ws("/ws2") { _, _ in return .upgrade([:]) - } onUpgrade: { ws, _ in - try await ws.outbound.write(.text("Two")) + } onUpgrade: { _, outbound, _ in + try await outbound.write(.text("Two")) } try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/ws1"), logger: logger) { ws, _ in - var inboundIterator = ws.inbound.makeAsyncIterator() + try WebSocketClient(url: .init("ws://localhost:\(port)/ws1"), logger: logger) { inbound, _, _ in + var inboundIterator = inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("One")) } } try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/ws2"), logger: logger) { ws, _ in - var inboundIterator = ws.inbound.makeAsyncIterator() + try WebSocketClient(url: .init("ws://localhost:\(port)/ws2"), logger: logger) { inbound, _, _ in + var inboundIterator = inbound.makeAsyncIterator() let msg = await inboundIterator.next() XCTAssertEqual(msg, .text("Two")) } @@ -432,13 +432,13 @@ final class HummingbirdWebSocketTests: XCTestCase { router.group("/ws") .add(middleware: WebSocketUpgradeMiddleware { _, _ in return .upgrade([:]) - } onUpgrade: { ws, _ in - try await ws.outbound.write(.text("One")) + } onUpgrade: { _, outbound, _ in + try await outbound.write(.text("One")) }) .get { _, _ -> Response in return .init(status: .ok) } do { try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { _, _ in } + try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { _, _, _ in } } } } @@ -447,12 +447,12 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in return .upgrade([:]) - } onUpgrade: { ws, _ in - try await ws.outbound.write(.text("One")) + } onUpgrade: { _, outbound, _ in + try await outbound.write(.text("One")) } do { try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/not-ws"), logger: logger) { _, _ in } + try WebSocketClient(url: .init("ws://localhost:\(port)/not-ws"), logger: logger) { _, _, _ in } } } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} } @@ -481,13 +481,13 @@ final class HummingbirdWebSocketTests: XCTestCase { router.middlewares.add(MyMiddleware()) router.ws("/ws") { _, _ in return .upgrade([:]) - } onUpgrade: { ws, context in - try await ws.outbound.write(.text(context.name)) + } onUpgrade: { _, outbound, context in + try await outbound.write(.text(context.name)) } do { try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in - try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { ws, _ in - let text = await ws.inbound.first { _ in true } + try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { inbound, _, _ in + let text = await inbound.first { _ in true } XCTAssertEqual(text, .text("Roger Moore")) } } @@ -498,8 +498,8 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in return .upgrade([:]) - } onUpgrade: { ws, _ in - try await ws.outbound.write(.text("Hello")) + } onUpgrade: { _, outbound, _ in + try await outbound.write(.text("Hello")) } router.get("/http") { _, _ in return "Hello"