From 20559a9b8bb292684a77ab0198ade2ef2bcd3fbc Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 22 Mar 2024 14:29:37 +0000 Subject: [PATCH 1/9] Add permessage deflate code in --- Package.swift | 10 +- .../PerMessageDeflateExtension.swift | 10 +- .../WebSocketExtension.swift | 161 ++++++++++++++++++ 3 files changed, 171 insertions(+), 10 deletions(-) create mode 100644 Sources/HummingbirdWebSocket/WebSocketExtension.swift diff --git a/Package.swift b/Package.swift index 911fe6e..680eb3b 100644 --- a/Package.swift +++ b/Package.swift @@ -32,13 +32,13 @@ let package = Package( .product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"), .product(name: "NIOWebSocket", package: "swift-nio"), ]), - /* .target(name: "HummingbirdWSCompression", dependencies: [ - .byName(name: "HummingbirdWSCore"), - .product(name: "CompressNIO", package: "compress-nio"), - ]),*/ + .target(name: "HummingbirdWSCompression", dependencies: [ + .byName(name: "HummingbirdWebSocket"), + .product(name: "CompressNIO", package: "compress-nio"), + ]), .testTarget(name: "HummingbirdWebSocketTests", dependencies: [ .byName(name: "HummingbirdWebSocket"), - // .byName(name: "HummingbirdWSCompression"), + .byName(name: "HummingbirdWSCompression"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Hummingbird", package: "hummingbird"), .product(name: "HummingbirdTesting", package: "hummingbird"), diff --git a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift index 4e9efc3..4078e8b 100644 --- a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift +++ b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// import CompressNIO -import HummingbirdWSCore +import HummingbirdWebSocket import NIOCore import NIOWebSocket @@ -201,7 +201,7 @@ struct PerMessageDeflateExtension: WebSocketExtension { self.internalState.value.shutdown() } - func processReceivedFrame(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame { + func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame { var frame = frame if frame.rsv1 { let state = self.internalState.value @@ -209,7 +209,7 @@ struct PerMessageDeflateExtension: WebSocketExtension { // Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame // send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2). frame.data.writeBytes([0, 0, 255, 255]) - frame.data = try frame.data.decompressStream(with: state.decompressor, maxSize: ws.maxFrameSize, allocator: ws.channel.allocator) + frame.data = try frame.data.decompressStream(with: state.decompressor, maxSize: ws.maxFrameSize, allocator: context.allocator) if self.configuration.receiveNoContextTakeover { try state.decompressor.resetStream() } @@ -217,7 +217,7 @@ struct PerMessageDeflateExtension: WebSocketExtension { return frame } - func processFrameToSend(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame { + func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame { let state = self.internalState.value // if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message // compress the data @@ -229,7 +229,7 @@ struct PerMessageDeflateExtension: WebSocketExtension { newFrame.rsv1 = true state.sendState = .sendingMessage } - newFrame.data = try newFrame.data.compressStream(with: state.compressor, flush: .sync, allocator: ws.channel.allocator) + newFrame.data = try newFrame.data.compressStream(with: state.compressor, flush: .sync, allocator: context.allocator) // if final frame then remove last four bytes 0x00 0x00 0xff 0xff // (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1) if newFrame.fin { diff --git a/Sources/HummingbirdWebSocket/WebSocketExtension.swift b/Sources/HummingbirdWebSocket/WebSocketExtension.swift new file mode 100644 index 0000000..d9dab43 --- /dev/null +++ b/Sources/HummingbirdWebSocket/WebSocketExtension.swift @@ -0,0 +1,161 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes +import NIOCore +import NIOWebSocket + +/// Protocol for WebSocket extension +public protocol WebSocketExtension: Sendable { + /// Process frame received from websocket + func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame + /// Process frame about to be sent to websocket + func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame + /// shutdown extension + func shutdown() +} + +/// Protocol for WebSocket extension builder +public protocol WebSocketExtensionBuilder { + /// name of WebSocket extension name + static var name: String { get } + /// construct client request header + func clientRequestHeader() -> String + /// construct server response header based of client request + func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? + /// construct server version of extension based of client request + func serverExtension(from: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (any WebSocketExtension)? + /// construct client version of extension based of server response + func clientExtension(from: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (any WebSocketExtension)? +} + +extension WebSocketExtensionBuilder { + /// construct server response header based of all client requests + public func serverResponseHeader(to requests: [WebSocketExtensionHTTPParameters]) -> String? { + for request in requests { + guard request.name == Self.name else { continue } + if let response = serverReponseHeader(to: request) { + return response + } + } + return nil + } + + /// construct all server extensions based of all client requests + public func serverExtension(from requests: [WebSocketExtensionHTTPParameters], eventLoop: EventLoop) throws -> (any WebSocketExtension)? { + for request in requests { + guard request.name == Self.name else { continue } + if let ext = try serverExtension(from: request, eventLoop: eventLoop) { + return ext + } + } + return nil + } + + /// construct all client extensions based of all server responses + public func clientExtension(from requests: [WebSocketExtensionHTTPParameters], eventLoop: EventLoop) throws -> (any WebSocketExtension)? { + for request in requests { + guard request.name == Self.name else { continue } + if let ext = try clientExtension(from: request, eventLoop: eventLoop) { + return ext + } + } + return nil + } +} + +/// Build WebSocket extension builder +public struct WebSocketExtensionFactory: Sendable { + public let build: @Sendable () -> any WebSocketExtensionBuilder + + public init(_ build: @escaping @Sendable () -> any WebSocketExtensionBuilder) { + self.build = build + } +} + +/// Parsed parameters from `Sec-WebSocket-Extensions` header +public struct WebSocketExtensionHTTPParameters: Sendable, Equatable { + /// A single parameter + public enum Parameter: Sendable, Equatable { + // Parameter with a value + case value(String) + // Parameter with no value + case null + + // Convert to optional + public var optional: String? { + switch self { + case .value(let string): + return .some(string) + case .null: + return .none + } + } + + // Convert to integer + public var integer: Int? { + switch self { + case .value(let string): + return Int(string) + case .null: + return .none + } + } + } + + public let parameters: [String: Parameter] + let name: String + + /// initialise WebSocket extension parameters from string + init?(from header: some StringProtocol) { + let split = header.split(separator: ";", omittingEmptySubsequences: true).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }[...] + if let name = split.first { + self.name = name + } else { + return nil + } + var index = split.index(after: split.startIndex) + var parameters: [String: Parameter] = [:] + while index != split.endIndex { + let keyValue = split[index].split(separator: "=", maxSplits: 1).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } + if let key = keyValue.first { + if keyValue.count > 1 { + parameters[key] = .value(keyValue[1]) + } else { + parameters[key] = .null + } + } + index = split.index(after: index) + } + self.parameters = parameters + } + + /// Parse all `Sec-WebSocket-Extensions` header values + /// - Parameters: + /// - headers: headers coming from other + /// - type: client or server + /// - Returns: Array of extensions + public static func parseHeaders(_ headers: HTTPFields) -> [WebSocketExtensionHTTPParameters] { + let extHeaders = headers[values: .secWebSocketExtensions] + return extHeaders.compactMap { .init(from: $0) } + } +} + +extension WebSocketExtensionHTTPParameters { + /// Initialiser used by tests + init(_ name: String, parameters: [String: Parameter]) { + self.name = name + self.parameters = parameters + } +} From 32d6ff5116892c4e09bbd36bb53aa12951fe537c Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 22 Mar 2024 14:29:56 +0000 Subject: [PATCH 2/9] Make configuration generic parameter of WebSocketHandler --- .../Client/WebSocketClientConfiguration.swift | 4 ++- .../Server/WebSocketServerConfiguration.swift | 5 +++- .../WebSocketConfiguration.swift | 27 +++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 Sources/HummingbirdWebSocket/WebSocketConfiguration.swift diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift index d4085fe..eb596a6 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift @@ -14,11 +14,13 @@ import HTTPTypes -public struct WebSocketClientConfiguration: Sendable { +public struct WebSocketClientConfiguration: WebSocketConfiguration { /// Max websocket frame size that can be sent/received public var maxFrameSize: Int /// Additional headers to be sent with the initial HTTP request public var additionalHeaders: HTTPFields + /// WebSocket type + public var type: WebSocketType { .client } /// Initialize WebSocketClient configuration /// - Paramters diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift b/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift index 8d7e471..b921740 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift @@ -13,9 +13,11 @@ //===----------------------------------------------------------------------===// /// Configuration for a WebSocket server -public struct WebSocketServerConfiguration: Sendable { +public struct WebSocketServerConfiguration: WebSocketConfiguration { /// Max websocket frame size that can be sent/received public var maxFrameSize: Int + /// WebSocket type + public var type: WebSocketType { .server } /// Initialize WebSocketClient configuration /// - Paramters @@ -26,4 +28,5 @@ public struct WebSocketServerConfiguration: Sendable { ) { self.maxFrameSize = maxFrameSize } + } diff --git a/Sources/HummingbirdWebSocket/WebSocketConfiguration.swift b/Sources/HummingbirdWebSocket/WebSocketConfiguration.swift new file mode 100644 index 0000000..8948640 --- /dev/null +++ b/Sources/HummingbirdWebSocket/WebSocketConfiguration.swift @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +public enum WebSocketType: Sendable { + case client + case server +} + +/// Configuration for a WebSocket server +public protocol WebSocketConfiguration: Sendable { + /// Max WebSocket frame size that can be sent/received + var maxFrameSize: Int { get } + + // WebSocket type + var type: WebSocketType { get } +} From c15051de35ffba8d33b7ef531a8a0ae76c2f83c3 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 22 Mar 2024 19:06:49 +0000 Subject: [PATCH 3/9] WebSocket extension handshake --- .../PerMessageDeflateExtension.swift | 214 ++++++---- .../Client/WebSocketClientChannel.swift | 17 +- .../Client/WebSocketClientConfiguration.swift | 10 +- .../NIOWebSocketServerUpgrade+ext.swift | 5 +- .../Server/WebSocketChannel.swift | 35 +- .../Server/WebSocketRouter.swift | 11 +- .../Server/WebSocketServerConfiguration.swift | 11 +- .../WebSocketConfiguration.swift | 27 -- .../WebSocketExtension.swift | 8 +- .../WebSocketHandler.swift | 4 +- .../WebSocketExtensionTests.swift | 389 ++++++++++++------ 11 files changed, 467 insertions(+), 264 deletions(-) delete mode 100644 Sources/HummingbirdWebSocket/WebSocketConfiguration.swift diff --git a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift index 4078e8b..98d7faa 100644 --- a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift +++ b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift @@ -27,6 +27,7 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { let serverNoContextTakeover: Bool let compressionLevel: Int? let memoryLevel: Int? + let maxDecompressedFrameSize: Int init( clientMaxWindow: Int? = nil, @@ -34,7 +35,8 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { serverMaxWindow: Int? = nil, serverNoContextTakeover: Bool = false, compressionLevel: Int? = nil, - memoryLevel: Int? = nil + memoryLevel: Int? = nil, + maxDecompressedFrameSize: Int = (1 << 14) ) { self.clientMaxWindow = clientMaxWindow self.clientNoContextTakeover = clientNoContextTakeover @@ -42,6 +44,7 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { self.serverNoContextTakeover = serverNoContextTakeover self.compressionLevel = compressionLevel self.memoryLevel = memoryLevel + self.maxDecompressedFrameSize = maxDecompressedFrameSize } /// Return client request header @@ -107,7 +110,8 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { sendMaxWindow: clientMaxWindowParam, sendNoContextTakeover: clientNoContextTakeoverParam, compressionLevel: self.compressionLevel, - memoryLevel: self.memoryLevel + memoryLevel: self.memoryLevel, + maxDecompressedFrameSize: self.maxDecompressedFrameSize ), eventLoop: eventLoop) } @@ -134,7 +138,8 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { sendMaxWindow: optionalMin(requestServerMaxWindow?.integer, self.serverMaxWindow), sendNoContextTakeover: requestServerNoContextTakeover || self.serverNoContextTakeover, compressionLevel: self.compressionLevel, - memoryLevel: self.memoryLevel + memoryLevel: self.memoryLevel, + maxDecompressedFrameSize: self.maxDecompressedFrameSize ) } } @@ -144,11 +149,6 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { /// Uses deflate to compress messages sent across a WebSocket /// See RFC 7692 for more details https://www.rfc-editor.org/rfc/rfc7692 struct PerMessageDeflateExtension: WebSocketExtension { - enum SendState: Sendable { - case idle - case sendingMessage - } - struct Configuration: Sendable { let receiveMaxWindow: Int? let receiveNoContextTakeover: Bool @@ -156,90 +156,157 @@ struct PerMessageDeflateExtension: WebSocketExtension { let sendNoContextTakeover: Bool let compressionLevel: Int? let memoryLevel: Int? + let maxDecompressedFrameSize: Int } - /// Internal mutable state and referenced types, that cannot be set to Sendable - class InternalState { + actor Decompressor { fileprivate let decompressor: any NIODecompressor + + init(_ decompressor: any NIODecompressor) throws { + self.decompressor = decompressor + try self.decompressor.startStream() + } + + func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: some WebSocketContextProtocol) throws -> WebSocketFrame { + var frame = frame + precondition(frame.fin, "Only concatenated frames with fin set can be processed by the permessage-deflate extension") + // Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame + // send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2). + frame.data.writeBytes([0, 0, 255, 255]) + frame.data = try frame.data.decompressStream(with: self.decompressor, maxSize: maxSize, allocator: context.allocator) + if resetStream { + try self.decompressor.resetStream() + } + return frame + } + + func shutdown() throws { + try self.decompressor.finishStream() + } + } + + actor Compressor { + enum SendState: Sendable { + case idle + case sendingMessage + } + fileprivate let compressor: any NIOCompressor - fileprivate var sendState: SendState + var sendState: SendState - init(configuration: Configuration) throws { - self.decompressor = CompressionAlgorithm.deflate( - configuration: .init( - windowSize: numericCast(configuration.receiveMaxWindow ?? 15) - ) - ).decompressor - // compression level -1 will setup the default compression level, 8 is the default memory level - self.compressor = CompressionAlgorithm.deflate( - configuration: .init( - windowSize: numericCast(configuration.sendMaxWindow ?? 15), - compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1, - memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8 - ) - ).compressor + init(_ compressor: any NIOCompressor) throws { + self.compressor = compressor self.sendState = .idle - try self.decompressor.startStream() try self.compressor.startStream() } - func shutdown() { - try? self.compressor.finishStream() - try? self.decompressor.finishStream() + func compress(_ frame: WebSocketFrame, resetStream: Bool, context: some WebSocketContextProtocol) throws -> WebSocketFrame { + // if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message + // compress the data + let shouldWeCompress = frame.data.readableBytes > 16 || !frame.fin || self.sendState != .idle + if shouldWeCompress { + var newFrame = frame + if self.sendState == .idle { + newFrame.rsv1 = true + self.sendState = .sendingMessage + } + newFrame.data = try newFrame.data.compressStream(with: self.compressor, flush: .sync, allocator: context.allocator) + // if final frame then remove last four bytes 0x00 0x00 0xff 0xff + // (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1) + if newFrame.fin { + newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data + self.sendState = .idle + if resetStream { + try self.compressor.resetStream() + } + } + return newFrame + } + return frame + } + + func shutdown() throws { + try self.compressor.finishStream() } } + /// Internal mutable state and referenced types, that cannot be set to Sendable + /* class InternalState { + fileprivate let decompressor: any NIODecompressor + fileprivate let compressor: any NIOCompressor + fileprivate var sendState: SendState + + init(configuration: Configuration) throws { + self.decompressor = CompressionAlgorithm.deflate( + configuration: .init( + windowSize: numericCast(configuration.receiveMaxWindow ?? 15) + ) + ).decompressor + // compression level -1 will setup the default compression level, 8 is the default memory level + self.compressor = CompressionAlgorithm.deflate( + configuration: .init( + windowSize: numericCast(configuration.sendMaxWindow ?? 15), + compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1, + memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8 + ) + ).compressor + self.sendState = .idle + try self.decompressor.startStream() + try self.compressor.startStream() + } + + func shutdown() { + try? self.compressor.finishStream() + try? self.decompressor.finishStream() + } + } */ + let configuration: Configuration - let internalState: NIOLoopBound + let decompressor: Decompressor + let compressor: Compressor + // let internalState: NIOLoopBound init(configuration: Configuration, eventLoop: EventLoop) throws { self.configuration = configuration - self.internalState = try .init(.init(configuration: configuration), eventLoop: eventLoop) + self.decompressor = try .init( + CompressionAlgorithm.deflate( + configuration: .init( + windowSize: numericCast(configuration.receiveMaxWindow ?? 15) + ) + ).decompressor + ) + self.compressor = try .init( + CompressionAlgorithm.deflate( + configuration: .init( + windowSize: numericCast(configuration.sendMaxWindow ?? 15), + compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1, + memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8 + ) + ).compressor + ) } - func shutdown() { - self.internalState.value.shutdown() + func shutdown() async { + try? await self.decompressor.shutdown() + try? await self.compressor.shutdown() } - func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame { - var frame = frame + func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame { if frame.rsv1 { - let state = self.internalState.value - precondition(frame.fin, "Only concatenated frames with fin set can be processed by the permessage-deflate extension") - // Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame - // send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2). - frame.data.writeBytes([0, 0, 255, 255]) - frame.data = try frame.data.decompressStream(with: state.decompressor, maxSize: ws.maxFrameSize, allocator: context.allocator) - if self.configuration.receiveNoContextTakeover { - try state.decompressor.resetStream() - } + return try await self.decompressor.decompress( + frame, + maxSize: self.configuration.maxDecompressedFrameSize, + resetStream: self.configuration.receiveNoContextTakeover, + context: context + ) } return frame } - func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame { - let state = self.internalState.value - // if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message - // compress the data - let shouldWeCompress = frame.data.readableBytes > 16 || !frame.fin || state.sendState != .idle + func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame { let isCorrectType = frame.opcode == .text || frame.opcode == .binary - if shouldWeCompress, isCorrectType { - var newFrame = frame - if state.sendState == .idle { - newFrame.rsv1 = true - state.sendState = .sendingMessage - } - newFrame.data = try newFrame.data.compressStream(with: state.compressor, flush: .sync, allocator: context.allocator) - // if final frame then remove last four bytes 0x00 0x00 0xff 0xff - // (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1) - if newFrame.fin { - newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data - state.sendState = .idle - if self.configuration.sendNoContextTakeover { - try state.compressor.resetStream() - } - } - return newFrame + if isCorrectType { + return try await self.compressor.compress(frame, resetStream: self.configuration.sendNoContextTakeover, context: context) } return frame } @@ -250,7 +317,11 @@ extension WebSocketExtensionFactory { /// - Parameters: /// - maxWindow: Max window to be used for decompression and compression /// - noContextTakeover: Should we reset window on every message - public static func perMessageDeflate(maxWindow: Int? = nil, noContextTakeover: Bool = false) -> WebSocketExtensionFactory { + public static func perMessageDeflate( + maxWindow: Int? = nil, + noContextTakeover: Bool = false, + maxDecompressedFrameSize: Int = 1 << 14 + ) -> WebSocketExtensionFactory { return .init { PerMessageDeflateExtensionBuilder( clientMaxWindow: maxWindow, @@ -258,7 +329,8 @@ extension WebSocketExtensionFactory { serverMaxWindow: maxWindow, serverNoContextTakeover: noContextTakeover, compressionLevel: nil, - memoryLevel: nil + memoryLevel: nil, + maxDecompressedFrameSize: maxDecompressedFrameSize ) } } @@ -279,7 +351,8 @@ extension WebSocketExtensionFactory { serverMaxWindow: Int? = nil, serverNoContextTakeover: Bool = false, compressionLevel: Int? = nil, - memoryLevel: Int? = nil + memoryLevel: Int? = nil, + maxDecompressedFrameSize: Int = 1 << 14 ) -> WebSocketExtensionFactory { return .init { PerMessageDeflateExtensionBuilder( @@ -288,7 +361,8 @@ extension WebSocketExtensionFactory { serverMaxWindow: serverMaxWindow, serverNoContextTakeover: serverNoContextTakeover, compressionLevel: compressionLevel, - memoryLevel: memoryLevel + memoryLevel: memoryLevel, + maxDecompressedFrameSize: maxDecompressedFrameSize ) } } diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift index 20d9330..b89bf68 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift @@ -22,7 +22,7 @@ import NIOWebSocket struct WebSocketClientChannel: ClientConnectionChannel { enum UpgradeResult { - case websocket(NIOAsyncChannel) + case websocket(NIOAsyncChannel, [any WebSocketExtension]) case notUpgraded } @@ -42,10 +42,16 @@ struct WebSocketClientChannel: ClientConnectionChannel { channel.eventLoop.makeCompletedFuture { let upgrader = NIOTypedWebSocketClientUpgrader( maxFrameSize: self.configuration.maxFrameSize, - upgradePipelineHandler: { channel, _ in + upgradePipelineHandler: { channel, head in channel.eventLoop.makeCompletedFuture { let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - return UpgradeResult.websocket(asyncChannel) + // work out what extensions we should add + let headerFields = HTTPFields(head.headers, splitCookie: false) + let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(headerFields) + let extensions = try configuration.extensions.compactMap { + try $0.clientExtension(from: serverExtensions, eventLoop: channel.eventLoop) + } + return UpgradeResult.websocket(asyncChannel, extensions) } } ) @@ -55,6 +61,7 @@ struct WebSocketClientChannel: ClientConnectionChannel { headers.add(name: "Content-Length", value: "0") let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders) headers.add(contentsOf: additionalHeaders) + headers.add(contentsOf: self.configuration.extensions.map { (name: "Sec-WebSocket-Extensions", value: $0.clientRequestHeader()) }) let requestHead = HTTPRequestHead( version: .http1_1, @@ -83,8 +90,8 @@ struct WebSocketClientChannel: ClientConnectionChannel { func handle(value: Value, logger: Logger) async throws { switch try await value.get() { - case .websocket(let webSocketChannel): - let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client) + case .websocket(let webSocketChannel, let extensions): + let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client, extensions: extensions) await webSocket.handle(handler: self.handler, context: WebSocketContext(channel: webSocketChannel.channel, logger: logger)) case .notUpgraded: // The upgrade to websocket did not succeed. diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift index eb596a6..d8049af 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift @@ -14,13 +14,13 @@ import HTTPTypes -public struct WebSocketClientConfiguration: WebSocketConfiguration { +public struct WebSocketClientConfiguration: Sendable { /// Max websocket frame size that can be sent/received public var maxFrameSize: Int /// Additional headers to be sent with the initial HTTP request public var additionalHeaders: HTTPFields - /// WebSocket type - public var type: WebSocketType { .client } + /// WebSocket extensions + public var extensions: [any WebSocketExtensionBuilder] /// Initialize WebSocketClient configuration /// - Paramters @@ -28,9 +28,11 @@ public struct WebSocketClientConfiguration: WebSocketConfiguration { /// - additionalHeaders: Additional headers to be sent with the initial HTTP request public init( maxFrameSize: Int = (1 << 14), - additionalHeaders: HTTPFields = .init() + additionalHeaders: HTTPFields = .init(), + extensions: [WebSocketExtensionFactory] = [] ) { self.maxFrameSize = maxFrameSize self.additionalHeaders = additionalHeaders + self.extensions = extensions.map { $0.build() } } } diff --git a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift index 3ca339c..0553024 100644 --- a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift +++ b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift @@ -25,12 +25,13 @@ public enum ShouldUpgradeResult: Sendable { case upgrade(HTTPFields, Value) /// Map upgrade result to difference type - func map(_ map: (Value) throws -> Result) rethrows -> ShouldUpgradeResult { + func map(_ map: (HTTPFields, Value) throws -> (HTTPFields, Result)) rethrows -> ShouldUpgradeResult { switch self { case .dontUpgrade: return .dontUpgrade case .upgrade(let headers, let value): - return try .upgrade(headers, map(value)) + let result = try map(headers, value) + return .upgrade(result.0, result.1) } } } diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 7759a25..1c89fd1 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -50,15 +50,23 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { ) { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration - self.shouldUpgrade = { head, channel, logger in + self.shouldUpgrade = { head, channel, logger -> EventLoopFuture> in channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in try shouldUpgrade(head, channel, logger) - .map { handler in - return { asyncChannel, logger in - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + .map { headers, handler -> (HTTPFields, WebSocketChannelHandler) in + var headers = headers + let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(head.headerFields) + let responseHeaders = configuration.extensions.compactMap { $0.serverResponseHeader(to: clientHeaders) } + headers.append(contentsOf: responseHeaders.map { .init(name: HTTPField.Name.secWebSocketExtensions, value: $0) }) + let extensions = try configuration.extensions.compactMap { + try $0.serverExtension(from: clientHeaders, eventLoop: channel.eventLoop) + } + + return (headers, { asyncChannel, logger in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) let context = WebSocketContext(channel: channel, logger: logger) await webSocket.handle(handler: handler, context: context) - } + }) } } } @@ -80,16 +88,23 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { ) { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration - self.shouldUpgrade = { head, channel, logger in + self.shouldUpgrade = { head, channel, logger -> EventLoopFuture> in let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) promise.completeWithTask { try await shouldUpgrade(head, channel, logger) - .map { handler in - return { asyncChannel, logger in - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + .map { headers, handler in + var headers = headers + let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(head.headerFields) + let responseHeaders = configuration.extensions.compactMap { $0.serverResponseHeader(to: clientHeaders) } + headers.append(contentsOf: responseHeaders.map { .init(name: HTTPField.Name.secWebSocketExtensions, value: $0) }) + let extensions = try configuration.extensions.compactMap { + try $0.serverExtension(from: clientHeaders, eventLoop: channel.eventLoop) + } + return (headers, { asyncChannel, logger in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) let context = WebSocketContext(channel: channel, logger: logger) await webSocket.handle(handler: handler, context: context) - } + }) } } return promise.futureResult diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index 54e1570..a8527c5 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -141,8 +141,15 @@ extension HTTP1AndWebSocketChannel { do { let response = try await webSocketResponder.respond(to: request, context: context) if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { - return .upgrade(response.headers) { asyncChannel, _ in - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + var headers = response.headers + let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(head.headerFields) + let responseHeaders = configuration.extensions.compactMap { $0.serverResponseHeader(to: clientHeaders) } + headers.append(contentsOf: responseHeaders.map { .init(name: HTTPField.Name.secWebSocketExtensions, value: $0) }) + let extensions = try configuration.extensions.compactMap { + try $0.serverExtension(from: clientHeaders, eventLoop: channel.eventLoop) + } + return .upgrade(headers) { asyncChannel, _ in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) await webSocket.handle(handler: webSocketHandler.handler, context: webSocketHandler.context) } } else { diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift b/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift index b921740..c80d9d0 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift @@ -13,20 +13,21 @@ //===----------------------------------------------------------------------===// /// Configuration for a WebSocket server -public struct WebSocketServerConfiguration: WebSocketConfiguration { +public struct WebSocketServerConfiguration: Sendable { /// Max websocket frame size that can be sent/received public var maxFrameSize: Int - /// WebSocket type - public var type: WebSocketType { .server } + /// WebSocket extensions + public var extensions: [any WebSocketExtensionBuilder] /// Initialize WebSocketClient configuration /// - Paramters /// - maxFrameSize: Max websocket frame size that can be sent/received /// - additionalHeaders: Additional headers to be sent with the initial HTTP request public init( - maxFrameSize: Int = (1 << 14) + maxFrameSize: Int = (1 << 14), + extensions: [WebSocketExtensionFactory] = [] ) { self.maxFrameSize = maxFrameSize + self.extensions = extensions.map { $0.build() } } - } diff --git a/Sources/HummingbirdWebSocket/WebSocketConfiguration.swift b/Sources/HummingbirdWebSocket/WebSocketConfiguration.swift deleted file mode 100644 index 8948640..0000000 --- a/Sources/HummingbirdWebSocket/WebSocketConfiguration.swift +++ /dev/null @@ -1,27 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Hummingbird server framework project -// -// Copyright (c) 2023-2024 the Hummingbird authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -public enum WebSocketType: Sendable { - case client - case server -} - -/// Configuration for a WebSocket server -public protocol WebSocketConfiguration: Sendable { - /// Max WebSocket frame size that can be sent/received - var maxFrameSize: Int { get } - - // WebSocket type - var type: WebSocketType { get } -} diff --git a/Sources/HummingbirdWebSocket/WebSocketExtension.swift b/Sources/HummingbirdWebSocket/WebSocketExtension.swift index d9dab43..05fe603 100644 --- a/Sources/HummingbirdWebSocket/WebSocketExtension.swift +++ b/Sources/HummingbirdWebSocket/WebSocketExtension.swift @@ -19,15 +19,15 @@ import NIOWebSocket /// Protocol for WebSocket extension public protocol WebSocketExtension: Sendable { /// Process frame received from websocket - func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame + func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame /// Process frame about to be sent to websocket - func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame + func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame /// shutdown extension - func shutdown() + func shutdown() async } /// Protocol for WebSocket extension builder -public protocol WebSocketExtensionBuilder { +public protocol WebSocketExtensionBuilder: Sendable { /// name of WebSocket extension name static var name: String { get } /// construct client request header diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index ed1bc61..7ea4e3e 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -35,12 +35,14 @@ actor WebSocketHandler: Sendable { let type: WebSocketType var closed: Bool var pingData: ByteBuffer + let extensions: [any WebSocketExtension] - init(asyncChannel: NIOAsyncChannel, type: WebSocketType) { + init(asyncChannel: NIOAsyncChannel, type: WebSocketType, extensions: [any WebSocketExtension]) { self.asyncChannel = asyncChannel self.type = type self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize) self.closed = false + self.extensions = extensions } /// Handle WebSocket AsynChannel diff --git a/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift index 057e67a..1696e4e 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift @@ -11,19 +11,196 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -/* - import Hummingbird - import HummingbirdWebSocket - import HummingbirdWSClient - @testable import HummingbirdWSCompression - @testable import HummingbirdWSCore - import NIOCore - import NIOPosix - import NIOWebSocket - import XCTest - - final class HummingbirdWebSocketExtensionTests: XCTestCase { - static var eventLoopGroup: EventLoopGroup! + +import Hummingbird +import HummingbirdCore +@testable import HummingbirdWebSocket +@testable import HummingbirdWSCompression +import Logging +import NIOCore +import NIOWebSocket +import ServiceLifecycle +import XCTest + +final class HummingbirdWebSocketExtensionTests: XCTestCase { + func testClientAndServer( + serverExtensions: [WebSocketExtensionFactory] = [], + clientExtensions: [WebSocketExtensionFactory] = [], + server serverHandler: @escaping WebSocketDataHandler, + client clientHandler: @escaping WebSocketDataHandler + ) async throws { + try await withThrowingTaskGroup(of: Void.self) { group in + let promise = Promise() + let serverLogger = { + var logger = Logger(label: "WebSocketServer") + logger.logLevel = .debug + return logger + }() + let clientLogger = { + var logger = Logger(label: "WebSocketClient") + logger.logLevel = .debug + return logger + }() + let router = Router(context: BasicWebSocketRequestContext.self) + router.ws("/test", onUpgrade: serverHandler) + let serviceGroup: ServiceGroup + let app = Application( + router: router, + server: .webSocketUpgrade(webSocketRouter: router, configuration: .init(extensions: serverExtensions)), + onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, + logger: serverLogger + ) + serviceGroup = ServiceGroup( + configuration: .init( + services: [app], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: app.logger + ) + ) + group.addTask { + try await serviceGroup.run() + } + group.addTask { + let port = await promise.wait() + let client = try WebSocketClient( + url: .init("ws://localhost:\(port)/test"), + configuration: .init(extensions: clientExtensions), + logger: clientLogger, + handler: clientHandler + ) + do { + try await client.run() + } catch { + print("\(error)") + throw error + } + } + do { + try await group.next() + await serviceGroup.triggerGracefulShutdown() + } catch { + await serviceGroup.triggerGracefulShutdown() + throw error + } + } + } + + /// Create random buffer + /// - Parameters: + /// - size: size of buffer + /// - randomness: how random you want the buffer to be (percentage) + func createRandomBuffer(size: Int, randomness: Int = 100) -> ByteBuffer { + var buffer = ByteBufferAllocator().buffer(capacity: size) + let randomness = (randomness * randomness) / 100 + for i in 0.. WebSocketFrame { - var newBuffer = ws.channel.allocator.buffer(capacity: frame.data.readableBytes) - for byte in frame.data.readableBytesView { - newBuffer.writeInteger(byte ^ self.value) - } - var frame = frame - frame.data = newBuffer - return frame - } - - func processReceivedFrame(_ frame: WebSocketFrame, ws: WebSocket) -> WebSocketFrame { - return self.xorFrame(frame, ws: ws) - } - - func processFrameToSend(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame { - return self.xorFrame(frame, ws: ws) - } - - let value: UInt8 - } - - struct XorWebSocketExtensionBuilder: WebSocketExtensionBuilder { - static var name = "permessage-xor" - let value: UInt8? - - init(value: UInt8? = nil) { - self.value = value - } - - func clientRequestHeader() -> String { - var header = Self.name - if let value = value { - header += ";value=\(value)" - } - return header - } - - func serverReponseHeader(to request: WebSocketExtensionHTTPParameters) -> String? { - var header = Self.name - if let value = request.parameters["value"]?.integer { - header += ";value=\(value)" - } - return header - } - - func serverExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { - XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) - } - - func clientExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { - XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) - } - } - - extension WebSocketExtensionFactory { - static func xor(value: UInt8? = nil) -> WebSocketExtensionFactory { - .init { XorWebSocketExtensionBuilder(value: value) } - } - } - */ + }*/ +} + +struct XorWebSocketExtension: WebSocketExtension { + func shutdown() {} + + func xorFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) -> WebSocketFrame { + var newBuffer = context.allocator.buffer(capacity: frame.data.readableBytes) + for byte in frame.data.readableBytesView { + newBuffer.writeInteger(byte ^ self.value) + } + var frame = frame + frame.data = newBuffer + return frame + } + + func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) -> WebSocketFrame { + return self.xorFrame(frame, context: context) + } + + func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame { + return self.xorFrame(frame, context: context) + } + + let value: UInt8 +} + +struct XorWebSocketExtensionBuilder: WebSocketExtensionBuilder { + static var name = "permessage-xor" + let value: UInt8? + + init(value: UInt8? = nil) { + self.value = value + } + + func clientRequestHeader() -> String { + var header = Self.name + if let value { + header += ";value=\(value)" + } + return header + } + + func serverReponseHeader(to request: WebSocketExtensionHTTPParameters) -> String? { + var header = Self.name + if let value = request.parameters["value"]?.integer { + header += ";value=\(value)" + } + return header + } + + func serverExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { + XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) + } + + func clientExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { + XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) + } +} + +extension WebSocketExtensionFactory { + static func xor(value: UInt8? = nil) -> WebSocketExtensionFactory { + .init { XorWebSocketExtensionBuilder(value: value) } + } +} From 27457b1e194ab8ddded53055c17f12faf6a54728 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Sat, 23 Mar 2024 10:28:39 +0000 Subject: [PATCH 4/9] Apply extensions to inbound and outbound data --- Package.swift | 3 +- Snippets/WebsocketTest.swift | 3 +- .../PerMessageDeflateExtension.swift | 45 +-- .../Client/WebSocketClientChannel.swift | 5 +- .../Server/WebSocketChannel.swift | 63 +++- .../Server/WebSocketHTTPChannelBuilder.swift | 8 +- .../Server/WebSocketRouter.swift | 21 +- .../WebSocketDataHandler.swift | 2 +- .../WebSocketExtension.swift | 14 +- .../WebSocketHandler.swift | 32 +- .../WebSocketOutboundWriter.swift | 14 +- .../WebSocketExtensionTests.swift | 288 +++++------------- 12 files changed, 202 insertions(+), 296 deletions(-) diff --git a/Package.swift b/Package.swift index 680eb3b..9951868 100644 --- a/Package.swift +++ b/Package.swift @@ -8,7 +8,7 @@ let package = Package( platforms: [.macOS(.v14), .iOS(.v17), .tvOS(.v17)], products: [ .library(name: "HummingbirdWebSocket", targets: ["HummingbirdWebSocket"]), - // .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]), + .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]), ], dependencies: [ .package(url: "https://github.com/hummingbird-project/hummingbird.git", branch: "main"), @@ -18,7 +18,6 @@ let package = Package( .package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"), .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"), - .package(url: "https://github.com/swift-extras/swift-extras-base64.git", from: "0.5.0"), .package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"), ], targets: [ diff --git a/Snippets/WebsocketTest.swift b/Snippets/WebsocketTest.swift index 1c2ac3a..5097893 100644 --- a/Snippets/WebsocketTest.swift +++ b/Snippets/WebsocketTest.swift @@ -1,6 +1,7 @@ import HTTPTypes import Hummingbird import HummingbirdWebSocket +import HummingbirdWSCompression import Logging var logger = Logger(label: "Echo") @@ -22,7 +23,7 @@ router.ws("/ws") { inbound, outbound, _ in let app = Application( router: router, - server: .webSocketUpgrade(webSocketRouter: router), + server: .webSocketUpgrade(webSocketRouter: router, configuration: .init(extensions: [.perMessageDeflate()])), logger: logger ) try await app.runService() diff --git a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift index 98d7faa..420cb40 100644 --- a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift +++ b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift @@ -89,17 +89,15 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { /// Create server PerMessageDeflateExtension based off request headers /// - Parameters: /// - request: Client request - /// - eventLoop: EventLoop it is bound to - func serverExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { + func serverExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? { let configuration = self.responseConfiguration(to: request) - return try PerMessageDeflateExtension(configuration: configuration, eventLoop: eventLoop) + return try PerMessageDeflateExtension(configuration: configuration) } /// Create client PerMessageDeflateExtension based off response headers /// - Parameters: /// - response: Server response - /// - eventLoop: EventLoop it is bound to - func clientExtension(from response: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> WebSocketExtension? { + func clientExtension(from response: WebSocketExtensionHTTPParameters) throws -> WebSocketExtension? { let clientMaxWindowParam = response.parameters["client_max_window_bits"]?.integer let clientNoContextTakeoverParam = response.parameters["client_no_context_takeover"] != nil let serverMaxWindowParam = response.parameters["server_max_window_bits"]?.integer @@ -112,7 +110,7 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { compressionLevel: self.compressionLevel, memoryLevel: self.memoryLevel, maxDecompressedFrameSize: self.maxDecompressedFrameSize - ), eventLoop: eventLoop) + )) } private func responseConfiguration(to request: WebSocketExtensionHTTPParameters) -> PerMessageDeflateExtension.Configuration { @@ -230,43 +228,12 @@ struct PerMessageDeflateExtension: WebSocketExtension { } } - /// Internal mutable state and referenced types, that cannot be set to Sendable - /* class InternalState { - fileprivate let decompressor: any NIODecompressor - fileprivate let compressor: any NIOCompressor - fileprivate var sendState: SendState - - init(configuration: Configuration) throws { - self.decompressor = CompressionAlgorithm.deflate( - configuration: .init( - windowSize: numericCast(configuration.receiveMaxWindow ?? 15) - ) - ).decompressor - // compression level -1 will setup the default compression level, 8 is the default memory level - self.compressor = CompressionAlgorithm.deflate( - configuration: .init( - windowSize: numericCast(configuration.sendMaxWindow ?? 15), - compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1, - memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8 - ) - ).compressor - self.sendState = .idle - try self.decompressor.startStream() - try self.compressor.startStream() - } - - func shutdown() { - try? self.compressor.finishStream() - try? self.decompressor.finishStream() - } - } */ - + let name = "permessage-deflate" let configuration: Configuration let decompressor: Decompressor let compressor: Compressor - // let internalState: NIOLoopBound - init(configuration: Configuration, eventLoop: EventLoop) throws { + init(configuration: Configuration) throws { self.configuration = configuration self.decompressor = try .init( CompressionAlgorithm.deflate( diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift index b89bf68..43d9220 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift @@ -45,11 +45,11 @@ struct WebSocketClientChannel: ClientConnectionChannel { upgradePipelineHandler: { channel, head in channel.eventLoop.makeCompletedFuture { let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - // work out what extensions we should add + // work out what extensions we should add based off the server response let headerFields = HTTPFields(head.headers, splitCookie: false) let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(headerFields) let extensions = try configuration.extensions.compactMap { - try $0.clientExtension(from: serverExtensions, eventLoop: channel.eventLoop) + try $0.clientExtension(from: serverExtensions) } return UpgradeResult.websocket(asyncChannel, extensions) } @@ -61,6 +61,7 @@ struct WebSocketClientChannel: ClientConnectionChannel { headers.add(name: "Content-Length", value: "0") let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders) headers.add(contentsOf: additionalHeaders) + // add websocket extensions to headers headers.add(contentsOf: self.configuration.extensions.map { (name: "Sec-WebSocket-Extensions", value: $0.clientRequestHeader()) }) let requestHead = HTTPRequestHead( diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 1c89fd1..11f2b45 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -24,7 +24,7 @@ import NIOHTTPTypesHTTP1 import NIOWebSocket /// Child channel supporting a web socket upgrade from HTTP1 -public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { +public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandler { public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel, Logger) async -> Void /// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel) public enum UpgradeResult { @@ -54,14 +54,12 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in try shouldUpgrade(head, channel, logger) .map { headers, handler -> (HTTPFields, WebSocketChannelHandler) in - var headers = headers - let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(head.headerFields) - let responseHeaders = configuration.extensions.compactMap { $0.serverResponseHeader(to: clientHeaders) } - headers.append(contentsOf: responseHeaders.map { .init(name: HTTPField.Name.secWebSocketExtensions, value: $0) }) - let extensions = try configuration.extensions.compactMap { - try $0.serverExtension(from: clientHeaders, eventLoop: channel.eventLoop) - } - + let (headers, extensions) = try Self.webSocketExtensionNegociation( + extensionBuilders: configuration.extensions, + requestHeaders: head.headerFields, + responseHeaders: headers, + logger: logger + ) return (headers, { asyncChannel, logger in let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) let context = WebSocketContext(channel: channel, logger: logger) @@ -93,13 +91,12 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { promise.completeWithTask { try await shouldUpgrade(head, channel, logger) .map { headers, handler in - var headers = headers - let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(head.headerFields) - let responseHeaders = configuration.extensions.compactMap { $0.serverResponseHeader(to: clientHeaders) } - headers.append(contentsOf: responseHeaders.map { .init(name: HTTPField.Name.secWebSocketExtensions, value: $0) }) - let extensions = try configuration.extensions.compactMap { - try $0.serverExtension(from: clientHeaders, eventLoop: channel.eventLoop) - } + let (headers, extensions) = try Self.webSocketExtensionNegociation( + extensionBuilders: configuration.extensions, + requestHeaders: head.headerFields, + responseHeaders: headers, + logger: logger + ) return (headers, { asyncChannel, logger in let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) let context = WebSocketContext(channel: channel, logger: logger) @@ -213,6 +210,40 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { } } + /// WebSocket extension negociation + /// - Parameters: + /// - requestHeaders: Request headers + /// - headers: Response headers + /// - logger: Logger + /// - Returns: Response headers and extensions enabled + static func webSocketExtensionNegociation( + extensionBuilders: [any WebSocketExtensionBuilder], + requestHeaders: HTTPFields, + responseHeaders: HTTPFields, + logger: Logger + ) throws -> (responseHeaders: HTTPFields, extensions: [any WebSocketExtension]) { + var responseHeaders = responseHeaders + let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(requestHeaders) + if clientHeaders.count > 0 { + logger.trace( + "Extensions requested", + metadata: ["hb_extensions": .string(clientHeaders.map(\.name).joined(separator: ","))] + ) + } + let extensionResponseHeaders = extensionBuilders.compactMap { $0.serverResponseHeader(to: clientHeaders) } + responseHeaders.append(contentsOf: extensionResponseHeaders.map { .init(name: .secWebSocketExtensions, value: $0) }) + let extensions = try extensionBuilders.compactMap { + try $0.serverExtension(from: clientHeaders) + } + if extensions.count > 0 { + logger.debug( + "Enabled extensions", + metadata: ["hb_extensions": .string(extensions.map(\.name).joined(separator: ","))] + ) + } + return (responseHeaders, extensions) + } + public var responder: @Sendable (Request, Channel) async throws -> Response let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture> let configuration: WebSocketServerConfiguration diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift index fe55a7b..e89cb40 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift @@ -24,9 +24,9 @@ extension HTTPChannelBuilder { configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> - ) -> HTTPChannelBuilder { + ) -> HTTPChannelBuilder { return .init { responder in - return HTTP1AndWebSocketChannel( + return HTTP1WebSocketUpgradeChannel( responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers, @@ -40,9 +40,9 @@ extension HTTPChannelBuilder { configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> - ) -> HTTPChannelBuilder { + ) -> HTTPChannelBuilder { return .init { responder in - return HTTP1AndWebSocketChannel( + return HTTP1WebSocketUpgradeChannel( responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index a8527c5..18fc503 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -117,8 +117,8 @@ public struct WebSocketUpgradeMiddleware: Rout } } -extension HTTP1AndWebSocketChannel { - /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function +extension HTTP1WebSocketUpgradeChannel { + /// Initialize HTTP1WebSocketUpgradeChannel with async `shouldUpgrade` function /// - Parameters: /// - additionalChannelHandlers: Additional channel handlers to add /// - responder: HTTP responder @@ -141,13 +141,12 @@ extension HTTP1AndWebSocketChannel { do { let response = try await webSocketResponder.respond(to: request, context: context) if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { - var headers = response.headers - let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(head.headerFields) - let responseHeaders = configuration.extensions.compactMap { $0.serverResponseHeader(to: clientHeaders) } - headers.append(contentsOf: responseHeaders.map { .init(name: HTTPField.Name.secWebSocketExtensions, value: $0) }) - let extensions = try configuration.extensions.compactMap { - try $0.serverExtension(from: clientHeaders, eventLoop: channel.eventLoop) - } + let (headers, extensions) = try Self.webSocketExtensionNegociation( + extensionBuilders: configuration.extensions, + requestHeaders: head.headerFields, + responseHeaders: response.headers, + logger: logger + ) return .upgrade(headers) { asyncChannel, _ in let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) await webSocket.handle(handler: webSocketHandler.handler, context: webSocketHandler.context) @@ -181,10 +180,10 @@ extension HTTPChannelBuilder { webSocketRouter: WSResponderBuilder, configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [] - ) -> HTTPChannelBuilder where WSResponderBuilder.Responder.Context: WebSocketRequestContext { + ) -> HTTPChannelBuilder where WSResponderBuilder.Responder.Context: WebSocketRequestContext { let webSocketReponder = webSocketRouter.buildResponder() return .init { responder in - return HTTP1AndWebSocketChannel( + return HTTP1WebSocketUpgradeChannel( responder: responder, webSocketResponder: webSocketReponder, configuration: configuration, diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 6bc9a7a..91064c3 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -19,4 +19,4 @@ import NIOCore import NIOWebSocket /// Function that handles websocket data and text blocks -public typealias WebSocketDataHandler = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void +public typealias WebSocketDataHandler = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void diff --git a/Sources/HummingbirdWebSocket/WebSocketExtension.swift b/Sources/HummingbirdWebSocket/WebSocketExtension.swift index 05fe603..012f902 100644 --- a/Sources/HummingbirdWebSocket/WebSocketExtension.swift +++ b/Sources/HummingbirdWebSocket/WebSocketExtension.swift @@ -18,6 +18,8 @@ import NIOWebSocket /// Protocol for WebSocket extension public protocol WebSocketExtension: Sendable { + /// Extension name + var name: String { get } /// Process frame received from websocket func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame /// Process frame about to be sent to websocket @@ -35,9 +37,9 @@ public protocol WebSocketExtensionBuilder: Sendable { /// construct server response header based of client request func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? /// construct server version of extension based of client request - func serverExtension(from: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (any WebSocketExtension)? + func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? /// construct client version of extension based of server response - func clientExtension(from: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (any WebSocketExtension)? + func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? } extension WebSocketExtensionBuilder { @@ -53,10 +55,10 @@ extension WebSocketExtensionBuilder { } /// construct all server extensions based of all client requests - public func serverExtension(from requests: [WebSocketExtensionHTTPParameters], eventLoop: EventLoop) throws -> (any WebSocketExtension)? { + public func serverExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { for request in requests { guard request.name == Self.name else { continue } - if let ext = try serverExtension(from: request, eventLoop: eventLoop) { + if let ext = try serverExtension(from: request) { return ext } } @@ -64,10 +66,10 @@ extension WebSocketExtensionBuilder { } /// construct all client extensions based of all server responses - public func clientExtension(from requests: [WebSocketExtensionHTTPParameters], eventLoop: EventLoop) throws -> (any WebSocketExtension)? { + public func clientExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { for request in requests { guard request.name == Self.name else { continue } - if let ext = try clientExtension(from: request, eventLoop: eventLoop) { + if let ext = try clientExtension(from: request) { return ext } } diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 7ea4e3e..ab02742 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -29,12 +29,16 @@ public enum WebSocketType: Sendable { /// Manages ping, pong and close messages. Collates data and text messages into final frame /// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user. actor WebSocketHandler: Sendable { + enum InternalError: Error { + case close(WebSocketErrorCode) + } + static let pingDataSize = 16 let asyncChannel: NIOAsyncChannel let type: WebSocketType - var closed: Bool var pingData: ByteBuffer + var closed = false let extensions: [any WebSocketExtension] init(asyncChannel: NIOAsyncChannel, type: WebSocketType, extensions: [any WebSocketExtension]) { @@ -53,7 +57,9 @@ actor WebSocketHandler: Sendable { let webSocketOutbound = WebSocketOutboundWriter( type: self.type, allocator: asyncChannel.channel.allocator, - outbound: outbound + outbound: outbound, + extensions: self.extensions, + context: context ) try await withTaskCancellationHandler { try await withGracefulShutdownHandler { @@ -96,8 +102,14 @@ actor WebSocketHandler: Sendable { break } if let frameSeq = frameSequence, frame.fin { - await webSocketInbound.send(frameSeq.data) - frameSequence = nil + var collatedFrame = frameSeq.collapsed + for ext in self.extensions.reversed() { + collatedFrame = try await ext.processReceivedFrame(collatedFrame, context: context) + } + if let finalFrame = WebSocketDataFrame(frame: collatedFrame) { + await webSocketInbound.send(finalFrame) + frameSequence = nil + } } } catch { // catch errors while processing websocket frames so responding close message @@ -112,6 +124,8 @@ actor WebSocketHandler: Sendable { // handle websocket data and text try await handler(webSocketInbound, webSocketOutbound, context) try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) + } catch InternalError.close(let code) { + try await self.close(code: code, outbound: webSocketOutbound, context: context) } catch { if self.type == .server { let errorCode = WebSocketErrorCode.unexpectedServerError @@ -142,7 +156,7 @@ actor WebSocketHandler: Sendable { /// Respond to ping func onPing( _ frame: WebSocketFrame, - outbound: WebSocketOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { if frame.fin { @@ -155,7 +169,7 @@ actor WebSocketHandler: Sendable { /// Respond to pong func onPong( _ frame: WebSocketFrame, - outbound: WebSocketOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { guard !self.closed else { return } @@ -168,7 +182,7 @@ actor WebSocketHandler: Sendable { } /// Send ping - func ping(outbound: WebSocketOutboundWriter) async throws { + func ping(outbound: WebSocketOutboundWriter) async throws { guard !self.closed else { return } if self.pingData.readableBytes == 0 { // creating random payload @@ -179,7 +193,7 @@ actor WebSocketHandler: Sendable { } /// Send pong - func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws { + func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws { guard !self.closed else { return } try await outbound.write(frame: .init(fin: true, opcode: .pong, data: data ?? .init())) } @@ -187,7 +201,7 @@ actor WebSocketHandler: Sendable { /// Send close func close( code: WebSocketErrorCode = .normalClosure, - outbound: WebSocketOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { guard !self.closed else { return } diff --git a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift index 4d334ee..dbc19d7 100644 --- a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift +++ b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift @@ -16,7 +16,7 @@ import NIOCore import NIOWebSocket /// Outbound websocket writer -public struct WebSocketOutboundWriter: Sendable { +public struct WebSocketOutboundWriter: Sendable { /// WebSocket frame that can be written public enum OutboundFrame: Sendable { /// Text frame @@ -32,6 +32,8 @@ public struct WebSocketOutboundWriter: Sendable { let type: WebSocketType let allocator: ByteBufferAllocator let outbound: NIOAsyncChannelOutboundWriter + let extensions: [any WebSocketExtension] + let context: Context /// Write WebSocket frame public func write(_ frame: OutboundFrame) async throws { @@ -58,8 +60,18 @@ public struct WebSocketOutboundWriter: Sendable { frame: WebSocketFrame ) async throws { var frame = frame + do { + for ext in self.extensions { + frame = try await ext.processFrameToSend(frame, context: self.context) + } + } catch { + self.context.logger.debug("Closing as we failed to generate valid frame data") + throw WebSocketHandler.InternalError.close(.unexpectedServerError) + } frame.maskKey = self.makeMaskKey() try await self.outbound.write(frame) + + self.context.logger.trace("Sent \(frame.opcode)") } func finish() { diff --git a/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift index 1696e4e..867f359 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift @@ -24,29 +24,27 @@ import XCTest final class HummingbirdWebSocketExtensionTests: XCTestCase { func testClientAndServer( - serverExtensions: [WebSocketExtensionFactory] = [], + serverChannel: HTTPChannelBuilder, clientExtensions: [WebSocketExtensionFactory] = [], - server serverHandler: @escaping WebSocketDataHandler, client clientHandler: @escaping WebSocketDataHandler ) async throws { try await withThrowingTaskGroup(of: Void.self) { group in let promise = Promise() let serverLogger = { var logger = Logger(label: "WebSocketServer") - logger.logLevel = .debug + logger.logLevel = .trace return logger }() let clientLogger = { var logger = Logger(label: "WebSocketClient") - logger.logLevel = .debug + logger.logLevel = .trace return logger }() - let router = Router(context: BasicWebSocketRequestContext.self) - router.ws("/test", onUpgrade: serverHandler) let serviceGroup: ServiceGroup + let router = Router() let app = Application( router: router, - server: .webSocketUpgrade(webSocketRouter: router, configuration: .init(extensions: serverExtensions)), + server: serverChannel, onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, logger: serverLogger ) @@ -85,6 +83,21 @@ final class HummingbirdWebSocketExtensionTests: XCTestCase { } } + func testClientAndServer( + serverExtensions: [WebSocketExtensionFactory] = [], + clientExtensions: [WebSocketExtensionFactory] = [], + server serverHandler: @escaping WebSocketDataHandler, + client clientHandler: @escaping WebSocketDataHandler + ) async throws { + try await self.testClientAndServer( + serverChannel: .webSocketUpgrade(configuration: .init(extensions: serverExtensions)) { _, _, _ in + .upgrade([:], serverHandler) + }, + clientExtensions: clientExtensions, + client: clientHandler + ) + } + /// Create random buffer /// - Parameters: /// - size: size of buffer @@ -176,210 +189,77 @@ final class HummingbirdWebSocketExtensionTests: XCTestCase { try await outbound.write(.text("Hello")) for try await _ in inbound {} } - /* onServer: { ws in - XCTAssertNotNil(ws.extensions.first as? PerMessageDeflateExtension) - let stream = ws.readStream() - Task { - var iterator = stream.makeAsyncIterator() - let firstMessage = await iterator.next() - XCTAssertEqual(firstMessage, .text("Hello, testing this is compressed")) - let secondMessage = await iterator.next() - XCTAssertEqual(secondMessage, .text("Hello")) - for await _ in stream {} - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - XCTAssertNotNil(ws.extensions.first as? PerMessageDeflateExtension) - try await ws.write(.text("Hello, testing this is compressed")) - try await ws.write(.text("Hello")) - try await ws.close() - } - )*/ } - /* static var eventLoopGroup: EventLoopGroup! - - override class func setUp() { - self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - } - - override class func tearDown() { - XCTAssertNoThrow(try self.eventLoopGroup.syncShutdownGracefully()) - } - - /// Create random buffer - /// - Parameters: - /// - size: size of buffer - /// - randomness: how random you want the buffer to be (percentage) - func createRandomBuffer(size: Int, randomness: Int = 100) -> ByteBuffer { - var buffer = ByteBufferAllocator().buffer(capacity: size) - let randomness = (randomness * randomness) / 100 - for i in 0.. Void, - onClient: @escaping (WebSocket) async throws -> Void - ) async throws -> HBApplication { - let app = HBApplication(configuration: .init(address: .hostname(port: 0))) - app.logger.logLevel = .trace - // add HTTP to WebSocket upgrade - app.ws.addUpgrade(maxFrameSize: 1 << 14, extensions: serverExtensions) - // on websocket connect. - app.ws.on("/test", onUpgrade: { _, ws in - try await onServer(ws) - return .ok - }) - try app.start() - - let eventLoop = app.eventLoopGroup.next() - let ws = try await WebSocketClient.connect( - url: HBURL("ws://localhost:\(app.server.port!)/test"), - configuration: .init(extensions: clientExtensions), - on: eventLoop - ) - try await onClient(ws) - return app - } - - func testPerMessageDeflate() async throws { - let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10)) - - let app = try await self.setupClientAndServer( - serverExtensions: [.perMessageDeflate()], - clientExtensions: [.perMessageDeflate()], - onServer: { ws in - XCTAssertNotNil(ws.extensions.first as? PerMessageDeflateExtension) - let stream = ws.readStream() - Task { - var iterator = stream.makeAsyncIterator() - let firstMessage = await iterator.next() - XCTAssertEqual(firstMessage, .text("Hello, testing this is compressed")) - let secondMessage = await iterator.next() - XCTAssertEqual(secondMessage, .text("Hello")) - for await _ in stream {} - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - XCTAssertNotNil(ws.extensions.first as? PerMessageDeflateExtension) - try await ws.write(.text("Hello, testing this is compressed")) - try await ws.write(.text("Hello")) - try await ws.close() - } - ) - defer { app.stop() } - - try promise.wait() - } - - func testPerMessageDeflateMaxWindow() async throws { - let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10)) - - let buffer = self.createRandomBuffer(size: 4096, randomness: 10) - let app = try await self.setupClientAndServer( - serverExtensions: [.perMessageDeflate()], - clientExtensions: [.perMessageDeflate(maxWindow: 10)], - onServer: { ws in - XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.receiveMaxWindow, 10) - let stream = ws.readStream() - Task { - for try await data in stream { - XCTAssertEqual(data, .binary(buffer)) - } - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.sendMaxWindow, 10) - try await ws.write(.binary(buffer)) - try await ws.close() - } - ) - defer { app.stop() } - - try promise.wait() - } - - func testPerMessageDeflateNoContextTakeover() async throws { - let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10)) - - let buffer = self.createRandomBuffer(size: 4096, randomness: 10) - let app = try await self.setupClientAndServer( - serverExtensions: [.perMessageDeflate()], - clientExtensions: [.perMessageDeflate(clientNoContextTakeover: true)], - onServer: { ws in - XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.receiveNoContextTakeover, true) - let stream = ws.readStream() - Task { - for try await data in stream { - XCTAssertEqual(data, .binary(buffer)) - } - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.sendNoContextTakeover, true) - try await ws.write(.binary(buffer)) - try await ws.close() - } - ) - defer { app.stop() } + func testPerMessageDeflateMaxWindow() async throws { + let buffer = self.createRandomBuffer(size: 4096, randomness: 10) + try await self.testClientAndServer( + serverExtensions: [.perMessageDeflate()], + clientExtensions: [.perMessageDeflate(maxWindow: 10)] + ) { inbound, outbound, _ in + XCTAssertEqual((outbound.extensions.first as? PerMessageDeflateExtension)?.configuration.receiveMaxWindow, 10) + for try await data in inbound { + XCTAssertEqual(data, .binary(buffer)) + } + } client: { _, outbound, _ in + XCTAssertEqual((outbound.extensions.first as? PerMessageDeflateExtension)?.configuration.sendMaxWindow, 10) + try await outbound.write(.binary(buffer)) + } + } - try promise.wait() - } + func testPerMessageDeflateNoContextTakeover() async throws { + let buffer = self.createRandomBuffer(size: 4096, randomness: 10) + try await self.testClientAndServer( + serverExtensions: [.perMessageDeflate()], + clientExtensions: [.perMessageDeflate(clientNoContextTakeover: true)] + ) { inbound, outbound, _ in + XCTAssertEqual((outbound.extensions.first as? PerMessageDeflateExtension)?.configuration.receiveNoContextTakeover, true) + for try await data in inbound { + XCTAssertEqual(data, .binary(buffer)) + } + } client: { _, outbound, _ in + XCTAssertEqual((outbound.extensions.first as? PerMessageDeflateExtension)?.configuration.sendNoContextTakeover, true) - func testPerMessageExtensionOrdering() async throws { - let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10)) + try await outbound.write(.binary(buffer)) + } + } - let buffer = self.createRandomBuffer(size: 4096, randomness: 10) - let app = try await self.setupClientAndServer( - serverExtensions: [.xor(), .perMessageDeflate()], - clientExtensions: [.xor(value: 34), .perMessageDeflate()], - onServer: { ws in - // XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.receiveNoContextTakeover, true) - let stream = ws.readStream() - Task { - for try await data in stream { - XCTAssertEqual(data, .binary(buffer)) - } - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - // XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.sendNoContextTakeover, true) - try await ws.write(.binary(buffer)) - try await ws.close() - } - ) - defer { app.stop() } + func testPerMessageExtensionOrdering() async throws { + let buffer = self.createRandomBuffer(size: 4096, randomness: 10) + try await self.testClientAndServer( + serverExtensions: [.xor(), .perMessageDeflate(serverNoContextTakeover: true)], + clientExtensions: [.xor(value: 34), .perMessageDeflate()] + ) { inbound, _, _ in + for try await data in inbound { + XCTAssertEqual(data, .binary(buffer)) + } + } client: { _, outbound, _ in + try await outbound.write(.binary(buffer)) + } + } - try promise.wait() - }*/ + func testPerMessageDeflateWithRouter() async throws { + let router = Router(context: BasicWebSocketRequestContext.self) + router.ws("/test") { inbound, _, _ in + var iterator = inbound.makeAsyncIterator() + let firstMessage = await iterator.next() + XCTAssertEqual(firstMessage, .text("Hello, testing this is compressed")) + let secondMessage = await iterator.next() + XCTAssertEqual(secondMessage, .text("Hello")) + } + try await self.testClientAndServer( + serverChannel: .webSocketUpgrade(webSocketRouter: router, configuration: .init(extensions: [.perMessageDeflate()])), + clientExtensions: [.perMessageDeflate()] + ) { inbound, outbound, _ in + try await outbound.write(.text("Hello, testing this is compressed")) + try await outbound.write(.text("Hello")) + for try await _ in inbound {} + } + } } struct XorWebSocketExtension: WebSocketExtension { + let name = "xor" func shutdown() {} func xorFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) -> WebSocketFrame { @@ -427,11 +307,11 @@ struct XorWebSocketExtensionBuilder: WebSocketExtensionBuilder { return header } - func serverExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { + func serverExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? { XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) } - func clientExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { + func clientExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? { XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) } } From 054cfd1dd26ae537a9e38596b280b91d18b93ab4 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Tue, 26 Mar 2024 09:56:06 +0000 Subject: [PATCH 5/9] Update Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift Co-authored-by: Joannis Orlandos --- Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 11f2b45..f49070e 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -91,7 +91,7 @@ public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandl promise.completeWithTask { try await shouldUpgrade(head, channel, logger) .map { headers, handler in - let (headers, extensions) = try Self.webSocketExtensionNegociation( + let (headers, extensions) = try Self.webSocketExtensionNegotiation( extensionBuilders: configuration.extensions, requestHeaders: head.headerFields, responseHeaders: headers, From fc8d7c7fb270526cc7825bfcfb2e7ec99f7548d7 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Tue, 26 Mar 2024 09:56:12 +0000 Subject: [PATCH 6/9] Update Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift Co-authored-by: Joannis Orlandos --- Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index f49070e..564b1b2 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -54,7 +54,7 @@ public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandl channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in try shouldUpgrade(head, channel, logger) .map { headers, handler -> (HTTPFields, WebSocketChannelHandler) in - let (headers, extensions) = try Self.webSocketExtensionNegociation( + let (headers, extensions) = try Self.webSocketExtensionNegotiation( extensionBuilders: configuration.extensions, requestHeaders: head.headerFields, responseHeaders: headers, From df7b14c07f24e31d7f6044449d3239f3833ebf5d Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Tue, 26 Mar 2024 09:56:19 +0000 Subject: [PATCH 7/9] Update Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift Co-authored-by: Joannis Orlandos --- Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 564b1b2..65c4bd7 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -210,7 +210,7 @@ public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandl } } - /// WebSocket extension negociation + /// WebSocket extension negotiation /// - Parameters: /// - requestHeaders: Request headers /// - headers: Response headers From 5cce1b88eb30643ada7f7c76ca160f8e6303b09a Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Tue, 26 Mar 2024 09:56:25 +0000 Subject: [PATCH 8/9] Update Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift Co-authored-by: Joannis Orlandos --- Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index 18fc503..ede01bb 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -141,7 +141,7 @@ extension HTTP1WebSocketUpgradeChannel { do { let response = try await webSocketResponder.respond(to: request, context: context) if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { - let (headers, extensions) = try Self.webSocketExtensionNegociation( + let (headers, extensions) = try Self.webSocketExtensionNegotiation( extensionBuilders: configuration.extensions, requestHeaders: head.headerFields, responseHeaders: response.headers, From 21e7aa95b7757aa259f708b1f1bcb353e49039bf Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Tue, 26 Mar 2024 09:56:32 +0000 Subject: [PATCH 9/9] Update Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift Co-authored-by: Joannis Orlandos --- Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 65c4bd7..79f3a3b 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -216,7 +216,7 @@ public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandl /// - headers: Response headers /// - logger: Logger /// - Returns: Response headers and extensions enabled - static func webSocketExtensionNegociation( + static func webSocketExtensionNegotiation( extensionBuilders: [any WebSocketExtensionBuilder], requestHeaders: HTTPFields, responseHeaders: HTTPFields,