diff --git a/Package.swift b/Package.swift index bda3e6a..911fe6e 100644 --- a/Package.swift +++ b/Package.swift @@ -11,7 +11,7 @@ let package = Package( // .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]), ], dependencies: [ - .package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "2.0.0-beta.1"), + .package(url: "https://github.com/hummingbird-project/hummingbird.git", branch: "main"), .package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"), diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift index e5b3720..99aeb8a 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift @@ -27,7 +27,7 @@ import ServiceLifecycle /// /// Supports TLS via both NIOSSL and Network framework. /// -/// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run`` +/// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run()`` /// to connect. The handler is provider with an `inbound` stream of WebSocket packets coming /// from the server and an `outbound` writer that can be used to write packets to the server. /// ```swift @@ -50,7 +50,7 @@ public struct WebSocketClient { /// WebSocket URL let url: URI /// WebSocket data handler - let handler: WebSocketDataCallbackHandler + let handler: WebSocketDataHandler /// configuration let configuration: WebSocketClientConfiguration /// EventLoopGroup to use @@ -75,10 +75,10 @@ public struct WebSocketClient { tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, - process: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler ) throws { self.url = url - self.handler = .init(process) + self.handler = handler self.configuration = configuration self.eventLoopGroup = eventLoopGroup self.logger = logger @@ -101,10 +101,10 @@ public struct WebSocketClient { transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, - process: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler ) throws { self.url = url - self.handler = .init(process) + self.handler = handler self.configuration = configuration self.eventLoopGroup = eventLoopGroup self.logger = logger @@ -193,7 +193,7 @@ extension WebSocketClient { tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, - process: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler ) async throws { let ws = try self.init( url: url, @@ -201,7 +201,7 @@ extension WebSocketClient { tlsConfiguration: tlsConfiguration, eventLoopGroup: eventLoopGroup, logger: logger, - process: process + handler: handler ) try await ws.run() } @@ -222,7 +222,7 @@ extension WebSocketClient { transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, - process: @escaping WebSocketDataCallbackHandler.Callback + handler: @escaping WebSocketDataHandler ) async throws { let ws = try self.init( url: url, @@ -230,7 +230,7 @@ extension WebSocketClient { transportServicesTLSOptions: transportServicesTLSOptions, eventLoopGroup: eventLoopGroup, logger: logger, - process: process + handler: handler ) try await ws.run() } diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift index dfed093..20d9330 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift @@ -20,25 +20,25 @@ import NIOHTTP1 import NIOHTTPTypesHTTP1 import NIOWebSocket -public struct WebSocketClientChannel: ClientConnectionChannel { - public enum UpgradeResult { +struct WebSocketClientChannel: ClientConnectionChannel { + enum UpgradeResult { case websocket(NIOAsyncChannel) case notUpgraded } - public typealias Value = EventLoopFuture + typealias Value = EventLoopFuture let url: String - let handler: Handler + let handler: WebSocketDataHandler let configuration: WebSocketClientConfiguration - init(handler: Handler, url: String, configuration: WebSocketClientConfiguration) { + init(handler: @escaping WebSocketDataHandler, url: String, configuration: WebSocketClientConfiguration) { self.url = url self.handler = handler self.configuration = configuration } - public func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture { + func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture { channel.eventLoop.makeCompletedFuture { let upgrader = NIOTypedWebSocketClientUpgrader( maxFrameSize: self.configuration.maxFrameSize, @@ -81,12 +81,11 @@ public struct WebSocketClientChannel: ClientConne } } - public func handle(value: Value, logger: Logger) async throws { + func handle(value: Value, logger: Logger) async throws { switch try await value.get() { case .websocket(let webSocketChannel): let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client) - let context = self.handler.alreadySetupContext ?? .init(channel: webSocketChannel.channel, logger: logger) - await webSocket.handle(handler: self.handler, context: context) + await webSocket.handle(handler: self.handler, context: WebSocketContext(channel: webSocketChannel.channel, logger: logger)) case .notUpgraded: // The upgrade to websocket did not succeed. logger.debug("Upgrade declined") diff --git a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift index 9994496..3ca339c 100644 --- a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift +++ b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift @@ -23,6 +23,16 @@ import NIOWebSocket public enum ShouldUpgradeResult: Sendable { case dontUpgrade case upgrade(HTTPFields, Value) + + /// Map upgrade result to difference type + func map(_ map: (Value) throws -> Result) rethrows -> ShouldUpgradeResult { + switch self { + case .dontUpgrade: + return .dontUpgrade + case .upgrade(let headers, let value): + return try .upgrade(headers, map(value)) + } + } } extension NIOTypedWebSocketServerUpgrader { @@ -47,7 +57,7 @@ extension NIOTypedWebSocketServerUpgrader { /// websocket protocol. This only needs to add the user handlers: the /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the /// pipeline automatically. - public convenience init( + convenience init( maxFrameSize: Int = 1 << 14, enableAutomaticErrorHandling: Bool = true, shouldUpgrade: @escaping @Sendable (Channel, HTTPRequest) -> EventLoopFuture>, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 0f55d9b..7759a25 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import HTTPTypes +import Hummingbird import HummingbirdCore import Logging import NIOConcurrencyHelpers @@ -23,11 +24,13 @@ import NIOHTTPTypesHTTP1 import NIOWebSocket /// Child channel supporting a web socket upgrade from HTTP1 -public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { +public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { + public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel, Logger) async -> Void /// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel) public enum UpgradeResult { - case websocket(NIOAsyncChannel, Handler) - case notUpgraded(NIOAsyncChannel, failed: Bool) + case websocket(NIOAsyncChannel, WebSocketChannelHandler, Logger) + case notUpgraded(NIOAsyncChannel) + case failedUpgrade(NIOAsyncChannel, Logger) } public typealias Value = EventLoopFuture @@ -43,13 +46,20 @@ public struct HTTP1AndWebSocketChannel: ServerChi responder: @escaping @Sendable (Request, Channel) async throws -> Response, configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> ) { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration self.shouldUpgrade = { head, channel, logger in - channel.eventLoop.makeCompletedFuture { + channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in try shouldUpgrade(head, channel, logger) + .map { handler in + return { asyncChannel, logger in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + let context = WebSocketContext(channel: channel, logger: logger) + await webSocket.handle(handler: handler, context: context) + } + } } } self.responder = responder @@ -66,14 +76,21 @@ public struct HTTP1AndWebSocketChannel: ServerChi responder: @escaping @Sendable (Request, Channel) async throws -> Response, configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> ) { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration self.shouldUpgrade = { head, channel, logger in - let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) promise.completeWithTask { try await shouldUpgrade(head, channel, logger) + .map { handler in + return { asyncChannel, logger in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + let context = WebSocketContext(channel: channel, logger: logger) + await webSocket.handle(handler: handler, context: context) + } + } } return promise.futureResult } @@ -89,6 +106,7 @@ public struct HTTP1AndWebSocketChannel: ServerChi public func setup(channel: Channel, logger: Logger) -> EventLoopFuture { return channel.eventLoop.makeCompletedFuture { let upgradeAttempted = NIOLoopBoundBox(false, eventLoop: channel.eventLoop) + let logger = logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID())) let upgrader = NIOTypedWebSocketServerUpgrader( maxFrameSize: self.configuration.maxFrameSize, shouldUpgrade: { channel, head in @@ -98,7 +116,7 @@ public struct HTTP1AndWebSocketChannel: ServerChi upgradePipelineHandler: { channel, handler in channel.eventLoop.makeCompletedFuture { let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - return UpgradeResult.websocket(asyncChannel, handler) + return UpgradeResult.websocket(asyncChannel, handler, logger) } } ) @@ -113,7 +131,11 @@ public struct HTTP1AndWebSocketChannel: ServerChi return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandlers(childChannelHandlers) let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - return UpgradeResult.notUpgraded(asyncChannel, failed: upgradeAttempted.value) + if upgradeAttempted.value { + return UpgradeResult.failedUpgrade(asyncChannel, logger) + } else { + return UpgradeResult.notUpgraded(asyncChannel) + } } } ) @@ -134,16 +156,16 @@ public struct HTTP1AndWebSocketChannel: ServerChi do { let result = try await upgradeResult.get() switch result { - case .notUpgraded(let http1, let failed): - if failed { - await self.write405(asyncChannel: http1, logger: logger) - } else { - await self.handleHTTP(asyncChannel: http1, logger: logger) - } - case .websocket(let asyncChannel, let handler): - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) - let context = handler.alreadySetupContext ?? .init(channel: asyncChannel.channel, logger: logger) - await webSocket.handle(handler: handler, context: context) + case .notUpgraded(let http1): + await self.handleHTTP(asyncChannel: http1, logger: logger) + + case .failedUpgrade(let http1, let logger): + logger.debug("Websocket upgrade failed") + await self.write405(asyncChannel: http1, logger: logger) + + case .websocket(let asyncChannel, let handler, let logger): + logger.debug("Websocket upgrade") + await handler(asyncChannel, logger) } } catch { logger.error("Error handling upgrade result: \(error)") @@ -177,7 +199,7 @@ public struct HTTP1AndWebSocketChannel: ServerChi } public var responder: @Sendable (Request, Channel) async throws -> Response - let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture> + let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture> let configuration: WebSocketServerConfiguration let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler] } diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift index c02d4f2..fe55a7b 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift @@ -20,11 +20,11 @@ import NIOCore extension HTTPChannelBuilder { /// HTTP1 channel builder supporting a websocket upgrade /// - parameters - public static func webSocketUpgrade( + public static func webSocketUpgrade( configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult - ) -> HTTPChannelBuilder> { + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> + ) -> HTTPChannelBuilder { return .init { responder in return HTTP1AndWebSocketChannel( responder: responder, @@ -36,13 +36,13 @@ extension HTTPChannelBuilder { } /// HTTP1 channel builder supporting a websocket upgrade - public static func webSocketUpgrade( + public static func webSocketUpgrade( configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult - ) -> HTTPChannelBuilder> { + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> + ) -> HTTPChannelBuilder { return .init { responder in - return HTTP1AndWebSocketChannel( + return HTTP1AndWebSocketChannel( responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index eee133a..54e1570 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -23,23 +23,29 @@ import NIOCore /// WebSocket Router context type. /// /// Includes reference to optional websocket handler -public struct WebSocketRouterContext: Sendable { +public struct WebSocketRouterContext: Sendable { + /// Holds WebSocket context and handler to call + struct Value: Sendable { + let context: Context + let handler: WebSocketDataHandler + } + public init() { self.handler = .init(nil) } - let handler: NIOLockedValueBox + let handler: NIOLockedValueBox } /// Request context protocol requirement for routers that support websockets public protocol WebSocketRequestContext: RequestContext, WebSocketContextProtocol { - var webSocket: WebSocketRouterContext { get } + var webSocket: WebSocketRouterContext { get } } /// Default implementation of a request context that supports WebSockets public struct BasicWebSocketRequestContext: WebSocketRequestContext { public var coreContext: CoreRequestContext - public let webSocket: WebSocketRouterContext + public let webSocket: WebSocketRouterContext public init(channel: Channel, logger: Logger) { self.coreContext = .init(allocator: channel.allocator, logger: logger) @@ -59,11 +65,11 @@ extension RouterMethods { /// - Parameters: /// - path: Path to match /// - shouldUpgrade: Should request be upgraded - /// - handle: WebSocket channel handler + /// - handler: WebSocket channel handler @discardableResult public func ws( _ path: String = "", shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, - handle: @escaping WebSocketDataCallbackHandler.Callback + onUpgrade handler: @escaping WebSocketDataHandler ) -> Self where Context: WebSocketRequestContext { return on(path, method: .get) { request, context -> Response in let result = try await shouldUpgrade(request, context) @@ -71,7 +77,7 @@ extension RouterMethods { case .dontUpgrade: return .init(status: .methodNotAllowed) case .upgrade(let headers): - context.webSocket.handler.withLockedValue { $0 = WebSocketDataCallbackHandler(handle) } + context.webSocket.handler.withLockedValue { $0 = WebSocketRouterContext.Value(context: context, handler: handler) } return .init(status: .ok, headers: headers) } } @@ -84,7 +90,7 @@ extension RouterMethods { /// with ``Hummingbird/Router`` if you add a route immediately after it. public struct WebSocketUpgradeMiddleware: RouterMiddleware { let shouldUpgrade: @Sendable (Request, Context) async throws -> RouterShouldUpgrade - let handle: WebSocketDataCallbackHandler.Callback + let handler: WebSocketDataHandler /// Initialize WebSocketUpgradeMiddleare /// - Parameters: @@ -92,10 +98,10 @@ public struct WebSocketUpgradeMiddleware: Rout /// - handle: WebSocket handler public init( shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, - handle: @escaping WebSocketDataCallbackHandler.Callback + onUpgrade handler: @escaping WebSocketDataHandler ) { self.shouldUpgrade = shouldUpgrade - self.handle = handle + self.handler = handler } /// WebSocketUpgradeMiddleware handler @@ -105,7 +111,7 @@ public struct WebSocketUpgradeMiddleware: Rout case .dontUpgrade: return .init(status: .methodNotAllowed) case .upgrade(let headers): - context.webSocket.handler.withLockedValue { $0 = WebSocketDataCallbackHandler(self.handle) } + context.webSocket.handler.withLockedValue { $0 = .init(context: context, handler: self.handler) } return .init(status: .ok, headers: headers) } } @@ -119,26 +125,36 @@ extension HTTP1AndWebSocketChannel { /// - maxFrameSize: Max frame size WebSocket will allow /// - webSocketRouter: WebSocket router /// - Returns: Upgrade result future - public init( + public init( responder: @escaping @Sendable (Request, Channel) async throws -> Response, webSocketResponder: WSResponder, configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] } - ) where Handler == WebSocketDataCallbackHandler, WSResponder.Context == Context { - self.init(responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers) { head, channel, logger in - let request = Request(head: head, body: .init(buffer: .init())) - let context = Context(channel: channel, logger: logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID()))) - do { - let response = try await webSocketResponder.respond(to: request, context: context) - if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { - return .upgrade(response.headers, webSocketHandler) - } else { + ) where WSResponder.Context: WebSocketRequestContext { + self.additionalChannelHandlers = additionalChannelHandlers + self.configuration = configuration + self.shouldUpgrade = { head, channel, logger in + let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) + promise.completeWithTask { + let request = Request(head: head, body: .init(buffer: .init())) + let context = WSResponder.Context(channel: channel, logger: logger) + do { + let response = try await webSocketResponder.respond(to: request, context: context) + if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { + return .upgrade(response.headers) { asyncChannel, _ in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + await webSocket.handle(handler: webSocketHandler.handler, context: webSocketHandler.context) + } + } else { + return .dontUpgrade + } + } catch { return .dontUpgrade } - } catch { - return .dontUpgrade } + return promise.futureResult } + self.responder = responder } } @@ -158,7 +174,7 @@ extension HTTPChannelBuilder { webSocketRouter: WSResponderBuilder, configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [] - ) -> HTTPChannelBuilder> where WSResponderBuilder.Responder.Context: WebSocketRequestContext { + ) -> HTTPChannelBuilder where WSResponderBuilder.Responder.Context: WebSocketRequestContext { let webSocketReponder = webSocketRouter.buildResponder() return .init { responder in return HTTP1AndWebSocketChannel( @@ -183,28 +199,3 @@ extension Logger { return logger } } - -/// Generate Unique ID for each request. This is a duplicate of the RequestID in Hummingbird -package struct RequestID: CustomStringConvertible { - let low: UInt64 - - package init() { - self.low = Self.globalRequestID.loadThenWrappingIncrement(by: 1, ordering: .relaxed) - } - - package var description: String { - Self.high + self.formatAsHexWithLeadingZeros(self.low) - } - - func formatAsHexWithLeadingZeros(_ value: UInt64) -> String { - let string = String(value, radix: 16) - if string.count < 16 { - return String(repeating: "0", count: 16 - string.count) + string - } else { - return string - } - } - - private static let high = String(UInt64.random(in: .min ... .max), radix: 16) - private static let globalRequestID = ManagedAtomic(UInt64.random(in: .min ... .max)) -} diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 9c3321a..74b4c2a 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -14,56 +14,9 @@ import AsyncAlgorithms import HTTPTypes +import Logging import NIOCore import NIOWebSocket -/// Protocol for web socket data handling -/// -/// This is the users interface into HummingbirdWebSocket. They provide an implementation of this protocol when -/// contructing their WebSocket upgrade handler. The user needs to return a type conforming to this protocol in -/// the `shouldUpgrade` closure in HTTP1AndWebSocketChannel.init -public protocol WebSocketDataHandler: Sendable { - /// Context type supplied to the handle function. - /// - /// The `WebSocketDataHandler` can chose to setup a context or accept the default one from - /// ``WebSocketHandler``. - associatedtype Context: WebSocketContextProtocol = WebSocketContext - /// If a `WebSocketDataHandler` requires a context with custom data it should - /// setup this variable on initialization - var alreadySetupContext: Context? { get } - /// Handler WebSocket data packets - /// - Parameters: - /// - inbound: An AsyncSequence of text or binary WebSocket frames. - /// - outbound: An outbound Writer to write websocket frames to - /// - context: Associated context to this websocket channel - func handle(_ inbound: WebSocketHandlerInbound, _ outbound: WebSocketHandlerOutboundWriter, context: Context) async throws -} - -extension WebSocketDataHandler { - /// Default implementaion of ``alreadySetupContext`` returns nil, so the Context will be - /// created by the ``WebSocketChannelHandler`` - public var alreadySetupContext: Context? { nil } -} - -/// WebSocketDataHandler that is is initialized via a callback -public struct WebSocketDataCallbackHandler: WebSocketDataHandler { - public typealias Callback = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, WebSocketContext) async throws -> Void - - let callback: Callback - - public init(_ callback: @escaping Callback) { - self.callback = callback - } - - /// Handler WebSocket data packets by passing directly to the callback - public func handle(_ inbound: WebSocketHandlerInbound, _ outbound: WebSocketHandlerOutboundWriter, context: WebSocketContext) async throws { - try await self.callback(inbound, outbound, context) - } -} - -extension ShouldUpgradeResult where Value == WebSocketDataCallbackHandler { - /// Extension to ShouldUpgradeResult that takes just a callback - public static func upgrade(_ headers: HTTPFields, _ callback: @escaping WebSocketDataCallbackHandler.Callback) -> Self { - .upgrade(headers, WebSocketDataCallbackHandler(callback)) - } -} +/// Function that handles websocket data and text blocks +public typealias WebSocketDataHandler = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 13df27c..6bcac5e 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -41,10 +41,7 @@ actor WebSocketHandler: Sendable { } /// Handle WebSocket AsynChannel - func handle( - handler: Handler, - context: Handler.Context - ) async { + func handle(handler: @escaping WebSocketDataHandler, context: Context) async { try? await self.asyncChannel.executeThenClose { inbound, outbound in do { try await withThrowingTaskGroup(of: Void.self) { group in @@ -90,7 +87,7 @@ actor WebSocketHandler: Sendable { } group.addTask { // handle websocket data and text - try await handler.handle(webSocketHandlerInbound, webSocketHandlerOutbound, context: context) + try await handler(webSocketHandlerInbound, webSocketHandlerOutbound, context) try await self.close(code: .normalClosure, outbound: outbound, context: context) } try await group.next() diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index e286221..b7666bf 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -76,7 +76,7 @@ final class HummingbirdWebSocketTests: XCTestCase { func testClientAndServer( serverTLSConfiguration: TLSConfiguration? = nil, - server serverHandler: @escaping WebSocketDataCallbackHandler.Callback, + server serverHandler: @escaping WebSocketDataHandler, shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] }, getClient: @escaping @Sendable (Int, Logger) throws -> WebSocketClient ) async throws { @@ -91,7 +91,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let serviceGroup: ServiceGroup let webSocketUpgrade: HTTPChannelBuilder = .webSocketUpgrade { head, _, _ in if let headers = try shouldUpgrade(head) { - return .upgrade(headers, WebSocketDataCallbackHandler(serverHandler)) + return .upgrade(headers, serverHandler) } else { return .dontUpgrade } @@ -144,9 +144,9 @@ final class HummingbirdWebSocketTests: XCTestCase { func testClientAndServer( serverTLSConfiguration: TLSConfiguration? = nil, - server serverHandler: @escaping WebSocketDataCallbackHandler.Callback, + server serverHandler: @escaping WebSocketDataHandler, shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] }, - client clientHandler: @escaping WebSocketDataCallbackHandler.Callback + client clientHandler: @escaping WebSocketDataHandler ) async throws { try await self.testClientAndServer( serverTLSConfiguration: serverTLSConfiguration, @@ -156,7 +156,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try WebSocketClient( url: .init("ws://localhost:\(port)"), logger: logger, - process: clientHandler + handler: clientHandler ) } ) @@ -393,12 +393,12 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws1") { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("One")) } router.ws("/ws2") { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("Two")) } try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in @@ -422,7 +422,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router.group("/ws") .add(middleware: WebSocketUpgradeMiddleware { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("One")) }) .get { _, _ -> Response in return .init(status: .ok) } @@ -437,7 +437,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("One")) } do { @@ -447,11 +447,48 @@ final class HummingbirdWebSocketTests: XCTestCase { } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} } + /// Test context from router is passed through to web socket + func testRouterContextUpdate() async throws { + struct MyRequestContext: WebSocketRequestContext { + var coreContext: CoreRequestContext + var webSocket: WebSocketRouterContext + var name: String + + init(channel: Channel, logger: Logger) { + self.coreContext = .init(allocator: channel.allocator, logger: logger) + self.webSocket = .init() + self.name = "" + } + } + struct MyMiddleware: RouterMiddleware { + func handle(_ request: Request, context: MyRequestContext, next: (Request, MyRequestContext) async throws -> Response) async throws -> Response { + var context = context + context.name = "Roger Moore" + return try await next(request, context) + } + } + let router = Router(context: MyRequestContext.self) + router.middlewares.add(MyMiddleware()) + router.ws("/ws") { _, _ in + return .upgrade([:]) + } onUpgrade: { _, outbound, context in + try await outbound.write(.text(context.name)) + } + do { + try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in + try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { inbound, _, _ in + let text = await inbound.first { _ in true } + XCTAssertEqual(text, .text("Roger Moore")) + } + } + } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} + } + func testHTTPRequest() async throws { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in return .upgrade([:]) - } handle: { _, outbound, _ in + } onUpgrade: { _, outbound, _ in try await outbound.write(.text("Hello")) } router.get("/http") { _, _ in