Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket Router #41

Merged
merged 7 commits into from Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 13 additions & 17 deletions Snippets/WebsocketTest.swift
@@ -1,28 +1,24 @@
import HTTPTypes
import Hummingbird
import HummingbirdWebSocket
import NIOHTTP1

let router = Router()
let router = Router(context: BasicWebSocketRequestContext.self)
router.middlewares.add(FileMiddleware("Snippets/public"))
router.get { _, _ in
"Hello"
}

router.middlewares.add(FileMiddleware("Snippets/public"))
let app = Application(
router: router,
server: .webSocketUpgrade { _, head in
if head.uri == "/ws" {
return .upgrade(HTTPHeaders()) { inbound, outbound, _ in
for try await packet in inbound {
if case .text("disconnect") = packet {
break
}
try await outbound.write(.custom(packet.webSocketFrame))
}
}
} else {
return .dontUpgrade
router.ws("/ws") { inbound, outbound, _ in
for try await packet in inbound {
if case .text("disconnect") = packet {
break
}
try await outbound.write(.custom(packet.webSocketFrame))
}
}

let app = Application(
router: router,
server: .webSocketUpgrade(webSocketRouter: router)
)
try await app.runService()
Expand Up @@ -85,9 +85,9 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne

public func handle(value: Value, logger: Logger) async throws {
switch try await value.get() {
case .websocket(let websocketChannel):
let webSocket = WebSocketHandler(asyncChannel: websocketChannel, type: .client)
let context = self.handler.alreadySetupContext ?? .init(logger: logger, allocator: websocketChannel.channel.allocator)
case .websocket(let webSocketChannel):
let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client)
let context = self.handler.alreadySetupContext ?? .init(channel: webSocketChannel.channel, logger: logger)
await webSocket.handle(handler: self.handler, context: context)
case .notUpgraded:
// The upgrade to websocket did not succeed.
Expand Down
Expand Up @@ -12,14 +12,17 @@
//
//===----------------------------------------------------------------------===//

import HTTPTypes
import NIOConcurrencyHelpers
import NIOCore
import NIOHTTP1
import NIOHTTPTypesHTTP1
import NIOWebSocket

/// Should HTTP channel upgrade to WebSocket
public enum ShouldUpgradeResult<Value: Sendable>: Sendable {
case dontUpgrade
case upgrade(HTTPHeaders, Value)
case upgrade(HTTPFields, Value)
}

extension NIOTypedWebSocketServerUpgrader {
Expand Down Expand Up @@ -47,21 +50,27 @@ extension NIOTypedWebSocketServerUpgrader {
public convenience init<Value>(
maxFrameSize: Int = 1 << 14,
enableAutomaticErrorHandling: Bool = true,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture<ShouldUpgradeResult<Value>>,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequest) -> EventLoopFuture<ShouldUpgradeResult<Value>>,
upgradePipelineHandler: @escaping @Sendable (Channel, Value) -> EventLoopFuture<UpgradeResult>
) {
let shouldUpgradeResult = NIOLockedValueBox<Value?>(nil)
self.init(
maxFrameSize: maxFrameSize,
enableAutomaticErrorHandling: enableAutomaticErrorHandling,
shouldUpgrade: { channel, head in
shouldUpgrade(channel, head).map { result in
shouldUpgrade: { (channel, head: HTTPRequestHead) in
let request: HTTPRequest
do {
request = try HTTPRequest(head, secure: false, splitCookie: false)
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
return shouldUpgrade(channel, request).map { result in
switch result {
case .dontUpgrade:
return nil
case .upgrade(let headers, let value):
shouldUpgradeResult.withLockedValue { $0 = value }
return headers
return .init(headers)
}
}
},
Expand Down
22 changes: 11 additions & 11 deletions Sources/HummingbirdWebSocket/Server/WebSocketChannelHandler.swift
Expand Up @@ -32,7 +32,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi

public typealias Value = EventLoopFuture<UpgradeResult>

/// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function
/// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function
/// - Parameters:
/// - additionalChannelHandlers: Additional channel handlers to add
/// - responder: HTTP responder
Expand All @@ -43,19 +43,19 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] },
responder: @escaping @Sendable (Request, Channel) async throws -> Response = { _, _ in throw HTTPError(.notImplemented) },
maxFrameSize: Int = (1 << 14),
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) throws -> ShouldUpgradeResult<Handler>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<Handler>
) {
self.additionalChannelHandlers = additionalChannelHandlers
self.maxFrameSize = maxFrameSize
self.shouldUpgrade = { channel, head in
self.shouldUpgrade = { head, channel, logger in
channel.eventLoop.makeCompletedFuture {
try shouldUpgrade(channel, head)
try shouldUpgrade(head, channel, logger)
}
}
self.responder = responder
}

/// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function
/// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function
/// - Parameters:
/// - additionalChannelHandlers: Additional channel handlers to add
/// - responder: HTTP responder
Expand All @@ -66,14 +66,14 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] },
responder: @escaping @Sendable (Request, Channel) async throws -> Response = { _, _ in throw HTTPError(.notImplemented) },
maxFrameSize: Int = (1 << 14),
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) async throws -> ShouldUpgradeResult<Handler>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<Handler>
) {
self.additionalChannelHandlers = additionalChannelHandlers
self.maxFrameSize = maxFrameSize
self.shouldUpgrade = { channel, head in
self.shouldUpgrade = { head, channel, logger in
let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult<Handler>.self)
promise.completeWithTask {
try await shouldUpgrade(channel, head)
try await shouldUpgrade(head, channel, logger)
}
return promise.futureResult
}
Expand All @@ -91,7 +91,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
let upgrader = NIOTypedWebSocketServerUpgrader<UpgradeResult>(
maxFrameSize: self.maxFrameSize,
shouldUpgrade: { channel, head in
self.shouldUpgrade(channel, head)
self.shouldUpgrade(head, channel, logger)
},
upgradePipelineHandler: { channel, handler in
channel.eventLoop.makeCompletedFuture {
Expand Down Expand Up @@ -136,7 +136,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
await handleHTTP(asyncChannel: http1, logger: logger)
case .websocket(let asyncChannel, let handler):
let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server)
let context = handler.alreadySetupContext ?? .init(logger: logger, allocator: asyncChannel.channel.allocator)
let context = handler.alreadySetupContext ?? .init(channel: asyncChannel.channel, logger: logger)
await webSocket.handle(handler: handler, context: context)
}
} catch {
Expand All @@ -145,7 +145,7 @@ public struct HTTP1AndWebSocketChannel<Handler: WebSocketDataHandler>: ServerChi
}

public var responder: @Sendable (Request, Channel) async throws -> Response
let shouldUpgrade: @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture<ShouldUpgradeResult<Handler>>
let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture<ShouldUpgradeResult<Handler>>
let maxFrameSize: Int
let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler]
}
Expand Up @@ -12,17 +12,18 @@
//
//===----------------------------------------------------------------------===//

import HTTPTypes
import HummingbirdCore
import Logging
import NIOCore
import NIOHTTP1

extension HTTPChannelBuilder {
/// HTTP1 channel builder supporting a websocket upgrade
/// - parameters
public static func webSocketUpgrade<Handler: WebSocketDataHandler>(
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
maxFrameSize: Int = 1 << 14,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) async throws -> ShouldUpgradeResult<Handler>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<Handler>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel<Handler>> {
return .init { responder in
return HTTP1AndWebSocketChannel(
Expand All @@ -38,7 +39,7 @@ extension HTTPChannelBuilder {
public static func webSocketUpgrade<Handler: WebSocketDataHandler>(
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
maxFrameSize: Int = 1 << 14,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) throws -> ShouldUpgradeResult<Handler>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<Handler>
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel<Handler>> {
return .init { responder in
return HTTP1AndWebSocketChannel<Handler>(
Expand Down
157 changes: 157 additions & 0 deletions Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift
@@ -0,0 +1,157 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2023-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Atomics
import HTTPTypes
import Hummingbird
import HummingbirdCore
import Logging
import NIOConcurrencyHelpers
import NIOCore

public struct WebSocketRouterContext: Sendable {
public init() {
self.handler = .init(nil)
}

let handler: NIOLockedValueBox<WebSocketDataCallbackHandler?>
}

public protocol WebSocketRequestContext: RequestContext, WebSocketContextProtocol {
var webSocket: WebSocketRouterContext { get }
}

public struct BasicWebSocketRequestContext: WebSocketRequestContext {
public var coreContext: CoreRequestContext
public let webSocket: WebSocketRouterContext

public init(channel: Channel, logger: Logger) {
self.coreContext = .init(allocator: channel.allocator, logger: logger)
self.webSocket = .init()
}
}

public enum RouterShouldUpgrade: Sendable {
case dontUpgrade
case upgrade(HTTPFields)
}

extension RouterMethods {
/// GET path for async closure returning type conforming to ResponseGenerator
@discardableResult public func ws(
_ path: String = "",
shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) },
handle: @escaping WebSocketDataCallbackHandler.Callback
) -> Self where Context: WebSocketRequestContext {
return on(path, method: .get) { request, context -> Response in
let result = try await shouldUpgrade(request, context)
switch result {
case .dontUpgrade:
return .init(status: .notAcceptable)
case .upgrade(let headers):
context.webSocket.handler.withLockedValue { $0 = WebSocketDataCallbackHandler(handle) }
return .init(status: .ok, headers: headers)
}
}
}
}

extension HTTP1AndWebSocketChannel {
/// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function
/// - Parameters:
/// - additionalChannelHandlers: Additional channel handlers to add
/// - responder: HTTP responder
/// - maxFrameSize: Max frame size WebSocket will allow
/// - webSocketRouter: WebSocket router
/// - Returns: Upgrade result future
public init<Context: WebSocketRequestContext, ResponderBuilder: HTTPResponderBuilder>(
additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] },
responder: @escaping @Sendable (Request, Channel) async throws -> Response = { _, _ in throw HTTPError(.notImplemented) },
maxFrameSize: Int = (1 << 14),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a parameter configuration: WebSocketServerConfiguration?

webSocketRouter: ResponderBuilder
) where Handler == WebSocketDataCallbackHandler, ResponderBuilder.Responder.Context == Context {
let webSocketRouterResponder = webSocketRouter.buildResponder()
self.init(additionalChannelHandlers: additionalChannelHandlers, responder: responder, maxFrameSize: maxFrameSize) { head, channel, logger in
let request = Request(head: head, body: .init(buffer: .init()))
let context = Context(channel: channel, logger: logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID())))
do {
let response = try await webSocketRouterResponder.respond(to: request, context: context)
if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) {
return .upgrade(response.headers, webSocketHandler)
} else {
return .dontUpgrade
}
} catch {
return .dontUpgrade
}
}
}
}

extension HTTPChannelBuilder {
/// HTTP1 channel builder supporting a websocket upgrade
/// - parameters
public static func webSocketUpgrade<ResponderBuilder: HTTPResponderBuilder>(
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
maxFrameSize: Int = 1 << 14,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a parameter configuration: WebSocketServerConfiguration?
maxFrameSize is definitely the most important configurable property, but I can imagine us adding a few more in the future.

webSocketRouter: ResponderBuilder
) -> HTTPChannelBuilder<HTTP1AndWebSocketChannel<WebSocketDataCallbackHandler>> where ResponderBuilder.Responder.Context: WebSocketRequestContext {
return .init { responder in
return HTTP1AndWebSocketChannel(
additionalChannelHandlers: additionalChannelHandlers,
responder: responder,
maxFrameSize: maxFrameSize,
webSocketRouter: webSocketRouter
)
}
}
}

extension Logger {
/// Create new Logger with additional metadata value
/// - Parameters:
/// - metadataKey: Metadata key
/// - value: Metadata value
/// - Returns: Logger
func with(metadataKey: String, value: MetadataValue) -> Logger {
var logger = self
logger[metadataKey: metadataKey] = value
return logger
}
}

/// Generate Unique ID for each request
package struct RequestID: CustomStringConvertible {
let low: UInt64

package init() {
self.low = Self.globalRequestID.loadThenWrappingIncrement(by: 1, ordering: .relaxed)
}

package var description: String {
Self.high + self.formatAsHexWithLeadingZeros(self.low)
}

func formatAsHexWithLeadingZeros(_ value: UInt64) -> String {
let string = String(value, radix: 16)
if string.count < 16 {
return String(repeating: "0", count: 16 - string.count) + string
} else {
return string
}
}

private static let high = String(UInt64.random(in: .min ... .max), radix: 16)
private static let globalRequestID = ManagedAtomic<UInt64>(UInt64.random(in: .min ... .max))
}
Comment on lines +188 to +210
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed now

6 changes: 3 additions & 3 deletions Sources/HummingbirdWebSocket/WebSocketContext.swift
Expand Up @@ -19,16 +19,16 @@ import NIOCore
public protocol WebSocketContextProtocol: Sendable {
var logger: Logger { get }
var allocator: ByteBufferAllocator { get }
init(logger: Logger, allocator: ByteBufferAllocator)
init(channel: Channel, logger: Logger)
}

/// Default implementation of ``WebSocketContextProtocol``
public struct WebSocketContext: WebSocketContextProtocol {
public let logger: Logger
public let allocator: ByteBufferAllocator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can imagine that the remote address is useful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise the original HTTPRequest could be useful in a WebSocketContextProtocol?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically a duplicate of the requirements from the RequestContext so including the request might be hard. You get the request in the shouldUpgrade call so things like permessage-deflate can be negotiated there


public init(logger: Logger, allocator: ByteBufferAllocator) {
public init(channel: Channel, logger: Logger) {
self.logger = logger
self.allocator = allocator
self.allocator = channel.allocator
}
}