Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket Router #41

Merged
merged 7 commits into from Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion Package.swift
Expand Up @@ -16,7 +16,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.21.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"),
.package(url: "https://github.com/swift-extras/swift-extras-base64.git", from: "0.5.0"),
.package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"),
Expand All @@ -41,6 +41,7 @@ let package = Package(
// .byName(name: "HummingbirdWSCompression"),
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "Hummingbird", package: "hummingbird"),
.product(name: "HummingbirdTesting", package: "hummingbird"),
.product(name: "HummingbirdTLS", package: "hummingbird"),
]),
]
Expand Down
30 changes: 13 additions & 17 deletions Snippets/WebsocketTest.swift
@@ -1,28 +1,24 @@
import HTTPTypes
import Hummingbird
import HummingbirdWebSocket
import NIOHTTP1

let router = Router()
let router = Router(context: BasicWebSocketRequestContext.self)
router.middlewares.add(FileMiddleware("Snippets/public"))
router.get { _, _ in
"Hello"
}

router.middlewares.add(FileMiddleware("Snippets/public"))
let app = Application(
router: router,
server: .webSocketUpgrade { _, head in
if head.uri == "/ws" {
return .upgrade(HTTPHeaders()) { inbound, outbound, _ in
for try await packet in inbound {
if case .text("disconnect") = packet {
break
}
try await outbound.write(.custom(packet.webSocketFrame))
}
}
} else {
return .dontUpgrade
router.ws("/ws") { inbound, outbound, _ in
for try await packet in inbound {
if case .text("disconnect") = packet {
break
}
try await outbound.write(.custom(packet.webSocketFrame))
}
}

let app = Application(
router: router,
server: .webSocketUpgrade(webSocketRouter: router)
)
try await app.runService()
39 changes: 9 additions & 30 deletions Sources/HummingbirdWebSocket/Client/WebSocketClient.swift
Expand Up @@ -40,25 +40,6 @@ import ServiceLifecycle
/// }
/// ```
public struct WebSocketClient {
public struct Configuration: Sendable {
/// Max websocket frame size that can be sent/received
public var maxFrameSize: Int
/// Additional headers to be sent with the initial HTTP request
public var additionalHeaders: HTTPFields

/// Initialize WebSocketClient configuration
/// - Paramters
/// - maxFrameSize: Max websocket frame size that can be sent/received
/// - additionalHeaders: Additional headers to be sent with the initial HTTP request
public init(
maxFrameSize: Int = (1 << 14),
additionalHeaders: HTTPFields = .init()
) {
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
}
}

enum MultiPlatformTLSConfiguration: Sendable {
case niossl(TLSConfiguration)
#if canImport(Network)
Expand All @@ -71,7 +52,7 @@ public struct WebSocketClient {
/// WebSocket data handler
let handler: WebSocketDataCallbackHandler
/// configuration
let configuration: Configuration
let configuration: WebSocketClientConfiguration
/// EventLoopGroup to use
let eventLoopGroup: EventLoopGroup
/// Logger
Expand All @@ -90,7 +71,7 @@ public struct WebSocketClient {
/// - logger: Logger
public init(
url: URI,
configuration: Configuration = .init(),
configuration: WebSocketClientConfiguration = .init(),
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
Expand All @@ -116,7 +97,7 @@ public struct WebSocketClient {
/// - logger: Logger
public init(
url: URI,
configuration: Configuration = .init(),
configuration: WebSocketClientConfiguration = .init(),
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
Expand All @@ -143,7 +124,7 @@ public struct WebSocketClient {
case .niossl(let tlsConfiguration):
let client = try ClientConnection(
TLSClientChannel(
WebSocketClientChannel(handler: handler, url: urlPath, maxFrameSize: self.configuration.maxFrameSize),
WebSocketClientChannel(handler: handler, url: urlPath, configuration: self.configuration),
tlsConfiguration: tlsConfiguration
),
address: .hostname(host, port: port),
Expand All @@ -155,7 +136,7 @@ public struct WebSocketClient {
#if canImport(Network)
case .ts(let tlsOptions):
let client = try ClientConnection(
WebSocketClientChannel(handler: handler, url: urlPath, maxFrameSize: self.configuration.maxFrameSize),
WebSocketClientChannel(handler: handler, url: urlPath, configuration: self.configuration),
address: .hostname(host, port: port),
transportServicesTLSOptions: tlsOptions,
eventLoopGroup: self.eventLoopGroup,
Expand All @@ -170,8 +151,7 @@ public struct WebSocketClient {
WebSocketClientChannel(
handler: handler,
url: urlPath,
maxFrameSize: self.configuration.maxFrameSize,
additionalHeaders: self.configuration.additionalHeaders
configuration: self.configuration
),
tlsConfiguration: TLSConfiguration.makeClientConfiguration()
),
Expand All @@ -186,8 +166,7 @@ public struct WebSocketClient {
WebSocketClientChannel(
handler: handler,
url: urlPath,
maxFrameSize: self.configuration.maxFrameSize,
additionalHeaders: self.configuration.additionalHeaders
configuration: self.configuration
),
address: .hostname(host, port: port),
eventLoopGroup: self.eventLoopGroup,
Expand All @@ -210,7 +189,7 @@ extension WebSocketClient {
/// - process: Closure handling webSocket
public static func connect(
url: URI,
configuration: Configuration = .init(),
configuration: WebSocketClientConfiguration = .init(),
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
Expand Down Expand Up @@ -239,7 +218,7 @@ extension WebSocketClient {
/// - process: WebSocket data handler
public static func connect(
url: URI,
configuration: Configuration = .init(),
configuration: WebSocketClientConfiguration = .init(),
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
Expand Down
18 changes: 8 additions & 10 deletions Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift
Expand Up @@ -30,20 +30,18 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne

let url: String
let handler: Handler
let maxFrameSize: Int
let additionalHeaders: HTTPFields
let configuration: WebSocketClientConfiguration

init(handler: Handler, url: String, maxFrameSize: Int = 1 << 14, additionalHeaders: HTTPFields = .init()) {
init(handler: Handler, url: String, configuration: WebSocketClientConfiguration) {
self.url = url
self.handler = handler
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
self.configuration = configuration
}

public func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture<Value> {
channel.eventLoop.makeCompletedFuture {
let upgrader = NIOTypedWebSocketClientUpgrader<UpgradeResult>(
maxFrameSize: maxFrameSize,
maxFrameSize: self.configuration.maxFrameSize,
upgradePipelineHandler: { channel, _ in
channel.eventLoop.makeCompletedFuture {
let asyncChannel = try NIOAsyncChannel<WebSocketFrame, WebSocketFrame>(wrappingChannelSynchronously: channel)
Expand All @@ -55,7 +53,7 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne
var headers = HTTPHeaders()
headers.add(name: "Content-Type", value: "text/plain; charset=utf-8")
headers.add(name: "Content-Length", value: "0")
let additionalHeaders = HTTPHeaders(self.additionalHeaders)
let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders)
headers.add(contentsOf: additionalHeaders)

let requestHead = HTTPRequestHead(
Expand Down Expand Up @@ -85,9 +83,9 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne

public func handle(value: Value, logger: Logger) async throws {
switch try await value.get() {
case .websocket(let websocketChannel):
let webSocket = WebSocketHandler(asyncChannel: websocketChannel, type: .client)
let context = self.handler.alreadySetupContext ?? .init(logger: logger, allocator: websocketChannel.channel.allocator)
case .websocket(let webSocketChannel):
let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client)
let context = self.handler.alreadySetupContext ?? .init(channel: webSocketChannel.channel, logger: logger)
await webSocket.handle(handler: self.handler, context: context)
case .notUpgraded:
// The upgrade to websocket did not succeed.
Expand Down
@@ -0,0 +1,34 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2023-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import HTTPTypes

public struct WebSocketClientConfiguration: Sendable {
/// Max websocket frame size that can be sent/received
public var maxFrameSize: Int
/// Additional headers to be sent with the initial HTTP request
public var additionalHeaders: HTTPFields

/// Initialize WebSocketClient configuration
/// - Paramters
/// - maxFrameSize: Max websocket frame size that can be sent/received
/// - additionalHeaders: Additional headers to be sent with the initial HTTP request
public init(
maxFrameSize: Int = (1 << 14),
additionalHeaders: HTTPFields = .init()
) {
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
}
}
Expand Up @@ -12,14 +12,17 @@
//
//===----------------------------------------------------------------------===//

import HTTPTypes
import NIOConcurrencyHelpers
import NIOCore
import NIOHTTP1
import NIOHTTPTypesHTTP1
import NIOWebSocket

/// Should HTTP channel upgrade to WebSocket
public enum ShouldUpgradeResult<Value: Sendable>: Sendable {
case dontUpgrade
case upgrade(HTTPHeaders, Value)
case upgrade(HTTPFields, Value)
}

extension NIOTypedWebSocketServerUpgrader {
Expand Down Expand Up @@ -47,21 +50,27 @@ extension NIOTypedWebSocketServerUpgrader {
public convenience init<Value>(
maxFrameSize: Int = 1 << 14,
enableAutomaticErrorHandling: Bool = true,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture<ShouldUpgradeResult<Value>>,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequest) -> EventLoopFuture<ShouldUpgradeResult<Value>>,
upgradePipelineHandler: @escaping @Sendable (Channel, Value) -> EventLoopFuture<UpgradeResult>
) {
let shouldUpgradeResult = NIOLockedValueBox<Value?>(nil)
self.init(
maxFrameSize: maxFrameSize,
enableAutomaticErrorHandling: enableAutomaticErrorHandling,
shouldUpgrade: { channel, head in
shouldUpgrade(channel, head).map { result in
shouldUpgrade: { (channel, head: HTTPRequestHead) in
let request: HTTPRequest
do {
request = try HTTPRequest(head, secure: false, splitCookie: false)
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
return shouldUpgrade(channel, request).map { result in
switch result {
case .dontUpgrade:
return nil
case .upgrade(let headers, let value):
shouldUpgradeResult.withLockedValue { $0 = value }
return headers
return .init(headers)
}
}
},
Expand Down