From 9440c8b444ae2ba26329fc402112f0886995d22d Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 20 Mar 2024 12:11:58 +0000 Subject: [PATCH 1/7] WebSocketDataHandler is now a struct --- Package.swift | 2 +- .../Client/WebSocketClient.swift | 20 +-- .../Client/WebSocketClientChannel.swift | 18 +-- .../NIOWebSocketServerUpgrade+ext.swift | 12 +- .../Server/WebSocketChannel.swift | 116 ++++++++++-------- .../Server/WebSocketHTTPChannelBuilder.swift | 14 +-- .../Server/WebSocketRouter.swift | 80 +++++------- .../WebSocketDataHandler.swift | 55 +++------ .../WebSocketHandler.swift | 8 +- .../WebSocketTests.swift | 12 +- 10 files changed, 155 insertions(+), 182 deletions(-) diff --git a/Package.swift b/Package.swift index bda3e6a..911fe6e 100644 --- a/Package.swift +++ b/Package.swift @@ -11,7 +11,7 @@ let package = Package( // .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]), ], dependencies: [ - .package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "2.0.0-beta.1"), + .package(url: "https://github.com/hummingbird-project/hummingbird.git", branch: "main"), .package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"), diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift index e5b3720..d34b61f 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift @@ -27,7 +27,7 @@ import ServiceLifecycle /// /// Supports TLS via both NIOSSL and Network framework. /// -/// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run`` +/// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run()`` /// to connect. The handler is provider with an `inbound` stream of WebSocket packets coming /// from the server and an `outbound` writer that can be used to write packets to the server. /// ```swift @@ -50,7 +50,7 @@ public struct WebSocketClient { /// WebSocket URL let url: URI /// WebSocket data handler - let handler: WebSocketDataCallbackHandler + let handler: WebSocketDataHandler.Handler /// configuration let configuration: WebSocketClientConfiguration /// EventLoopGroup to use @@ -75,10 +75,10 @@ public struct WebSocketClient { tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, - process: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler.Handler ) throws { self.url = url - self.handler = .init(process) + self.handler = handler self.configuration = configuration self.eventLoopGroup = eventLoopGroup self.logger = logger @@ -101,10 +101,10 @@ public struct WebSocketClient { transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, - process: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler.Handler ) throws { self.url = url - self.handler = .init(process) + self.handler = handler self.configuration = configuration self.eventLoopGroup = eventLoopGroup self.logger = logger @@ -193,7 +193,7 @@ extension WebSocketClient { tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, - process: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler.Handler ) async throws { let ws = try self.init( url: url, @@ -201,7 +201,7 @@ extension WebSocketClient { tlsConfiguration: tlsConfiguration, eventLoopGroup: eventLoopGroup, logger: logger, - process: process + handler: handler ) try await ws.run() } @@ -222,7 +222,7 @@ extension WebSocketClient { transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, - process: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler.Handler ) async throws { let ws = try self.init( url: url, @@ -230,7 +230,7 @@ extension WebSocketClient { transportServicesTLSOptions: transportServicesTLSOptions, eventLoopGroup: eventLoopGroup, logger: logger, - process: process + handler: handler ) try await ws.run() } diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift index dfed093..3899b6e 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift @@ -20,25 +20,25 @@ import NIOHTTP1 import NIOHTTPTypesHTTP1 import NIOWebSocket -public struct WebSocketClientChannel: ClientConnectionChannel { - public enum UpgradeResult { +struct WebSocketClientChannel: ClientConnectionChannel { + enum UpgradeResult { case websocket(NIOAsyncChannel) case notUpgraded } - public typealias Value = EventLoopFuture + typealias Value = EventLoopFuture let url: String - let handler: Handler + let handler: WebSocketDataHandler.Handler let configuration: WebSocketClientConfiguration - init(handler: Handler, url: String, configuration: WebSocketClientConfiguration) { + init(handler: @escaping WebSocketDataHandler.Handler, url: String, configuration: WebSocketClientConfiguration) { self.url = url self.handler = handler self.configuration = configuration } - public func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture { + func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture { channel.eventLoop.makeCompletedFuture { let upgrader = NIOTypedWebSocketClientUpgrader( maxFrameSize: self.configuration.maxFrameSize, @@ -81,12 +81,12 @@ public struct WebSocketClientChannel: ClientConne } } - public func handle(value: Value, logger: Logger) async throws { + func handle(value: Value, logger: Logger) async throws { switch try await value.get() { case .websocket(let webSocketChannel): let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client) - let context = self.handler.alreadySetupContext ?? .init(channel: webSocketChannel.channel, logger: logger) - await webSocket.handle(handler: self.handler, context: context) + let dataHandler = WebSocketDataHandler(context: .init(channel: webSocketChannel.channel, logger: logger), handler: self.handler) + await webSocket.handle(handler: dataHandler) case .notUpgraded: // The upgrade to websocket did not succeed. logger.debug("Upgrade declined") diff --git a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift index 9994496..3ca339c 100644 --- a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift +++ b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift @@ -23,6 +23,16 @@ import NIOWebSocket public enum ShouldUpgradeResult: Sendable { case dontUpgrade case upgrade(HTTPFields, Value) + + /// Map upgrade result to difference type + func map(_ map: (Value) throws -> Result) rethrows -> ShouldUpgradeResult { + switch self { + case .dontUpgrade: + return .dontUpgrade + case .upgrade(let headers, let value): + return try .upgrade(headers, map(value)) + } + } } extension NIOTypedWebSocketServerUpgrader { @@ -47,7 +57,7 @@ extension NIOTypedWebSocketServerUpgrader { /// websocket protocol. This only needs to add the user handlers: the /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the /// pipeline automatically. - public convenience init( + convenience init( maxFrameSize: Int = 1 << 14, enableAutomaticErrorHandling: Bool = true, shouldUpgrade: @escaping @Sendable (Channel, HTTPRequest) -> EventLoopFuture>, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 0f55d9b..bf142c3 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -23,62 +23,15 @@ import NIOHTTPTypesHTTP1 import NIOWebSocket /// Child channel supporting a web socket upgrade from HTTP1 -public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { +public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { /// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel) public enum UpgradeResult { - case websocket(NIOAsyncChannel, Handler) + case websocket(NIOAsyncChannel, WebSocketDataHandler) case notUpgraded(NIOAsyncChannel, failed: Bool) } public typealias Value = EventLoopFuture - - /// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function - /// - Parameters: - /// - additionalChannelHandlers: Additional channel handlers to add - /// - responder: HTTP responder - /// - maxFrameSize: Max frame size WebSocket will allow - /// - shouldUpgrade: Function returning whether upgrade should be allowed - /// - Returns: Upgrade result future - public init( - responder: @escaping @Sendable (Request, Channel) async throws -> Response, - configuration: WebSocketServerConfiguration, - additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult - ) { - self.additionalChannelHandlers = additionalChannelHandlers - self.configuration = configuration - self.shouldUpgrade = { head, channel, logger in - channel.eventLoop.makeCompletedFuture { - try shouldUpgrade(head, channel, logger) - } - } - self.responder = responder - } - - /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function - /// - Parameters: - /// - additionalChannelHandlers: Additional channel handlers to add - /// - responder: HTTP responder - /// - maxFrameSize: Max frame size WebSocket will allow - /// - shouldUpgrade: Function returning whether upgrade should be allowed - /// - Returns: Upgrade result future - public init( - responder: @escaping @Sendable (Request, Channel) async throws -> Response, - configuration: WebSocketServerConfiguration, - additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult - ) { - self.additionalChannelHandlers = additionalChannelHandlers - self.configuration = configuration - self.shouldUpgrade = { head, channel, logger in - let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) - promise.completeWithTask { - try await shouldUpgrade(head, channel, logger) - } - return promise.futureResult - } - self.responder = responder - } + public typealias Handler = WebSocketDataHandler.Handler /// Setup channel to accept HTTP1 with a WebSocket upgrade /// - Parameters: @@ -142,8 +95,7 @@ public struct HTTP1AndWebSocketChannel: ServerChi } case .websocket(let asyncChannel, let handler): let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) - let context = handler.alreadySetupContext ?? .init(channel: asyncChannel.channel, logger: logger) - await webSocket.handle(handler: handler, context: context) + await webSocket.handle(handler: handler) } } catch { logger.error("Error handling upgrade result: \(error)") @@ -177,7 +129,65 @@ public struct HTTP1AndWebSocketChannel: ServerChi } public var responder: @Sendable (Request, Channel) async throws -> Response - let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture> + let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture>> let configuration: WebSocketServerConfiguration let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler] } + +extension HTTP1AndWebSocketChannel where Context == WebSocketContext { + /// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function + /// - Parameters: + /// - additionalChannelHandlers: Additional channel handlers to add + /// - responder: HTTP responder + /// - maxFrameSize: Max frame size WebSocket will allow + /// - shouldUpgrade: Function returning whether upgrade should be allowed + /// - Returns: Upgrade result future + public init( + responder: @escaping @Sendable (Request, Channel) async throws -> Response, + configuration: WebSocketServerConfiguration, + additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult + ) { + self.additionalChannelHandlers = additionalChannelHandlers + self.configuration = configuration + self.shouldUpgrade = { head, channel, logger in + channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult> in + try shouldUpgrade(head, channel, logger) + .map { + let context = WebSocketContext(channel: channel, logger: logger) + return WebSocketDataHandler(context: context, handler: $0) + } + } + } + self.responder = responder + } + + /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function + /// - Parameters: + /// - additionalChannelHandlers: Additional channel handlers to add + /// - responder: HTTP responder + /// - maxFrameSize: Max frame size WebSocket will allow + /// - shouldUpgrade: Function returning whether upgrade should be allowed + /// - Returns: Upgrade result future + public init( + responder: @escaping @Sendable (Request, Channel) async throws -> Response, + configuration: WebSocketServerConfiguration, + additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult + ) { + self.additionalChannelHandlers = additionalChannelHandlers + self.configuration = configuration + self.shouldUpgrade = { head, channel, logger in + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult>.self) + promise.completeWithTask { + try await shouldUpgrade(head, channel, logger) + .map { + let context = WebSocketContext(channel: channel, logger: logger) + return WebSocketDataHandler(context: context, handler: $0) + } + } + return promise.futureResult + } + self.responder = responder + } +} diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift index c02d4f2..1085114 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift @@ -20,11 +20,11 @@ import NIOCore extension HTTPChannelBuilder { /// HTTP1 channel builder supporting a websocket upgrade /// - parameters - public static func webSocketUpgrade( + public static func webSocketUpgrade( configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult - ) -> HTTPChannelBuilder> { + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult.Handler> + ) -> HTTPChannelBuilder> { return .init { responder in return HTTP1AndWebSocketChannel( responder: responder, @@ -36,13 +36,13 @@ extension HTTPChannelBuilder { } /// HTTP1 channel builder supporting a websocket upgrade - public static func webSocketUpgrade( + public static func webSocketUpgrade( configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult - ) -> HTTPChannelBuilder> { + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult.Handler> + ) -> HTTPChannelBuilder> { return .init { responder in - return HTTP1AndWebSocketChannel( + return HTTP1AndWebSocketChannel( responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index eee133a..4067a3c 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -23,23 +23,23 @@ import NIOCore /// WebSocket Router context type. /// /// Includes reference to optional websocket handler -public struct WebSocketRouterContext: Sendable { +public struct WebSocketRouterContext: Sendable { public init() { self.handler = .init(nil) } - let handler: NIOLockedValueBox + let handler: NIOLockedValueBox?> } /// Request context protocol requirement for routers that support websockets public protocol WebSocketRequestContext: RequestContext, WebSocketContextProtocol { - var webSocket: WebSocketRouterContext { get } + var webSocket: WebSocketRouterContext { get } } /// Default implementation of a request context that supports WebSockets public struct BasicWebSocketRequestContext: WebSocketRequestContext { public var coreContext: CoreRequestContext - public let webSocket: WebSocketRouterContext + public let webSocket: WebSocketRouterContext public init(channel: Channel, logger: Logger) { self.coreContext = .init(allocator: channel.allocator, logger: logger) @@ -63,7 +63,7 @@ extension RouterMethods { @discardableResult public func ws( _ path: String = "", shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, - handle: @escaping WebSocketDataCallbackHandler.Callback + handle: @escaping WebSocketDataHandler.Handler ) -> Self where Context: WebSocketRequestContext { return on(path, method: .get) { request, context -> Response in let result = try await shouldUpgrade(request, context) @@ -71,7 +71,7 @@ extension RouterMethods { case .dontUpgrade: return .init(status: .methodNotAllowed) case .upgrade(let headers): - context.webSocket.handler.withLockedValue { $0 = WebSocketDataCallbackHandler(handle) } + context.webSocket.handler.withLockedValue { $0 = WebSocketDataHandler(context: context, handler: handle) } return .init(status: .ok, headers: headers) } } @@ -84,7 +84,7 @@ extension RouterMethods { /// with ``Hummingbird/Router`` if you add a route immediately after it. public struct WebSocketUpgradeMiddleware: RouterMiddleware { let shouldUpgrade: @Sendable (Request, Context) async throws -> RouterShouldUpgrade - let handle: WebSocketDataCallbackHandler.Callback + let handler: WebSocketDataHandler.Handler /// Initialize WebSocketUpgradeMiddleare /// - Parameters: @@ -92,10 +92,10 @@ public struct WebSocketUpgradeMiddleware: Rout /// - handle: WebSocket handler public init( shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, - handle: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler.Handler ) { self.shouldUpgrade = shouldUpgrade - self.handle = handle + self.handler = handler } /// WebSocketUpgradeMiddleware handler @@ -105,13 +105,13 @@ public struct WebSocketUpgradeMiddleware: Rout case .dontUpgrade: return .init(status: .methodNotAllowed) case .upgrade(let headers): - context.webSocket.handler.withLockedValue { $0 = WebSocketDataCallbackHandler(self.handle) } + context.webSocket.handler.withLockedValue { $0 = .init(context: context, handler: self.handler) } return .init(status: .ok, headers: headers) } } } -extension HTTP1AndWebSocketChannel { +extension HTTP1AndWebSocketChannel where Context: WebSocketRequestContext { /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function /// - Parameters: /// - additionalChannelHandlers: Additional channel handlers to add @@ -119,26 +119,33 @@ extension HTTP1AndWebSocketChannel { /// - maxFrameSize: Max frame size WebSocket will allow /// - webSocketRouter: WebSocket router /// - Returns: Upgrade result future - public init( + public init( responder: @escaping @Sendable (Request, Channel) async throws -> Response, webSocketResponder: WSResponder, configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] } - ) where Handler == WebSocketDataCallbackHandler, WSResponder.Context == Context { - self.init(responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers) { head, channel, logger in - let request = Request(head: head, body: .init(buffer: .init())) - let context = Context(channel: channel, logger: logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID()))) - do { - let response = try await webSocketResponder.respond(to: request, context: context) - if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { - return .upgrade(response.headers, webSocketHandler) - } else { + ) where WSResponder.Context == Context { + self.additionalChannelHandlers = additionalChannelHandlers + self.configuration = configuration + self.shouldUpgrade = { head, channel, logger in + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult>.self) + promise.completeWithTask { + let request = Request(head: head, body: .init(buffer: .init())) + let context = Context(channel: channel, logger: logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID()))) + do { + let response = try await webSocketResponder.respond(to: request, context: context) + if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { + return .upgrade(response.headers, webSocketHandler) + } else { + return .dontUpgrade + } + } catch { return .dontUpgrade } - } catch { - return .dontUpgrade } + return promise.futureResult } + self.responder = responder } } @@ -158,7 +165,7 @@ extension HTTPChannelBuilder { webSocketRouter: WSResponderBuilder, configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [] - ) -> HTTPChannelBuilder> where WSResponderBuilder.Responder.Context: WebSocketRequestContext { + ) -> HTTPChannelBuilder> where WSResponderBuilder.Responder.Context: WebSocketRequestContext { let webSocketReponder = webSocketRouter.buildResponder() return .init { responder in return HTTP1AndWebSocketChannel( @@ -183,28 +190,3 @@ extension Logger { return logger } } - -/// Generate Unique ID for each request. This is a duplicate of the RequestID in Hummingbird -package struct RequestID: CustomStringConvertible { - let low: UInt64 - - package init() { - self.low = Self.globalRequestID.loadThenWrappingIncrement(by: 1, ordering: .relaxed) - } - - package var description: String { - Self.high + self.formatAsHexWithLeadingZeros(self.low) - } - - func formatAsHexWithLeadingZeros(_ value: UInt64) -> String { - let string = String(value, radix: 16) - if string.count < 16 { - return String(repeating: "0", count: 16 - string.count) + string - } else { - return string - } - } - - private static let high = String(UInt64.random(in: .min ... .max), radix: 16) - private static let globalRequestID = ManagedAtomic(UInt64.random(in: .min ... .max)) -} diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 9c3321a..6b00522 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -14,6 +14,7 @@ import AsyncAlgorithms import HTTPTypes +import Logging import NIOCore import NIOWebSocket @@ -22,48 +23,20 @@ import NIOWebSocket /// This is the users interface into HummingbirdWebSocket. They provide an implementation of this protocol when /// contructing their WebSocket upgrade handler. The user needs to return a type conforming to this protocol in /// the `shouldUpgrade` closure in HTTP1AndWebSocketChannel.init -public protocol WebSocketDataHandler: Sendable { - /// Context type supplied to the handle function. - /// - /// The `WebSocketDataHandler` can chose to setup a context or accept the default one from - /// ``WebSocketHandler``. - associatedtype Context: WebSocketContextProtocol = WebSocketContext - /// If a `WebSocketDataHandler` requires a context with custom data it should - /// setup this variable on initialization - var alreadySetupContext: Context? { get } - /// Handler WebSocket data packets - /// - Parameters: - /// - inbound: An AsyncSequence of text or binary WebSocket frames. - /// - outbound: An outbound Writer to write websocket frames to - /// - context: Associated context to this websocket channel - func handle(_ inbound: WebSocketHandlerInbound, _ outbound: WebSocketHandlerOutboundWriter, context: Context) async throws -} - -extension WebSocketDataHandler { - /// Default implementaion of ``alreadySetupContext`` returns nil, so the Context will be - /// created by the ``WebSocketChannelHandler`` - public var alreadySetupContext: Context? { nil } -} - -/// WebSocketDataHandler that is is initialized via a callback -public struct WebSocketDataCallbackHandler: WebSocketDataHandler { - public typealias Callback = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, WebSocketContext) async throws -> Void - - let callback: Callback - - public init(_ callback: @escaping Callback) { - self.callback = callback - } - - /// Handler WebSocket data packets by passing directly to the callback - public func handle(_ inbound: WebSocketHandlerInbound, _ outbound: WebSocketHandlerOutboundWriter, context: WebSocketContext) async throws { - try await self.callback(inbound, outbound, context) +public struct WebSocketDataHandler: Sendable { + /// Handler closure type + public typealias Handler = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void + /// Context sent to handler + let context: Context + /// handler function + let handler: Handler + + public init(context: Context, handler: @escaping Handler) { + self.context = context + self.handler = handler } -} -extension ShouldUpgradeResult where Value == WebSocketDataCallbackHandler { - /// Extension to ShouldUpgradeResult that takes just a callback - public static func upgrade(_ headers: HTTPFields, _ callback: @escaping WebSocketDataCallbackHandler.Callback) -> Self { - .upgrade(headers, WebSocketDataCallbackHandler(callback)) + func withContext(channel: Channel, logger: Logger) -> Self { + .init(context: .init(channel: channel, logger: logger), handler: self.handler) } } diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 13df27c..df8c602 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -41,10 +41,8 @@ actor WebSocketHandler: Sendable { } /// Handle WebSocket AsynChannel - func handle( - handler: Handler, - context: Handler.Context - ) async { + func handle(handler: WebSocketDataHandler) async { + let context = handler.context try? await self.asyncChannel.executeThenClose { inbound, outbound in do { try await withThrowingTaskGroup(of: Void.self) { group in @@ -90,7 +88,7 @@ actor WebSocketHandler: Sendable { } group.addTask { // handle websocket data and text - try await handler.handle(webSocketHandlerInbound, webSocketHandlerOutbound, context: context) + try await handler.handler(webSocketHandlerInbound, webSocketHandlerOutbound, context) try await self.close(code: .normalClosure, outbound: outbound, context: context) } try await group.next() diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index e286221..a6c166d 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -76,7 +76,7 @@ final class HummingbirdWebSocketTests: XCTestCase { func testClientAndServer( serverTLSConfiguration: TLSConfiguration? = nil, - server serverHandler: @escaping WebSocketDataCallbackHandler.Callback, + server serverHandler: @escaping WebSocketDataHandler.Handler, shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] }, getClient: @escaping @Sendable (Int, Logger) throws -> WebSocketClient ) async throws { @@ -91,7 +91,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let serviceGroup: ServiceGroup let webSocketUpgrade: HTTPChannelBuilder = .webSocketUpgrade { head, _, _ in if let headers = try shouldUpgrade(head) { - return .upgrade(headers, WebSocketDataCallbackHandler(serverHandler)) + return .upgrade(headers, serverHandler) } else { return .dontUpgrade } @@ -144,9 +144,9 @@ final class HummingbirdWebSocketTests: XCTestCase { func testClientAndServer( serverTLSConfiguration: TLSConfiguration? = nil, - server serverHandler: @escaping WebSocketDataCallbackHandler.Callback, + server serverHandler: @escaping WebSocketDataHandler.Handler, shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] }, - client clientHandler: @escaping WebSocketDataCallbackHandler.Callback + client clientHandler: @escaping WebSocketDataHandler.Handler ) async throws { try await self.testClientAndServer( serverTLSConfiguration: serverTLSConfiguration, @@ -156,7 +156,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try WebSocketClient( url: .init("ws://localhost:\(port)"), logger: logger, - process: clientHandler + handler: clientHandler ) } ) @@ -422,7 +422,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router.group("/ws") .add(middleware: WebSocketUpgradeMiddleware { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } handler: { _, outbound, _ in try await outbound.write(.text("One")) }) .get { _, _ -> Response in return .init(status: .ok) } From c3275b42038ea9829af994905ed258364634de28 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 20 Mar 2024 12:41:51 +0000 Subject: [PATCH 2/7] Renaming --- .../Client/WebSocketClient.swift | 10 +++++----- .../Client/WebSocketClientChannel.swift | 7 +++---- .../Server/WebSocketChannel.swift | 19 +++++++++++-------- .../Server/WebSocketHTTPChannelBuilder.swift | 4 ++-- .../Server/WebSocketRouter.swift | 12 ++++++------ .../WebSocketDataHandler.swift | 17 +++++++---------- .../WebSocketHandler.swift | 5 ++--- .../WebSocketTests.swift | 6 +++--- 8 files changed, 39 insertions(+), 41 deletions(-) diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift index d34b61f..99aeb8a 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift @@ -50,7 +50,7 @@ public struct WebSocketClient { /// WebSocket URL let url: URI /// WebSocket data handler - let handler: WebSocketDataHandler.Handler + let handler: WebSocketDataHandler /// configuration let configuration: WebSocketClientConfiguration /// EventLoopGroup to use @@ -75,7 +75,7 @@ public struct WebSocketClient { tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, - handler: @escaping WebSocketDataHandler.Handler + handler: @escaping WebSocketDataHandler ) throws { self.url = url self.handler = handler @@ -101,7 +101,7 @@ public struct WebSocketClient { transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, - handler: @escaping WebSocketDataHandler.Handler + handler: @escaping WebSocketDataHandler ) throws { self.url = url self.handler = handler @@ -193,7 +193,7 @@ extension WebSocketClient { tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, - handler: @escaping WebSocketDataHandler.Handler + handler: @escaping WebSocketDataHandler ) async throws { let ws = try self.init( url: url, @@ -222,7 +222,7 @@ extension WebSocketClient { transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, - handler: @escaping WebSocketDataHandler.Handler + handler: @escaping WebSocketDataHandler ) async throws { let ws = try self.init( url: url, diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift index 3899b6e..20d9330 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift @@ -29,10 +29,10 @@ struct WebSocketClientChannel: ClientConnectionChannel { typealias Value = EventLoopFuture let url: String - let handler: WebSocketDataHandler.Handler + let handler: WebSocketDataHandler let configuration: WebSocketClientConfiguration - init(handler: @escaping WebSocketDataHandler.Handler, url: String, configuration: WebSocketClientConfiguration) { + init(handler: @escaping WebSocketDataHandler, url: String, configuration: WebSocketClientConfiguration) { self.url = url self.handler = handler self.configuration = configuration @@ -85,8 +85,7 @@ struct WebSocketClientChannel: ClientConnectionChannel { switch try await value.get() { case .websocket(let webSocketChannel): let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client) - let dataHandler = WebSocketDataHandler(context: .init(channel: webSocketChannel.channel, logger: logger), handler: self.handler) - await webSocket.handle(handler: dataHandler) + await webSocket.handle(handler: self.handler, context: WebSocketContext(channel: webSocketChannel.channel, logger: logger)) case .notUpgraded: // The upgrade to websocket did not succeed. logger.debug("Upgrade declined") diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index bf142c3..c63de92 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import HTTPTypes +import Hummingbird import HummingbirdCore import Logging import NIOConcurrencyHelpers @@ -26,12 +27,12 @@ import NIOWebSocket public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { /// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel) public enum UpgradeResult { - case websocket(NIOAsyncChannel, WebSocketDataHandler) + case websocket(NIOAsyncChannel, WebSocketDataHandlerAndContext) case notUpgraded(NIOAsyncChannel, failed: Bool) } public typealias Value = EventLoopFuture - public typealias Handler = WebSocketDataHandler.Handler + public typealias Handler = WebSocketDataHandler /// Setup channel to accept HTTP1 with a WebSocket upgrade /// - Parameters: @@ -95,7 +96,7 @@ public struct HTTP1AndWebSocketChannel: Serve } case .websocket(let asyncChannel, let handler): let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) - await webSocket.handle(handler: handler) + await webSocket.handle(handler: handler.handler, context: handler.context) } } catch { logger.error("Error handling upgrade result: \(error)") @@ -129,7 +130,7 @@ public struct HTTP1AndWebSocketChannel: Serve } public var responder: @Sendable (Request, Channel) async throws -> Response - let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture>> + let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture>> let configuration: WebSocketServerConfiguration let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler] } @@ -151,11 +152,12 @@ extension HTTP1AndWebSocketChannel where Context == WebSocketContext { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration self.shouldUpgrade = { head, channel, logger in - channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult> in + channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult> in try shouldUpgrade(head, channel, logger) .map { + let logger = logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID())) let context = WebSocketContext(channel: channel, logger: logger) - return WebSocketDataHandler(context: context, handler: $0) + return WebSocketDataHandlerAndContext(context: context, handler: $0) } } } @@ -178,12 +180,13 @@ extension HTTP1AndWebSocketChannel where Context == WebSocketContext { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration self.shouldUpgrade = { head, channel, logger in - let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult>.self) + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult>.self) promise.completeWithTask { try await shouldUpgrade(head, channel, logger) .map { + let logger = logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID())) let context = WebSocketContext(channel: channel, logger: logger) - return WebSocketDataHandler(context: context, handler: $0) + return WebSocketDataHandlerAndContext(context: context, handler: $0) } } return promise.futureResult diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift index 1085114..47227db 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift @@ -23,7 +23,7 @@ extension HTTPChannelBuilder { public static func webSocketUpgrade( configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult.Handler> + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> ) -> HTTPChannelBuilder> { return .init { responder in return HTTP1AndWebSocketChannel( @@ -39,7 +39,7 @@ extension HTTPChannelBuilder { public static func webSocketUpgrade( configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult.Handler> + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> ) -> HTTPChannelBuilder> { return .init { responder in return HTTP1AndWebSocketChannel( diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index 4067a3c..85a6295 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -28,7 +28,7 @@ public struct WebSocketRouterContext: Sendable self.handler = .init(nil) } - let handler: NIOLockedValueBox?> + let handler: NIOLockedValueBox?> } /// Request context protocol requirement for routers that support websockets @@ -63,7 +63,7 @@ extension RouterMethods { @discardableResult public func ws( _ path: String = "", shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, - handle: @escaping WebSocketDataHandler.Handler + handle: @escaping WebSocketDataHandler ) -> Self where Context: WebSocketRequestContext { return on(path, method: .get) { request, context -> Response in let result = try await shouldUpgrade(request, context) @@ -71,7 +71,7 @@ extension RouterMethods { case .dontUpgrade: return .init(status: .methodNotAllowed) case .upgrade(let headers): - context.webSocket.handler.withLockedValue { $0 = WebSocketDataHandler(context: context, handler: handle) } + context.webSocket.handler.withLockedValue { $0 = WebSocketDataHandlerAndContext(context: context, handler: handle) } return .init(status: .ok, headers: headers) } } @@ -84,7 +84,7 @@ extension RouterMethods { /// with ``Hummingbird/Router`` if you add a route immediately after it. public struct WebSocketUpgradeMiddleware: RouterMiddleware { let shouldUpgrade: @Sendable (Request, Context) async throws -> RouterShouldUpgrade - let handler: WebSocketDataHandler.Handler + let handler: WebSocketDataHandler /// Initialize WebSocketUpgradeMiddleare /// - Parameters: @@ -92,7 +92,7 @@ public struct WebSocketUpgradeMiddleware: Rout /// - handle: WebSocket handler public init( shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, - handler: @escaping WebSocketDataHandler.Handler + handler: @escaping WebSocketDataHandler ) { self.shouldUpgrade = shouldUpgrade self.handler = handler @@ -128,7 +128,7 @@ extension HTTP1AndWebSocketChannel where Context: WebSocketRequestContext { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration self.shouldUpgrade = { head, channel, logger in - let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult>.self) + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult>.self) promise.completeWithTask { let request = Request(head: head, body: .init(buffer: .init())) let context = Context(channel: channel, logger: logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID()))) diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 6b00522..ac3d54b 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -18,20 +18,17 @@ import Logging import NIOCore import NIOWebSocket -/// Protocol for web socket data handling -/// -/// This is the users interface into HummingbirdWebSocket. They provide an implementation of this protocol when -/// contructing their WebSocket upgrade handler. The user needs to return a type conforming to this protocol in -/// the `shouldUpgrade` closure in HTTP1AndWebSocketChannel.init -public struct WebSocketDataHandler: Sendable { - /// Handler closure type - public typealias Handler = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void +/// Handle websocket data and text blocks +public typealias WebSocketDataHandler = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void + +/// Struct holding for web socket data handler and context. +public struct WebSocketDataHandlerAndContext: Sendable { /// Context sent to handler let context: Context /// handler function - let handler: Handler + let handler: WebSocketDataHandler - public init(context: Context, handler: @escaping Handler) { + public init(context: Context, handler: @escaping WebSocketDataHandler) { self.context = context self.handler = handler } diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index df8c602..6bcac5e 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -41,8 +41,7 @@ actor WebSocketHandler: Sendable { } /// Handle WebSocket AsynChannel - func handle(handler: WebSocketDataHandler) async { - let context = handler.context + func handle(handler: @escaping WebSocketDataHandler, context: Context) async { try? await self.asyncChannel.executeThenClose { inbound, outbound in do { try await withThrowingTaskGroup(of: Void.self) { group in @@ -88,7 +87,7 @@ actor WebSocketHandler: Sendable { } group.addTask { // handle websocket data and text - try await handler.handler(webSocketHandlerInbound, webSocketHandlerOutbound, context) + try await handler(webSocketHandlerInbound, webSocketHandlerOutbound, context) try await self.close(code: .normalClosure, outbound: outbound, context: context) } try await group.next() diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index a6c166d..4f436f6 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -76,7 +76,7 @@ final class HummingbirdWebSocketTests: XCTestCase { func testClientAndServer( serverTLSConfiguration: TLSConfiguration? = nil, - server serverHandler: @escaping WebSocketDataHandler.Handler, + server serverHandler: @escaping WebSocketDataHandler, shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] }, getClient: @escaping @Sendable (Int, Logger) throws -> WebSocketClient ) async throws { @@ -144,9 +144,9 @@ final class HummingbirdWebSocketTests: XCTestCase { func testClientAndServer( serverTLSConfiguration: TLSConfiguration? = nil, - server serverHandler: @escaping WebSocketDataHandler.Handler, + server serverHandler: @escaping WebSocketDataHandler, shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] }, - client clientHandler: @escaping WebSocketDataHandler.Handler + client clientHandler: @escaping WebSocketDataHandler ) async throws { try await self.testClientAndServer( serverTLSConfiguration: serverTLSConfiguration, From daf2faffca71bf12011256428b0510b57dc75d67 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 20 Mar 2024 14:34:53 +0000 Subject: [PATCH 3/7] Remove generic parameter from HTTP1AndWebSocketChannel --- .../Server/WebSocketChannel.swift | 69 +++++++++++-------- .../Server/WebSocketHTTPChannelBuilder.swift | 4 +- .../Server/WebSocketRouter.swift | 25 ++++--- .../WebSocketDataHandler.swift | 19 +---- 4 files changed, 60 insertions(+), 57 deletions(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index c63de92..2ee05de 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -24,15 +24,16 @@ import NIOHTTPTypesHTTP1 import NIOWebSocket /// Child channel supporting a web socket upgrade from HTTP1 -public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { +public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { + public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel) async -> Void /// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel) public enum UpgradeResult { - case websocket(NIOAsyncChannel, WebSocketDataHandlerAndContext) - case notUpgraded(NIOAsyncChannel, failed: Bool) + case websocket(NIOAsyncChannel, WebSocketChannelHandler, Logger) + case notUpgraded(NIOAsyncChannel) + case failedUpgrade(NIOAsyncChannel, Logger) } public typealias Value = EventLoopFuture - public typealias Handler = WebSocketDataHandler /// Setup channel to accept HTTP1 with a WebSocket upgrade /// - Parameters: @@ -43,6 +44,7 @@ public struct HTTP1AndWebSocketChannel: Serve public func setup(channel: Channel, logger: Logger) -> EventLoopFuture { return channel.eventLoop.makeCompletedFuture { let upgradeAttempted = NIOLoopBoundBox(false, eventLoop: channel.eventLoop) + let logger = logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID())) let upgrader = NIOTypedWebSocketServerUpgrader( maxFrameSize: self.configuration.maxFrameSize, shouldUpgrade: { channel, head in @@ -52,7 +54,7 @@ public struct HTTP1AndWebSocketChannel: Serve upgradePipelineHandler: { channel, handler in channel.eventLoop.makeCompletedFuture { let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - return UpgradeResult.websocket(asyncChannel, handler) + return UpgradeResult.websocket(asyncChannel, handler, logger) } } ) @@ -67,7 +69,11 @@ public struct HTTP1AndWebSocketChannel: Serve return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandlers(childChannelHandlers) let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - return UpgradeResult.notUpgraded(asyncChannel, failed: upgradeAttempted.value) + if upgradeAttempted.value { + return UpgradeResult.failedUpgrade(asyncChannel, logger) + } else { + return UpgradeResult.notUpgraded(asyncChannel) + } } } ) @@ -88,15 +94,16 @@ public struct HTTP1AndWebSocketChannel: Serve do { let result = try await upgradeResult.get() switch result { - case .notUpgraded(let http1, let failed): - if failed { - await self.write405(asyncChannel: http1, logger: logger) - } else { - await self.handleHTTP(asyncChannel: http1, logger: logger) - } - case .websocket(let asyncChannel, let handler): - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) - await webSocket.handle(handler: handler.handler, context: handler.context) + case .notUpgraded(let http1): + await self.handleHTTP(asyncChannel: http1, logger: logger) + + case .failedUpgrade(let http1, let logger): + logger.debug("Websocket upgrade failed") + await self.write405(asyncChannel: http1, logger: logger) + + case .websocket(let asyncChannel, let handler, let logger): + logger.debug("Websocket upgrade") + await handler(asyncChannel) } } catch { logger.error("Error handling upgrade result: \(error)") @@ -130,12 +137,12 @@ public struct HTTP1AndWebSocketChannel: Serve } public var responder: @Sendable (Request, Channel) async throws -> Response - let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture>> + let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture> let configuration: WebSocketServerConfiguration let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler] } -extension HTTP1AndWebSocketChannel where Context == WebSocketContext { +extension HTTP1AndWebSocketChannel { /// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function /// - Parameters: /// - additionalChannelHandlers: Additional channel handlers to add @@ -147,17 +154,19 @@ extension HTTP1AndWebSocketChannel where Context == WebSocketContext { responder: @escaping @Sendable (Request, Channel) async throws -> Response, configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> ) { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration self.shouldUpgrade = { head, channel, logger in - channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult> in + channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in try shouldUpgrade(head, channel, logger) - .map { - let logger = logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID())) - let context = WebSocketContext(channel: channel, logger: logger) - return WebSocketDataHandlerAndContext(context: context, handler: $0) + .map { handler in + return { asyncChannel in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + let context = WebSocketContext(channel: channel, logger: logger) + await webSocket.handle(handler: handler, context: context) + } } } } @@ -175,18 +184,20 @@ extension HTTP1AndWebSocketChannel where Context == WebSocketContext { responder: @escaping @Sendable (Request, Channel) async throws -> Response, configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> ) { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration self.shouldUpgrade = { head, channel, logger in - let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult>.self) + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) promise.completeWithTask { try await shouldUpgrade(head, channel, logger) - .map { - let logger = logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID())) - let context = WebSocketContext(channel: channel, logger: logger) - return WebSocketDataHandlerAndContext(context: context, handler: $0) + .map { handler in + return { asyncChannel in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + let context = WebSocketContext(channel: channel, logger: logger) + await webSocket.handle(handler: handler, context: context) + } } } return promise.futureResult diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift index 47227db..fe55a7b 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift @@ -24,7 +24,7 @@ extension HTTPChannelBuilder { configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> - ) -> HTTPChannelBuilder> { + ) -> HTTPChannelBuilder { return .init { responder in return HTTP1AndWebSocketChannel( responder: responder, @@ -40,7 +40,7 @@ extension HTTPChannelBuilder { configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> - ) -> HTTPChannelBuilder> { + ) -> HTTPChannelBuilder { return .init { responder in return HTTP1AndWebSocketChannel( responder: responder, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index 85a6295..a9bbfd2 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -24,11 +24,17 @@ import NIOCore /// /// Includes reference to optional websocket handler public struct WebSocketRouterContext: Sendable { + /// Holds WebSocket context and handler to call + struct Value: Sendable { + let context: Context + let handler: WebSocketDataHandler + } + public init() { self.handler = .init(nil) } - let handler: NIOLockedValueBox?> + let handler: NIOLockedValueBox } /// Request context protocol requirement for routers that support websockets @@ -71,7 +77,7 @@ extension RouterMethods { case .dontUpgrade: return .init(status: .methodNotAllowed) case .upgrade(let headers): - context.webSocket.handler.withLockedValue { $0 = WebSocketDataHandlerAndContext(context: context, handler: handle) } + context.webSocket.handler.withLockedValue { $0 = WebSocketRouterContext.Value(context: context, handler: handle) } return .init(status: .ok, headers: headers) } } @@ -111,7 +117,7 @@ public struct WebSocketUpgradeMiddleware: Rout } } -extension HTTP1AndWebSocketChannel where Context: WebSocketRequestContext { +extension HTTP1AndWebSocketChannel { /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function /// - Parameters: /// - additionalChannelHandlers: Additional channel handlers to add @@ -124,18 +130,21 @@ extension HTTP1AndWebSocketChannel where Context: WebSocketRequestContext { webSocketResponder: WSResponder, configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] } - ) where WSResponder.Context == Context { + ) where WSResponder.Context: WebSocketRequestContext { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration self.shouldUpgrade = { head, channel, logger in - let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult>.self) + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) promise.completeWithTask { let request = Request(head: head, body: .init(buffer: .init())) - let context = Context(channel: channel, logger: logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID()))) + let context = WSResponder.Context(channel: channel, logger: logger) do { let response = try await webSocketResponder.respond(to: request, context: context) if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { - return .upgrade(response.headers, webSocketHandler) + return .upgrade(response.headers) { asyncChannel in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + await webSocket.handle(handler: webSocketHandler.handler, context: webSocketHandler.context) + } } else { return .dontUpgrade } @@ -165,7 +174,7 @@ extension HTTPChannelBuilder { webSocketRouter: WSResponderBuilder, configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [] - ) -> HTTPChannelBuilder> where WSResponderBuilder.Responder.Context: WebSocketRequestContext { + ) -> HTTPChannelBuilder where WSResponderBuilder.Responder.Context: WebSocketRequestContext { let webSocketReponder = webSocketRouter.buildResponder() return .init { responder in return HTTP1AndWebSocketChannel( diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index ac3d54b..74b4c2a 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -18,22 +18,5 @@ import Logging import NIOCore import NIOWebSocket -/// Handle websocket data and text blocks +/// Function that handles websocket data and text blocks public typealias WebSocketDataHandler = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void - -/// Struct holding for web socket data handler and context. -public struct WebSocketDataHandlerAndContext: Sendable { - /// Context sent to handler - let context: Context - /// handler function - let handler: WebSocketDataHandler - - public init(context: Context, handler: @escaping WebSocketDataHandler) { - self.context = context - self.handler = handler - } - - func withContext(channel: Channel, logger: Logger) -> Self { - .init(context: .init(channel: channel, logger: logger), handler: self.handler) - } -} From c4e3e38b1fc411c247266c75b85a11bb1939f1e0 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 20 Mar 2024 14:37:27 +0000 Subject: [PATCH 4/7] Upgrade handle function takes a Logger --- .../HummingbirdWebSocket/Server/WebSocketChannel.swift | 8 ++++---- Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 2ee05de..f1d6388 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -25,7 +25,7 @@ import NIOWebSocket /// Child channel supporting a web socket upgrade from HTTP1 public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { - public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel) async -> Void + public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel, Logger) async -> Void /// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel) public enum UpgradeResult { case websocket(NIOAsyncChannel, WebSocketChannelHandler, Logger) @@ -103,7 +103,7 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { case .websocket(let asyncChannel, let handler, let logger): logger.debug("Websocket upgrade") - await handler(asyncChannel) + await handler(asyncChannel, logger) } } catch { logger.error("Error handling upgrade result: \(error)") @@ -162,7 +162,7 @@ extension HTTP1AndWebSocketChannel { channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in try shouldUpgrade(head, channel, logger) .map { handler in - return { asyncChannel in + return { asyncChannel, logger in let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) let context = WebSocketContext(channel: channel, logger: logger) await webSocket.handle(handler: handler, context: context) @@ -193,7 +193,7 @@ extension HTTP1AndWebSocketChannel { promise.completeWithTask { try await shouldUpgrade(head, channel, logger) .map { handler in - return { asyncChannel in + return { asyncChannel, logger in let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) let context = WebSocketContext(channel: channel, logger: logger) await webSocket.handle(handler: handler, context: context) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index a9bbfd2..2c60583 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -141,7 +141,7 @@ extension HTTP1AndWebSocketChannel { do { let response = try await webSocketResponder.respond(to: request, context: context) if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { - return .upgrade(response.headers) { asyncChannel in + return .upgrade(response.headers) { asyncChannel, _ in let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) await webSocket.handle(handler: webSocketHandler.handler, context: webSocketHandler.context) } From 2719f8e467eca692fc6205ce990290cfe82282ee Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 20 Mar 2024 14:42:49 +0000 Subject: [PATCH 5/7] Move code around --- .../Server/WebSocketChannel.swift | 126 +++++++++--------- 1 file changed, 62 insertions(+), 64 deletions(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index f1d6388..7759a25 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -35,6 +35,68 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { public typealias Value = EventLoopFuture + /// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function + /// - Parameters: + /// - additionalChannelHandlers: Additional channel handlers to add + /// - responder: HTTP responder + /// - maxFrameSize: Max frame size WebSocket will allow + /// - shouldUpgrade: Function returning whether upgrade should be allowed + /// - Returns: Upgrade result future + public init( + responder: @escaping @Sendable (Request, Channel) async throws -> Response, + configuration: WebSocketServerConfiguration, + additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> + ) { + self.additionalChannelHandlers = additionalChannelHandlers + self.configuration = configuration + self.shouldUpgrade = { head, channel, logger in + channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in + try shouldUpgrade(head, channel, logger) + .map { handler in + return { asyncChannel, logger in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + let context = WebSocketContext(channel: channel, logger: logger) + await webSocket.handle(handler: handler, context: context) + } + } + } + } + self.responder = responder + } + + /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function + /// - Parameters: + /// - additionalChannelHandlers: Additional channel handlers to add + /// - responder: HTTP responder + /// - maxFrameSize: Max frame size WebSocket will allow + /// - shouldUpgrade: Function returning whether upgrade should be allowed + /// - Returns: Upgrade result future + public init( + responder: @escaping @Sendable (Request, Channel) async throws -> Response, + configuration: WebSocketServerConfiguration, + additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> + ) { + self.additionalChannelHandlers = additionalChannelHandlers + self.configuration = configuration + self.shouldUpgrade = { head, channel, logger in + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) + promise.completeWithTask { + try await shouldUpgrade(head, channel, logger) + .map { handler in + return { asyncChannel, logger in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + let context = WebSocketContext(channel: channel, logger: logger) + await webSocket.handle(handler: handler, context: context) + } + } + } + return promise.futureResult + } + self.responder = responder + } + /// Setup channel to accept HTTP1 with a WebSocket upgrade /// - Parameters: /// - channel: Child channel @@ -141,67 +203,3 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { let configuration: WebSocketServerConfiguration let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler] } - -extension HTTP1AndWebSocketChannel { - /// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function - /// - Parameters: - /// - additionalChannelHandlers: Additional channel handlers to add - /// - responder: HTTP responder - /// - maxFrameSize: Max frame size WebSocket will allow - /// - shouldUpgrade: Function returning whether upgrade should be allowed - /// - Returns: Upgrade result future - public init( - responder: @escaping @Sendable (Request, Channel) async throws -> Response, - configuration: WebSocketServerConfiguration, - additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> - ) { - self.additionalChannelHandlers = additionalChannelHandlers - self.configuration = configuration - self.shouldUpgrade = { head, channel, logger in - channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in - try shouldUpgrade(head, channel, logger) - .map { handler in - return { asyncChannel, logger in - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) - let context = WebSocketContext(channel: channel, logger: logger) - await webSocket.handle(handler: handler, context: context) - } - } - } - } - self.responder = responder - } - - /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function - /// - Parameters: - /// - additionalChannelHandlers: Additional channel handlers to add - /// - responder: HTTP responder - /// - maxFrameSize: Max frame size WebSocket will allow - /// - shouldUpgrade: Function returning whether upgrade should be allowed - /// - Returns: Upgrade result future - public init( - responder: @escaping @Sendable (Request, Channel) async throws -> Response, - configuration: WebSocketServerConfiguration, - additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> - ) { - self.additionalChannelHandlers = additionalChannelHandlers - self.configuration = configuration - self.shouldUpgrade = { head, channel, logger in - let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) - promise.completeWithTask { - try await shouldUpgrade(head, channel, logger) - .map { handler in - return { asyncChannel, logger in - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) - let context = WebSocketContext(channel: channel, logger: logger) - await webSocket.handle(handler: handler, context: context) - } - } - } - return promise.futureResult - } - self.responder = responder - } -} From 8f6164fab7a39321d3a272e258045b819a594c55 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 20 Mar 2024 14:57:23 +0000 Subject: [PATCH 6/7] Add testRouterContextUpdate --- .../WebSocketTests.swift | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index 4f436f6..6de603d 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -447,6 +447,43 @@ final class HummingbirdWebSocketTests: XCTestCase { } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} } + /// Test context from router is passed through to web socket + func testRouterContextUpdate() async throws { + struct MyRequestContext: WebSocketRequestContext { + var coreContext: CoreRequestContext + var webSocket: WebSocketRouterContext + var name: String + + init(channel: Channel, logger: Logger) { + self.coreContext = .init(allocator: channel.allocator, logger: logger) + self.webSocket = .init() + self.name = "" + } + } + struct MyMiddleware: RouterMiddleware { + func handle(_ request: Request, context: MyRequestContext, next: (Request, MyRequestContext) async throws -> Response) async throws -> Response { + var context = context + context.name = "Roger Moore" + return try await next(request, context) + } + } + let router = Router(context: MyRequestContext.self) + router.middlewares.add(MyMiddleware()) + router.ws("/ws") { _, _ in + return .upgrade([:]) + } handle: { _, outbound, context in + try await outbound.write(.text(context.name)) + } + do { + try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in + try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { inbound, _, _ in + let text = await inbound.first { _ in true } + XCTAssertEqual(text, .text("Roger Moore")) + } + } + } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} + } + func testHTTPRequest() async throws { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in From 0d0cb03f00b74c1fe2a4301e791315f4497156e8 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 22 Mar 2024 10:24:26 +0000 Subject: [PATCH 7/7] handle -> onUpgrade --- .../Server/WebSocketRouter.swift | 8 ++++---- Tests/HummingbirdWebSocketTests/WebSocketTests.swift | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index 2c60583..54e1570 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -65,11 +65,11 @@ extension RouterMethods { /// - Parameters: /// - path: Path to match /// - shouldUpgrade: Should request be upgraded - /// - handle: WebSocket channel handler + /// - handler: WebSocket channel handler @discardableResult public func ws( _ path: String = "", shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, - handle: @escaping WebSocketDataHandler + onUpgrade handler: @escaping WebSocketDataHandler ) -> Self where Context: WebSocketRequestContext { return on(path, method: .get) { request, context -> Response in let result = try await shouldUpgrade(request, context) @@ -77,7 +77,7 @@ extension RouterMethods { case .dontUpgrade: return .init(status: .methodNotAllowed) case .upgrade(let headers): - context.webSocket.handler.withLockedValue { $0 = WebSocketRouterContext.Value(context: context, handler: handle) } + context.webSocket.handler.withLockedValue { $0 = WebSocketRouterContext.Value(context: context, handler: handler) } return .init(status: .ok, headers: headers) } } @@ -98,7 +98,7 @@ public struct WebSocketUpgradeMiddleware: Rout /// - handle: WebSocket handler public init( shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, - handler: @escaping WebSocketDataHandler + onUpgrade handler: @escaping WebSocketDataHandler ) { self.shouldUpgrade = shouldUpgrade self.handler = handler diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index 6de603d..b7666bf 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -393,12 +393,12 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws1") { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("One")) } router.ws("/ws2") { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("Two")) } try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in @@ -422,7 +422,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router.group("/ws") .add(middleware: WebSocketUpgradeMiddleware { _, _ in return .upgrade([:]) - } handler: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("One")) }) .get { _, _ -> Response in return .init(status: .ok) } @@ -437,7 +437,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("One")) } do { @@ -471,7 +471,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router.middlewares.add(MyMiddleware()) router.ws("/ws") { _, _ in return .upgrade([:]) - } handle: { _, outbound, context in + } onUpgrade: { _, outbound, context in try await outbound.write(.text(context.name)) } do { @@ -488,7 +488,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("Hello")) } router.get("/http") { _, _ in