Skip to content

Commit

Permalink
Separate client library (#54)
Browse files Browse the repository at this point in the history
* Separate client into separate library

* Fix host header, close

* Don't drop left over bytes

* Copy Client from HummingbirdCore into WSClient module

Remove dependencies on Hummingbird

* Remove import HummingbirdCore I missed

* Fixes from tests with Autobahn (#55)

* Verify close code is correct

* Fail on receiving reserved opcode

* Add trace logging for errors

* Fixed connecting to echo.websocket.org. Needed SNI hostname
  • Loading branch information
adam-fowler committed Apr 12, 2024
1 parent 0016b8f commit 25963c8
Show file tree
Hide file tree
Showing 37 changed files with 1,456 additions and 58 deletions.
27 changes: 20 additions & 7 deletions Package.swift
Expand Up @@ -8,37 +8,50 @@ let package = Package(
platforms: [.macOS(.v14), .iOS(.v17), .tvOS(.v17)],
products: [
.library(name: "HummingbirdWebSocket", targets: ["HummingbirdWebSocket"]),
.library(name: "HummingbirdWSClient", targets: ["HummingbirdWSClient"]),
.library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]),
],
dependencies: [
.package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "2.0.0-beta.2"),
.package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.20.0"),
.package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"),
],
targets: [
.target(name: "HummingbirdWebSocket", dependencies: [
.byName(name: "HummingbirdWSCore"),
.product(name: "Hummingbird", package: "hummingbird"),
.product(name: "HummingbirdTLS", package: "hummingbird"),
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
.product(name: "NIOHTTPTypes", package: "swift-nio-extras"),
.product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"),
]),
.target(name: "HummingbirdWSClient", dependencies: [
.byName(name: "HummingbirdWSCore"),
.product(name: "HTTPTypes", package: "swift-http-types"),
.product(name: "Logging", package: "swift-log"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOHTTPTypes", package: "swift-nio-extras"),
.product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"),
.product(name: "NIOPosix", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl"),
.product(name: "NIOTransportServices", package: "swift-nio-transport-services"),
.product(name: "NIOWebSocket", package: "swift-nio"),
]),
.target(name: "HummingbirdWSCore", dependencies: [
.product(name: "HTTPTypes", package: "swift-http-types"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOWebSocket", package: "swift-nio"),
]),
.target(name: "HummingbirdWSCompression", dependencies: [
.byName(name: "HummingbirdWebSocket"),
.byName(name: "HummingbirdWSCore"),
.product(name: "CompressNIO", package: "compress-nio"),
]),
.testTarget(name: "HummingbirdWebSocketTests", dependencies: [
.byName(name: "HummingbirdWebSocket"),
.byName(name: "HummingbirdWSClient"),
.byName(name: "HummingbirdWSCompression"),
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "Hummingbird", package: "hummingbird"),
.product(name: "HummingbirdTesting", package: "hummingbird"),
.product(name: "HummingbirdTLS", package: "hummingbird"),
Expand Down
32 changes: 32 additions & 0 deletions Snippets/AutobahnClientTest.swift
@@ -0,0 +1,32 @@
import HummingbirdWSClient
import HummingbirdWSCompression
import Logging

let cases = 1...1

var logger = Logger(label: "TestClient")
logger.logLevel = .trace
do {
for c in cases {
logger.info("Case \(c)")
try await WebSocketClient.connect(
url: .init("ws://127.0.0.1:9001/runCase?case=\(c)&agent=HB"),
configuration: .init(maxFrameSize: 1 << 16, extensions: [.perMessageDeflate(maxDecompressedFrameSize: 65536)]),
logger: logger
) { inbound, outbound, _ in
for try await msg in inbound.messages(maxSize: .max) {
switch msg {
case .binary(let buffer):
try await outbound.write(.binary(buffer))
case .text(let string):
try await outbound.write(.text(string))
}
}
}
}
try await WebSocketClient.connect(url: .init("ws://127.0.0.1:9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in
for try await _ in inbound {}
}
} catch {
logger.error("Error: \(error)")
}
19 changes: 19 additions & 0 deletions Snippets/WebSocketClientTest.swift
@@ -0,0 +1,19 @@
import HummingbirdWSClient
import Logging

var logger = Logger(label: "TestClient")
logger.logLevel = .trace
do {
try await WebSocketClient.connect(
url: .init("https://echo.websocket.org"),
configuration: .init(maxFrameSize: 1 << 16),
logger: logger
) { inbound, outbound, _ in
try await outbound.write(.text("Hello"))
for try await msg in inbound.messages(maxSize: .max) {
print(msg)
}
}
} catch {
logger.error("Error: \(error)")
}
1 change: 1 addition & 0 deletions Snippets/WebsocketTest.swift
Expand Up @@ -14,6 +14,7 @@ router.get { _, _ in
}

router.ws("/ws") { inbound, outbound, _ in
try await outbound.write(.text("Hello"))
for try await frame in inbound {
if frame.opcode == .text, String(buffer: frame.data) == "disconnect", frame.fin == true {
break
Expand Down
34 changes: 34 additions & 0 deletions Sources/HummingbirdWSClient/Client/ClientChannel.swift
@@ -0,0 +1,34 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Logging
import NIOCore

/// ClientConnection child channel setup protocol
public protocol ClientConnectionChannel: Sendable {
associatedtype Value: Sendable

/// Setup child channel
/// - Parameters:
/// - channel: Child channel
/// - logger: Logger used during setup
/// - Returns: Object to process input/output on child channel
func setup(channel: Channel, logger: Logger) -> EventLoopFuture<Value>

/// handle messages being passed down the channel pipeline
/// - Parameters:
/// - value: Object to process input/output on child channel
/// - logger: Logger to use while processing messages
func handle(value: Value, logger: Logger) async throws
}
171 changes: 171 additions & 0 deletions Sources/HummingbirdWSClient/Client/ClientConnection.swift
@@ -0,0 +1,171 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Logging
import NIOCore
import NIOPosix
#if canImport(Network)
import Network
import NIOTransportServices
#endif

/// A generic client connection to a server.
///
/// Actual client protocol is implemented in `ClientChannel` generic parameter
public struct ClientConnection<ClientChannel: ClientConnectionChannel>: Sendable {
/// Address to connect to
public struct Address: Sendable, Equatable {
enum _Internal: Equatable {
case hostname(_ host: String, port: Int)
case unixDomainSocket(path: String)
}

let value: _Internal
init(_ value: _Internal) {
self.value = value
}

// Address define by host and port
public static func hostname(_ host: String, port: Int) -> Self { .init(.hostname(host, port: port)) }
// Address defined by unxi domain socket
public static func unixDomainSocket(path: String) -> Self { .init(.unixDomainSocket(path: path)) }
}

typealias ChannelResult = ClientChannel.Value
/// Logger used by Server
let logger: Logger
let eventLoopGroup: EventLoopGroup
let clientChannel: ClientChannel
let address: Address
#if canImport(Network)
let tlsOptions: NWProtocolTLS.Options?
#endif

/// Initialize Client
public init(
_ clientChannel: ClientChannel,
address: Address,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger
) {
self.clientChannel = clientChannel
self.address = address
self.eventLoopGroup = eventLoopGroup
self.logger = logger
#if canImport(Network)
self.tlsOptions = nil
#endif
}

#if canImport(Network)
/// Initialize Client with TLS options
public init(
_ clientChannel: ClientChannel,
address: Address,
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger
) throws {
self.clientChannel = clientChannel
self.address = address
self.eventLoopGroup = eventLoopGroup
self.logger = logger
self.tlsOptions = transportServicesTLSOptions.options
}
#endif

public func run() async throws {
let channelResult = try await self.makeClient(
clientChannel: self.clientChannel,
address: self.address
)
try await self.clientChannel.handle(value: channelResult, logger: self.logger)
}

/// Connect to server
func makeClient(clientChannel: ClientChannel, address: Address) async throws -> ChannelResult {
// get bootstrap
let bootstrap: ClientBootstrapProtocol
#if canImport(Network)
if let tsBootstrap = self.createTSBootstrap() {
bootstrap = tsBootstrap
} else {
#if os(iOS) || os(tvOS)
self.logger.warning("Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework")
#endif
bootstrap = self.createSocketsBootstrap()
}
#else
bootstrap = self.createSocketsBootstrap()
#endif

// connect
let result: ChannelResult
do {
switch address.value {
case .hostname(let host, let port):
result = try await bootstrap
.connect(host: host, port: port) { channel in
clientChannel.setup(channel: channel, logger: self.logger)
}
self.logger.debug("Client connnected to \(host):\(port)")
case .unixDomainSocket(let path):
result = try await bootstrap
.connect(unixDomainSocketPath: path) { channel in
clientChannel.setup(channel: channel, logger: self.logger)
}
self.logger.debug("Client connnected to socket path \(path)")
}
return result
} catch {
throw error
}
}

/// create a BSD sockets based bootstrap
private func createSocketsBootstrap() -> ClientBootstrap {
return ClientBootstrap(group: self.eventLoopGroup)
}

#if canImport(Network)
/// create a NIOTransportServices bootstrap using Network.framework
private func createTSBootstrap() -> NIOTSConnectionBootstrap? {
guard let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup) else {
return nil
}
if let tlsOptions {
return bootstrap.tlsOptions(tlsOptions)
}
return bootstrap
}
#endif
}

protocol ClientBootstrapProtocol {
func connect<Output: Sendable>(
host: String,
port: Int,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output

func connect<Output: Sendable>(
unixDomainSocketPath: String,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output
}

extension ClientBootstrap: ClientBootstrapProtocol {}
#if canImport(Network)
extension NIOTSConnectionBootstrap: ClientBootstrapProtocol {}
#endif

0 comments on commit 25963c8

Please sign in to comment.