Skip to content

Commit

Permalink
Add async/await support to websockets (#9)
Browse files Browse the repository at this point in the history
* Add async/await support to websockets

* swift 5.6 fix

* More 5.6 fixes
  • Loading branch information
adam-fowler committed Jan 30, 2023
1 parent 157c9cd commit cbe5fec
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 9 deletions.
16 changes: 16 additions & 0 deletions Sources/HummingbirdWSClient/WebSocketClient.swift
Expand Up @@ -160,3 +160,19 @@ public enum HBWebSocketClient {
}
}
}

#if compiler(>=5.5.2) && canImport(_Concurrency)

@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
extension HBWebSocketClient {
/// Connect to WebSocket
/// - Parameters:
/// - url: URL of websocket
/// - configuration: Configuration of connection
/// - eventLoop: eventLoop to run connection on
public static func connect(url: HBURL, configuration: Configuration, on eventLoop: EventLoop) async throws -> HBWebSocket {
return try await self.connect(url: url, configuration: configuration, on: eventLoop).get()
}
}

#endif // compiler(>=5.5.2) && canImport(_Concurrency)
77 changes: 73 additions & 4 deletions Sources/HummingbirdWSCore/WebSocket.swift
Expand Up @@ -42,7 +42,7 @@ public final class HBWebSocket {
self.readCallback = cb
}

/// Set callback to be called whenever WebSocket receives data
/// Set callback to be called whenever WebSocket receives a pong
public func onPong(_ cb: @escaping PongCallback) {
self.pongCallback = cb
}
Expand All @@ -54,11 +54,20 @@ public final class HBWebSocket {
}
}

/// Write data to WebSocket
/// - Parameter data: data to be written
/// - Returns: future that is completed when data is written
public func write(_ data: WebSocketData) -> EventLoopFuture<Void> {
let promise = self.channel.eventLoop.makePromise(of: Void.self)
self.write(data, promise: promise)
return promise.futureResult
}

