Skip to content

Commit

Permalink
Cleanup API (#42)
Browse files Browse the repository at this point in the history
* WebSocketDataHandler is now a struct

* Renaming

* Remove generic parameter from HTTP1AndWebSocketChannel

* Upgrade handle function takes a Logger

* Move code around

* Add testRouterContextUpdate

* handle -> onUpgrade
  • Loading branch information
adam-fowler committed Mar 22, 2024
1 parent 60f1f10 commit 0f61374
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 162 deletions.
2 changes: 1 addition & 1 deletion Package.swift
Expand Up @@ -11,7 +11,7 @@ let package = Package(
// .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]),
],
dependencies: [
.package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "2.0.0-beta.1"),
.package(url: "https://github.com/hummingbird-project/hummingbird.git", branch: "main"),
.package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"),
Expand Down
20 changes: 10 additions & 10 deletions Sources/HummingbirdWebSocket/Client/WebSocketClient.swift
Expand Up @@ -27,7 +27,7 @@ import ServiceLifecycle
///
/// Supports TLS via both NIOSSL and Network framework.
///
/// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run``
/// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run()``
/// to connect. The handler is provider with an `inbound` stream of WebSocket packets coming
/// from the server and an `outbound` writer that can be used to write packets to the server.
/// ```swift
Expand All @@ -50,7 +50,7 @@ public struct WebSocketClient {
/// WebSocket URL
let url: URI
/// WebSocket data handler
let handler: WebSocketDataCallbackHandler
let handler: WebSocketDataHandler<WebSocketContext>
/// configuration
let configuration: WebSocketClientConfiguration
/// EventLoopGroup to use
Expand All @@ -75,10 +75,10 @@ public struct WebSocketClient {
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
process: @escaping WebSocketDataCallbackHandler.Callback
handler: @escaping WebSocketDataHandler<WebSocketContext>
) throws {
self.url = url
self.handler = .init(process)
self.handler = handler
self.configuration = configuration
self.eventLoopGroup = eventLoopGroup
self.logger = logger
Expand All @@ -101,10 +101,10 @@ public struct WebSocketClient {
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
process: @escaping WebSocketDataCallbackHandler.Callback
handler: @escaping WebSocketDataHandler<WebSocketContext>
) throws {
self.url = url
self.handler = .init(process)
self.handler = handler
self.configuration = configuration
self.eventLoopGroup = eventLoopGroup
self.logger = logger
Expand Down Expand Up @@ -193,15 +193,15 @@ extension WebSocketClient {
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
process: @escaping WebSocketDataCallbackHandler.Callback
handler: @escaping WebSocketDataHandler<WebSocketContext>
) async throws {
let ws = try self.init(
url: url,
configuration: configuration,
tlsConfiguration: tlsConfiguration,
eventLoopGroup: eventLoopGroup,
logger: logger,
process: process
handler: handler
)
try await ws.run()
}
Expand All @@ -222,15 +222,15 @@ extension WebSocketClient {
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
process: @escaping WebSocketDataCallbackHandler.Callback
handler: @escaping WebSocketDataHandler<WebSocketContext>
) async throws {
let ws = try self.init(
url: url,
configuration: configuration,
transportServicesTLSOptions: transportServicesTLSOptions,
eventLoopGroup: eventLoopGroup,
logger: logger,
process: process
handler: handler
)
try await ws.run()
}
Expand Down
17 changes: 8 additions & 9 deletions Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift
Expand Up @@ -20,25 +20,25 @@ import NIOHTTP1
import NIOHTTPTypesHTTP1
import NIOWebSocket

public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConnectionChannel {
public enum UpgradeResult {
struct WebSocketClientChannel: ClientConnectionChannel {
enum UpgradeResult {
case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>)
case notUpgraded
}

public typealias Value = EventLoopFuture<UpgradeResult>
typealias Value = EventLoopFuture<UpgradeResult>

let url: String
let handler: Handler
let handler: WebSocketDataHandler<WebSocketContext>
let configuration: WebSocketClientConfiguration

init(handler: Handler, url: String, configuration: WebSocketClientConfiguration) {
init(handler: @escaping WebSocketDataHandler<WebSocketContext>, url: String, configuration: WebSocketClientConfiguration) {
self.url = url
self.handler = handler
self.configuration = configuration
}

public func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture<Value> {
func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture<Value> {
channel.eventLoop.makeCompletedFuture {
let upgrader = NIOTypedWebSocketClientUpgrader<UpgradeResult>(
maxFrameSize: self.configuration.maxFrameSize,
Expand Down Expand Up @@ -81,12 +81,11 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne
}
}

public func handle(value: Value, logger: Logger) async throws {
func handle(value: Value, logger: Logger) async throws {
switch try await value.get() {
case .websocket(let webSocketChannel):
let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client)
let context = self.handler.alreadySetupContext ?? .init(channel: webSocketChannel.channel, logger: logger)
await webSocket.handle(handler: self.handler, context: context)
await webSocket.handle(handler: self.handler, context: WebSocketContext(channel: webSocketChannel.channel, logger: logger))
case .notUpgraded:
// The upgrade to websocket did not succeed.
logger.debug("Upgrade declined")
Expand Down
Expand Up @@ -23,6 +23,16 @@ import NIOWebSocket
public enum ShouldUpgradeResult<Value: Sendable>: Sendable {
case dontUpgrade
case upgrade(HTTPFields, Value)

/// Map upgrade result to difference type
func map<Result>(_ map: (Value) throws -> Result) rethrows -> ShouldUpgradeResult<Result> {
switch self {
case .dontUpgrade:
return .dontUpgrade
case .upgrade(let headers, let value):
return try .upgrade(headers, map(value))
}
}
}

extension NIOTypedWebSocketServerUpgrader {
Expand All @@ -47,7 +57,7 @@ extension NIOTypedWebSocketServerUpgrader {
/// websocket protocol. This only needs to add the user handlers: the
/// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the
/// pipeline automatically.
public convenience init<Value>(
convenience init<Value>(
maxFrameSize: Int = 1 << 14,
enableAutomaticErrorHandling: Bool = true,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequest) -> EventLoopFuture<ShouldUpgradeResult<Value>>,
Expand Down
62 changes: 42 additions & 20 deletions Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

import HTTPTypes
import Hummingbird
import HummingbirdCore
import Logging
import NIOConcurrencyHelpers
Expand All @@ -23,11 +24,13 @@ import NIOHTTPTypesHTTP1
import NIOWebSocket

/// Child channel supporting a web socket upgrade from HTTP1
public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChildChannel, HTTPChannelHandler {
public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler {
public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, Logger) async -> Void
/// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel)
public enum UpgradeResult {
case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, Handler)
case notUpgraded(NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>, failed: Bool)
case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, WebSocketChannelHandler, Logger)
case notUpgraded(NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>)
case failedUpgrade(NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>, Logger)
}

public typealias Value = EventLoopFuture<UpgradeResult>
Expand All @@ -43,13 +46,20 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
responder: @escaping @Sendable (Request, Channel) async throws -> Response,
configuration: WebSocketServerConfiguration,
additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] },
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<Handler>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<WebSocketContext>>
) {
self.additionalChannelHandlers = additionalChannelHandlers
self.configuration = configuration
self.shouldUpgrade = { head, channel, logger in
channel.eventLoop.makeCompletedFuture {
channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult<WebSocketChannelHandler> in
try shouldUpgrade(head, channel, logger)
.map { handler in
return { asyncChannel, logger in
let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server)
let context = WebSocketContext(channel: channel, logger: logger)
await webSocket.handle(handler: handler, context: context)
}
}
}
}
self.responder = responder
Expand All @@ -66,14 +76,21 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
responder: @escaping @Sendable (Request, Channel) async throws -> Response,
configuration: WebSocketServerConfiguration,
additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] },
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<Handler>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<WebSocketContext>>
) {
self.additionalChannelHandlers = additionalChannelHandlers
self.configuration = configuration
self.shouldUpgrade = { head, channel, logger in
let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult<Handler>.self)
let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult<WebSocketChannelHandler>.self)
promise.completeWithTask {
try await shouldUpgrade(head, channel, logger)
.map { handler in
return { asyncChannel, logger in
let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server)
let context = WebSocketContext(channel: channel, logger: logger)
await webSocket.handle(handler: handler, context: context)
}
}
}
return promise.futureResult
}
Expand All @@ -89,6 +106,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
public func setup(channel: Channel, logger: Logger) -> EventLoopFuture<Value> {
return channel.eventLoop.makeCompletedFuture {
let upgradeAttempted = NIOLoopBoundBox(false, eventLoop: channel.eventLoop)
let logger = logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID()))
let upgrader = NIOTypedWebSocketServerUpgrader<UpgradeResult>(
maxFrameSize: self.configuration.maxFrameSize,
shouldUpgrade: { channel, head in
Expand All @@ -98,7 +116,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
upgradePipelineHandler: { channel, handler in
channel.eventLoop.makeCompletedFuture {
let asyncChannel = try NIOAsyncChannel<WebSocketFrame, WebSocketFrame>(wrappingChannelSynchronously: channel)
return UpgradeResult.websocket(asyncChannel, handler)
return UpgradeResult.websocket(asyncChannel, handler, logger)
}
}
)
Expand All @@ -113,7 +131,11 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandlers(childChannelHandlers)
let asyncChannel = try NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>(wrappingChannelSynchronously: channel)
return UpgradeResult.notUpgraded(asyncChannel, failed: upgradeAttempted.value)
if upgradeAttempted.value {
return UpgradeResult.failedUpgrade(asyncChannel, logger)
} else {
return UpgradeResult.notUpgraded(asyncChannel)
}
}
}
)
Expand All @@ -134,16 +156,16 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
do {
let result = try await upgradeResult.get()
switch result {
case .notUpgraded(let http1, let failed):
if failed {
await self.write405(asyncChannel: http1, logger: logger)
} else {
await self.handleHTTP(asyncChannel: http1, logger: logger)
}
case .websocket(let asyncChannel, let handler):
let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server)
let context = handler.alreadySetupContext ?? .init(channel: asyncChannel.channel, logger: logger)
await webSocket.handle(handler: handler, context: context)
case .notUpgraded(let http1):
await self.handleHTTP(asyncChannel: http1, logger: logger)

case .failedUpgrade(let http1, let logger):
logger.debug("Websocket upgrade failed")
await self.write405(asyncChannel: http1, logger: logger)

case .websocket(let asyncChannel, let handler, let logger):
logger.debug("Websocket upgrade")
await handler(asyncChannel, logger)
}
} catch {
logger.error("Error handling upgrade result: \(error)")
Expand Down Expand Up @@ -177,7 +199,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
}

public var responder: @Sendable (Request, Channel) async throws -> Response
let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture<ShouldUpgradeResult<Handler>>
let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture<ShouldUpgradeResult<WebSocketChannelHandler>>
let configuration: WebSocketServerConfiguration
let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler]
}
Expand Up @@ -20,11 +20,11 @@ import NIOCore
extension HTTPChannelBuilder {
/// HTTP1 channel builder supporting a websocket upgrade
/// - parameters
public static func webSocketUpgrade<Handler: WebSocketDataHandler>(
public static func webSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<Handler>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel<Handler>> {
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<WebSocketContext>>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel> {
return .init { responder in
return HTTP1AndWebSocketChannel(
responder: responder,
Expand All @@ -36,13 +36,13 @@ extension HTTPChannelBuilder {
}

/// HTTP1 channel builder supporting a websocket upgrade
public static func webSocketUpgrade<Handler: WebSocketDataHandler>(
public static func webSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<Handler>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel<Handler>> {
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<WebSocketContext>>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel> {
return .init { responder in
return HTTP1AndWebSocketChannel<Handler>(
return HTTP1AndWebSocketChannel(
responder: responder,
configuration: configuration,
additionalChannelHandlers: additionalChannelHandlers,
Expand Down

0 comments on commit 0f61374

Please sign in to comment.