Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup API #42

Merged
merged 7 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Expand Up @@ -11,7 +11,7 @@ let package = Package(
// .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]),
],
dependencies: [
.package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "2.0.0-beta.1"),
.package(url: "https://github.com/hummingbird-project/hummingbird.git", branch: "main"),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted to main temporarily until beta 2 is released

.package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"),
Expand Down
20 changes: 10 additions & 10 deletions Sources/HummingbirdWebSocket/Client/WebSocketClient.swift
Expand Up @@ -27,7 +27,7 @@ import ServiceLifecycle
///
/// Supports TLS via both NIOSSL and Network framework.
///
/// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run``
/// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run()``
/// to connect. The handler is provider with an `inbound` stream of WebSocket packets coming
/// from the server and an `outbound` writer that can be used to write packets to the server.
/// ```swift
Expand All @@ -50,7 +50,7 @@ public struct WebSocketClient {
/// WebSocket URL
let url: URI
/// WebSocket data handler
let handler: WebSocketDataCallbackHandler
let handler: WebSocketDataHandler<WebSocketContext>
/// configuration
let configuration: WebSocketClientConfiguration
/// EventLoopGroup to use
Expand All @@ -75,10 +75,10 @@ public struct WebSocketClient {
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
process: @escaping WebSocketDataCallbackHandler.Callback
handler: @escaping WebSocketDataHandler<WebSocketContext>
) throws {
self.url = url
self.handler = .init(process)
self.handler = handler
self.configuration = configuration
self.eventLoopGroup = eventLoopGroup
self.logger = logger
Expand All @@ -101,10 +101,10 @@ public struct WebSocketClient {
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
process: @escaping WebSocketDataCallbackHandler.Callback
handler: @escaping WebSocketDataHandler<WebSocketContext>
) throws {
self.url = url
self.handler = .init(process)
self.handler = handler
self.configuration = configuration
self.eventLoopGroup = eventLoopGroup
self.logger = logger
Expand Down Expand Up @@ -193,15 +193,15 @@ extension WebSocketClient {
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
process: @escaping WebSocketDataCallbackHandler.Callback
handler: @escaping WebSocketDataHandler<WebSocketContext>
) async throws {
let ws = try self.init(
url: url,
configuration: configuration,
tlsConfiguration: tlsConfiguration,
eventLoopGroup: eventLoopGroup,
logger: logger,
process: process
handler: handler
)
try await ws.run()
}
Expand All @@ -222,15 +222,15 @@ extension WebSocketClient {
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
process: @escaping WebSocketDataCallbackHandler.Callback
handler: @escaping WebSocketDataHandler<WebSocketContext>
) async throws {
let ws = try self.init(
url: url,
configuration: configuration,
transportServicesTLSOptions: transportServicesTLSOptions,
eventLoopGroup: eventLoopGroup,
logger: logger,
process: process
handler: handler
)
try await ws.run()
}
Expand Down
17 changes: 8 additions & 9 deletions Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift
Expand Up @@ -20,25 +20,25 @@ import NIOHTTP1
import NIOHTTPTypesHTTP1
import NIOWebSocket

public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConnectionChannel {
public enum UpgradeResult {
struct WebSocketClientChannel: ClientConnectionChannel {
enum UpgradeResult {
case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>)
case notUpgraded
}

public typealias Value = EventLoopFuture<UpgradeResult>
typealias Value = EventLoopFuture<UpgradeResult>

let url: String
let handler: Handler
let handler: WebSocketDataHandler<WebSocketContext>
let configuration: WebSocketClientConfiguration

init(handler: Handler, url: String, configuration: WebSocketClientConfiguration) {
init(handler: @escaping WebSocketDataHandler<WebSocketContext>, url: String, configuration: WebSocketClientConfiguration) {
self.url = url
self.handler = handler
self.configuration = configuration
}

public func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture<Value> {
func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture<Value> {
channel.eventLoop.makeCompletedFuture {
let upgrader = NIOTypedWebSocketClientUpgrader<UpgradeResult>(
maxFrameSize: self.configuration.maxFrameSize,
Expand Down Expand Up @@ -81,12 +81,11 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne
}
}

public func handle(value: Value, logger: Logger) async throws {
func handle(value: Value, logger: Logger) async throws {
switch try await value.get() {
case .websocket(let webSocketChannel):
let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client)
let context = self.handler.alreadySetupContext ?? .init(channel: webSocketChannel.channel, logger: logger)
await webSocket.handle(handler: self.handler, context: context)
await webSocket.handle(handler: self.handler, context: WebSocketContext(channel: webSocketChannel.channel, logger: logger))
case .notUpgraded:
// The upgrade to websocket did not succeed.
logger.debug("Upgrade declined")
Expand Down
Expand Up @@ -23,6 +23,16 @@ import NIOWebSocket
public enum ShouldUpgradeResult<Value: Sendable>: Sendable {
case dontUpgrade
case upgrade(HTTPFields, Value)

/// Map upgrade result to difference type
func map<Result>(_ map: (Value) throws -> Result) rethrows -> ShouldUpgradeResult<Result> {
switch self {
case .dontUpgrade:
return .dontUpgrade
case .upgrade(let headers, let value):
return try .upgrade(headers, map(value))
}
}
}

extension NIOTypedWebSocketServerUpgrader {
Expand All @@ -47,7 +57,7 @@ extension NIOTypedWebSocketServerUpgrader {
/// websocket protocol. This only needs to add the user handlers: the
/// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the
/// pipeline automatically.
public convenience init<Value>(
convenience init<Value>(
maxFrameSize: Int = 1 << 14,
enableAutomaticErrorHandling: Bool = true,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequest) -> EventLoopFuture<ShouldUpgradeResult<Value>>,
Expand Down
62 changes: 42 additions & 20 deletions Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

import HTTPTypes
import Hummingbird
import HummingbirdCore
import Logging
import NIOConcurrencyHelpers
Expand All @@ -23,11 +24,13 @@ import NIOHTTPTypesHTTP1
import NIOWebSocket

/// Child channel supporting a web socket upgrade from HTTP1
public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChildChannel, HTTPChannelHandler {
public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler {
public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, Logger) async -> Void
/// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel)
public enum UpgradeResult {
case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, Handler)
case notUpgraded(NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>, failed: Bool)
case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, WebSocketChannelHandler, Logger)
case notUpgraded(NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>)
case failedUpgrade(NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>, Logger)
}

public typealias Value = EventLoopFuture<UpgradeResult>
Expand All @@ -43,13 +46,20 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
responder: @escaping @Sendable (Request, Channel) async throws -> Response,
configuration: WebSocketServerConfiguration,
additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] },
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<Handler>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<WebSocketContext>>
) {
self.additionalChannelHandlers = additionalChannelHandlers
self.configuration = configuration
self.shouldUpgrade = { head, channel, logger in
channel.eventLoop.makeCompletedFuture {
channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult<WebSocketChannelHandler> in
try shouldUpgrade(head, channel, logger)
.map { handler in
return { asyncChannel, logger in
let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server)
let context = WebSocketContext(channel: channel, logger: logger)
await webSocket.handle(handler: handler, context: context)
}
}
}
}
self.responder = responder
Expand All @@ -66,14 +76,21 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
responder: @escaping @Sendable (Request, Channel) async throws -> Response,
configuration: WebSocketServerConfiguration,
additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] },
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<Handler>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<WebSocketContext>>
) {
self.additionalChannelHandlers = additionalChannelHandlers
self.configuration = configuration
self.shouldUpgrade = { head, channel, logger in
let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult<Handler>.self)
let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult<WebSocketChannelHandler>.self)
promise.completeWithTask {
try await shouldUpgrade(head, channel, logger)
.map { handler in
return { asyncChannel, logger in
let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server)
let context = WebSocketContext(channel: channel, logger: logger)
await webSocket.handle(handler: handler, context: context)
}
}
}
return promise.futureResult
}
Expand All @@ -89,6 +106,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
public func setup(channel: Channel, logger: Logger) -> EventLoopFuture<Value> {
return channel.eventLoop.makeCompletedFuture {
let upgradeAttempted = NIOLoopBoundBox(false, eventLoop: channel.eventLoop)
let logger = logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID()))
let upgrader = NIOTypedWebSocketServerUpgrader<UpgradeResult>(
maxFrameSize: self.configuration.maxFrameSize,
shouldUpgrade: { channel, head in
Expand All @@ -98,7 +116,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
upgradePipelineHandler: { channel, handler in
channel.eventLoop.makeCompletedFuture {
let asyncChannel = try NIOAsyncChannel<WebSocketFrame, WebSocketFrame>(wrappingChannelSynchronously: channel)
return UpgradeResult.websocket(asyncChannel, handler)
return UpgradeResult.websocket(asyncChannel, handler, logger)
}
}
)
Expand All @@ -113,7 +131,11 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandlers(childChannelHandlers)
let asyncChannel = try NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>(wrappingChannelSynchronously: channel)
return UpgradeResult.notUpgraded(asyncChannel, failed: upgradeAttempted.value)
if upgradeAttempted.value {
return UpgradeResult.failedUpgrade(asyncChannel, logger)
} else {
return UpgradeResult.notUpgraded(asyncChannel)
}
}
}
)
Expand All @@ -134,16 +156,16 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
do {
let result = try await upgradeResult.get()
switch result {
case .notUpgraded(let http1, let failed):
if failed {
await self.write405(asyncChannel: http1, logger: logger)
} else {
await self.handleHTTP(asyncChannel: http1, logger: logger)
}
case .websocket(let asyncChannel, let handler):
let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server)
let context = handler.alreadySetupContext ?? .init(channel: asyncChannel.channel, logger: logger)
await webSocket.handle(handler: handler, context: context)
case .notUpgraded(let http1):
await self.handleHTTP(asyncChannel: http1, logger: logger)

case .failedUpgrade(let http1, let logger):
logger.debug("Websocket upgrade failed")
await self.write405(asyncChannel: http1, logger: logger)

case .websocket(let asyncChannel, let handler, let logger):
logger.debug("Websocket upgrade")
await handler(asyncChannel, logger)
}
} catch {
logger.error("Error handling upgrade result: \(error)")
Expand Down Expand Up @@ -177,7 +199,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
}

public var responder: @Sendable (Request, Channel) async throws -> Response
let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture<ShouldUpgradeResult<Handler>>
let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture<ShouldUpgradeResult<WebSocketChannelHandler>>
let configuration: WebSocketServerConfiguration
let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler]
}
Expand Up @@ -20,11 +20,11 @@ import NIOCore
extension HTTPChannelBuilder {
/// HTTP1 channel builder supporting a websocket upgrade
/// - parameters
public static func webSocketUpgrade<Handler: WebSocketDataHandler>(
public static func webSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<Handler>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel<Handler>> {
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<WebSocketContext>>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel> {
return .init { responder in
return HTTP1AndWebSocketChannel(
responder: responder,
Expand All @@ -36,13 +36,13 @@ extension HTTPChannelBuilder {
}

/// HTTP1 channel builder supporting a websocket upgrade
public static func webSocketUpgrade<Handler: WebSocketDataHandler>(
public static func webSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<Handler>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel<Handler>> {
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<WebSocketContext>>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel> {
return .init { responder in
return HTTP1AndWebSocketChannel<Handler>(
return HTTP1AndWebSocketChannel(
responder: responder,
configuration: configuration,
additionalChannelHandlers: additionalChannelHandlers,
Expand Down