/// Write data to WebSocket
/// - Parameters:
/// - data: Data to be written
/// - promise:promise that is completed when data has been sent
public func write(_ data: WebSocketData, promise: EventLoopPromise<Void>? = nil) {
/// - promise: promise that is completed when data has been sent
public func write(_ data: WebSocketData, promise: EventLoopPromise<Void>?) {
switch data {
case .text(let string):
let buffer = self.channel.allocator.buffer(string: string)
Expand All @@ -68,11 +77,20 @@ public final class HBWebSocket {
}
}

/// Close websocket connection
/// - Parameter code:
/// - Returns: future that is complete once close message is sent
public func close(code: WebSocketErrorCode = .normalClosure) -> EventLoopFuture<Void> {
let promise = self.channel.eventLoop.makePromise(of: Void.self)
self.close(code: code, promise: promise)
return promise.futureResult
}

/// Close websocket connection
/// - Parameters:
/// - code: Close reason
/// - promise: promise that is completed when close has been sent
public func close(code: WebSocketErrorCode = .goingAway, promise: EventLoopPromise<Void>?) {
public func close(code: WebSocketErrorCode = .normalClosure, promise: EventLoopPromise<Void>?) {
guard self.isClosed == false else {
promise?.succeed(())
return
Expand All @@ -84,6 +102,16 @@ public final class HBWebSocket {
self.send(buffer: buffer, opcode: .connectionClose, fin: true, promise: promise)
}

/// Send ping message
/// - Returns: future that is complete when ping message is sent
public func sendPing() -> EventLoopFuture<Void> {
let promise = self.channel.eventLoop.makePromise(of: Void.self)
self.sendPing(promise: promise)
return promise.futureResult
}

/// Send ping message
/// - Parameter promise: promise that is completed when ping message has been sent
public func sendPing(promise: EventLoopPromise<Void>?) {
_ = self.channel.eventLoop.submit {
if self.waitingOnPong {
Expand Down Expand Up @@ -197,3 +225,44 @@ public final class HBWebSocket {
private var readCallback: ReadCallback?
private var isClosed: Bool = false
}

#if compiler(>=5.5.2) && canImport(_Concurrency)

@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
extension HBWebSocket {
/// Write data to WebSocket
/// - Parameters:
/// - data: Data to be written
public func write(_ data: WebSocketData) async throws {
return try await self.write(data).get()
}

/// Close websocket connection
/// - Parameter code: reason for closing socket
public func close(code: WebSocketErrorCode = .normalClosure) async throws {
return try await self.close(code: code).get()
}

/// Send ping message
public func sendPing() async throws {
return try await self.sendPing().get()
}

/// Return stream of web socket data
///
/// This uses the `onRead`` and `onClose` functions so should not be used
/// at the same time as these functions.
/// - Returns: Web socket data stream
public func readStream() -> AsyncStream<WebSocketData> {
return AsyncStream { cont in
self.onRead { data, _ in
cont.yield(data)
}
self.onClose { _ in
cont.finish()
}
}
}
}

#endif
20 changes: 20 additions & 0 deletions Sources/HummingbirdWebSocket/Application+WebSocket.swift
Expand Up @@ -98,3 +98,23 @@ extension HBApplication {
/// WebSocket interface
public var ws: WebSocket { .init(application: self) }
}

#if compiler(>=5.5.2) && canImport(_Concurrency)

@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
extension HBApplication.WebSocket {
/// Add WebSocket connection upgrade at given path
/// - Parameters:
/// - path: URI path connection upgrade is available
/// - shouldUpgrade: Return whether upgrade should be allowed
/// - onUpgrade: Called on upgrade with reference to WebSocket
@discardableResult public func on(
_ path: String = "",
shouldUpgrade: @escaping (HBRequest) async throws -> HTTPHeaders? = { _ in return nil },
onUpgrade: @escaping (HBRequest, HBWebSocket) async throws -> HTTPResponseStatus
) -> HBWebSocketRouterGroup {
self.routerGroup.on(path, shouldUpgrade: shouldUpgrade, onUpgrade: onUpgrade)
}
}

#endif // compiler(>=5.5.2) && canImport(_Concurrency)
67 changes: 63 additions & 4 deletions Sources/HummingbirdWebSocket/WebSocketRouterGroup.swift
Expand Up @@ -38,7 +38,7 @@ public struct HBWebSocketRouterGroup {
@discardableResult public func on(
_ path: String = "",
shouldUpgrade: @escaping (HBRequest) -> EventLoopFuture<HTTPHeaders?>,
onUpgrade: @escaping (HBRequest, HBWebSocket) throws -> Void
onUpgrade: @escaping (HBRequest, HBWebSocket) -> EventLoopFuture<HTTPResponseStatus>
) -> Self {
let responder = HBCallbackResponder { request in
var request = request
Expand All @@ -50,10 +50,9 @@ public struct HBWebSocketRouterGroup {
}
}
} else if let webSocket = request.webSocket {
return request.body.consumeBody(on: request.eventLoop).flatMapThrowing { buffer in
return request.body.consumeBody(on: request.eventLoop).flatMap { buffer in
request.body = .byteBuffer(buffer)
try onUpgrade(request, webSocket)
return HBResponse(status: .ok)
return onUpgrade(request, webSocket).map { HBResponse(status: $0) }
}
} else {
return request.failure(.upgradeRequired)
Expand All @@ -62,4 +61,64 @@ public struct HBWebSocketRouterGroup {
self.router.add(path, method: .GET, responder: self.middlewares.constructResponder(finalResponder: responder))
return self
}

/// Add path for websocket with shouldUpgrade and onUpgrade closures
/// - Parameters:
/// - path: URI path that the websocket upgrade will proceed
/// - shouldUpgrade: Closure indicating whether we should upgrade or not. Return a failed `EventLoopFuture` for no.
/// - onUpgrade: Closure called with web socket when connection has been upgraded
@discardableResult public func on(
_ path: String = "",
shouldUpgrade: @escaping (HBRequest) -> EventLoopFuture<HTTPHeaders?>,
onUpgrade: @escaping (HBRequest, HBWebSocket) throws -> Void
) -> Self {
return self.on(
path,
shouldUpgrade: shouldUpgrade,
onUpgrade: { request, ws -> EventLoopFuture<HTTPResponseStatus> in
do {
try onUpgrade(request, ws)
return request.eventLoop.makeSucceededFuture(.ok)
} catch {
return request.eventLoop.makeFailedFuture(error)
}
}
)
}
}

#if compiler(>=5.5.2) && canImport(_Concurrency)

@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
extension HBWebSocketRouterGroup {
/// Add WebSocket connection upgrade at given path
/// - Parameters:
/// - path: URI path connection upgrade is available
/// - shouldUpgrade: Return whether upgrade should be allowed
/// - onUpgrade: Called on upgrade with reference to WebSocket
@discardableResult public func on(
_ path: String = "",
shouldUpgrade: @escaping (HBRequest) async throws -> HTTPHeaders? = { _ in return nil },
onUpgrade: @escaping (HBRequest, HBWebSocket) async throws -> HTTPResponseStatus
) -> HBWebSocketRouterGroup {
self.on(
path,
shouldUpgrade: { request -> EventLoopFuture<HTTPHeaders?> in
let promise = request.eventLoop.makePromise(of: HTTPHeaders?.self)
promise.completeWithTask {
try await shouldUpgrade(request)
}
return promise.futureResult
},
onUpgrade: { request, ws -> EventLoopFuture<HTTPResponseStatus> in
let promise = request.eventLoop.makePromise(of: HTTPResponseStatus.self)
promise.completeWithTask {
try await onUpgrade(request, ws)
}
return promise.futureResult
}
)
}
}

#endif // compiler(>=5.5.2) && canImport(_Concurrency)
53 changes: 52 additions & 1 deletion Tests/HummingbirdWebSocketTests/WebSocketTests.swift
Expand Up @@ -66,6 +66,23 @@ final class HummingbirdWebSocketTests: XCTestCase {
return app
}

func setupClientAndServer(onServer: @escaping (HBWebSocket) async throws -> Void, onClient: @escaping (HBWebSocket) async throws -> Void) async throws -> HBApplication {
let app = HBApplication(configuration: .init(address: .hostname(port: 8080)))
// add HTTP to WebSocket upgrade
app.ws.addUpgrade()
// on websocket connect.
app.ws.on("/test", onUpgrade: { _, ws in
try await onServer(ws)
return .ok
})
try app.start()

let eventLoop = app.eventLoopGroup.next()
let ws = try await HBWebSocketClient.connect(url: "ws://localhost:8080/test", configuration: .init(), on: eventLoop)
try await onClient(ws)
return app
}

func testClientAndServerConnection() throws {
var serverHello: Bool = false
var clientHello: Bool = false
Expand All @@ -77,7 +94,7 @@ final class HummingbirdWebSocketTests: XCTestCase {
ws.onRead { data, ws in
XCTAssertEqual(data, .text("Hello"))
serverHello = true
ws.write(.text("Hello back"))
ws.write(.text("Hello back"), promise: nil)
}
},
onClient: { ws in
Expand Down Expand Up @@ -285,3 +302,37 @@ final class HummingbirdWebSocketTests: XCTestCase {
_ = try wsFuture.wait()
}
}

#if compiler(>=5.5.2) && canImport(_Concurrency)

@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
extension HummingbirdWebSocketTests {
func testServerAsyncReadWrite() async throws {
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
let promise = TimeoutPromise(eventLoop: elg.next(), timeout: .seconds(10))

let app = try await self.setupClientAndServer(
onServer: { ws in
let stream = ws.readStream()
Task {
for try await data in stream {
XCTAssertEqual(data, .text("Hello"))
}
ws.onClose { _ in
promise.succeed()
}
}
},
onClient: { ws in
try await ws.write(.text("Hello"))
try await ws.close()
}
)
defer { app.stop() }

try promise.wait()
}
}

#endif // compiler(>=5.5.2) && canImport(_Concurrency)

0 comments on commit cbe5fec

Please sign in to comment.