diff --git a/Snippets/WebsocketTest.swift b/Snippets/WebsocketTest.swift index fad3fce..1c2ac3a 100644 --- a/Snippets/WebsocketTest.swift +++ b/Snippets/WebsocketTest.swift @@ -1,7 +1,10 @@ import HTTPTypes import Hummingbird import HummingbirdWebSocket +import Logging +var logger = Logger(label: "Echo") +logger.logLevel = .trace let router = Router(context: BasicWebSocketRequestContext.self) router.middlewares.add(FileMiddleware("Snippets/public")) router.get { _, _ in @@ -19,6 +22,7 @@ router.ws("/ws") { inbound, outbound, _ in let app = Application( router: router, - server: .webSocketUpgrade(webSocketRouter: router) + server: .webSocketUpgrade(webSocketRouter: router), + logger: logger ) try await app.runService() diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 74b4c2a..6bc9a7a 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -19,4 +19,4 @@ import NIOCore import NIOWebSocket /// Function that handles websocket data and text blocks -public typealias WebSocketDataHandler = @Sendable (WebSocketHandlerInbound, WebSocketHandlerOutboundWriter, Context) async throws -> Void +public typealias WebSocketDataHandler = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index 6bcac5e..ed1bc61 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -16,96 +16,131 @@ import Hummingbird import Logging import NIOCore import NIOWebSocket +import ServiceLifecycle + +/// WebSocket type +public enum WebSocketType: Sendable { + case client + case server +} /// Handler processing raw WebSocket packets. /// /// Manages ping, pong and close messages. Collates data and text messages into final frame /// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user. actor WebSocketHandler: Sendable { - enum SocketType: Sendable { - case client - case server - } - static let pingDataSize = 16 let asyncChannel: NIOAsyncChannel - let type: SocketType - var closed = false + let type: WebSocketType + var closed: Bool var pingData: ByteBuffer - init(asyncChannel: NIOAsyncChannel, type: WebSocketHandler.SocketType) { + init(asyncChannel: NIOAsyncChannel, type: WebSocketType) { self.asyncChannel = asyncChannel self.type = type self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize) + self.closed = false } /// Handle WebSocket AsynChannel func handle(handler: @escaping WebSocketDataHandler, context: Context) async { - try? await self.asyncChannel.executeThenClose { inbound, outbound in - do { - try await withThrowingTaskGroup(of: Void.self) { group in - let webSocketHandlerInbound = WebSocketHandlerInbound() - defer { - asyncChannel.channel.close(promise: nil) - webSocketHandlerInbound.finish() - } - let webSocketHandlerOutbound = WebSocketHandlerOutboundWriter(webSocket: self, outbound: outbound) - group.addTask { - // parse messages coming from inbound - var frameSequence: WebSocketFrameSequence? - for try await frame in inbound { - switch frame.opcode { - case .connectionClose: - return - case .ping: - try await self.onPing(frame, outbound: outbound, context: context) - case .pong: - try await self.onPong(frame, outbound: outbound, context: context) - case .text, .binary: - if var frameSeq = frameSequence { - frameSeq.append(frame) - frameSequence = frameSeq - } else { - frameSequence = WebSocketFrameSequence(frame: frame) + let asyncChannel = self.asyncChannel + try? await asyncChannel.executeThenClose { inbound, outbound in + let webSocketInbound = WebSocketInboundStream() + let webSocketOutbound = WebSocketOutboundWriter( + type: self.type, + allocator: asyncChannel.channel.allocator, + outbound: outbound + ) + try await withTaskCancellationHandler { + try await withGracefulShutdownHandler { + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + defer { + webSocketInbound.finish() + } + // parse messages coming from inbound + var frameSequence: WebSocketFrameSequence? + for try await frame in inbound { + do { + context.logger.trace("Received \(frame.opcode)") + switch frame.opcode { + case .connectionClose: + // we received a connection close. Finish the inbound data stream, + // send a close back if it hasn't already been send and exit + webSocketInbound.finish() + _ = try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) + return + case .ping: + try await self.onPing(frame, outbound: webSocketOutbound, context: context) + case .pong: + try await self.onPong(frame, outbound: webSocketOutbound, context: context) + case .text, .binary: + if var frameSeq = frameSequence { + frameSeq.append(frame) + frameSequence = frameSeq + } else { + frameSequence = WebSocketFrameSequence(frame: frame) + } + case .continuation: + if var frameSeq = frameSequence { + frameSeq.append(frame) + frameSequence = frameSeq + } else { + try await self.close(code: .protocolError, outbound: webSocketOutbound, context: context) + } + default: + break + } + if let frameSeq = frameSequence, frame.fin { + await webSocketInbound.send(frameSeq.data) + frameSequence = nil + } + } catch { + // catch errors while processing websocket frames so responding close message + // can be dealt with + let errorCode = WebSocketErrorCode(error) + try await self.close(code: errorCode, outbound: webSocketOutbound, context: context) } - case .continuation: - if var frameSeq = frameSequence { - frameSeq.append(frame) - frameSequence = frameSeq + } + } + group.addTask { + do { + // handle websocket data and text + try await handler(webSocketInbound, webSocketOutbound, context) + try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) + } catch { + if self.type == .server { + let errorCode = WebSocketErrorCode.unexpectedServerError + try await self.close(code: errorCode, outbound: webSocketOutbound, context: context) } else { - try await self.close(code: .protocolError, outbound: outbound, context: context) + try await asyncChannel.channel.close(mode: .input) } - default: - break - } - if let frameSeq = frameSequence, frame.fin { - await webSocketHandlerInbound.send(frameSeq.data) - frameSequence = nil } } + try await group.next() + webSocketInbound.finish() } - group.addTask { - // handle websocket data and text - try await handler(webSocketHandlerInbound, webSocketHandlerOutbound, context) - try await self.close(code: .normalClosure, outbound: outbound, context: context) + } onGracefulShutdown: { + Task { + try? await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) } - try await group.next() } - } catch let error as NIOWebSocketError { - let errorCode = WebSocketErrorCode(error) - try await self.close(code: errorCode, outbound: outbound, context: context) - } catch { - let errorCode = WebSocketErrorCode.unexpectedServerError - try await self.close(code: errorCode, outbound: outbound, context: context) + } onCancel: { + Task { + webSocketInbound.finish() + try await asyncChannel.channel.close(mode: .input) + } } } + context.logger.debug("Closed WebSocket") } /// Respond to ping func onPing( _ frame: WebSocketFrame, - outbound: NIOAsyncChannelOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { if frame.fin { @@ -118,9 +153,10 @@ actor WebSocketHandler: Sendable { /// Respond to pong func onPong( _ frame: WebSocketFrame, - outbound: NIOAsyncChannelOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { + guard !self.closed else { return } let frameData = frame.unmaskedData guard self.pingData.readableBytes == 0 || frameData == self.pingData else { try await self.close(code: .goingAway, outbound: outbound, context: context) @@ -130,26 +166,26 @@ actor WebSocketHandler: Sendable { } /// Send ping - func ping(outbound: NIOAsyncChannelOutboundWriter) async throws { + func ping(outbound: WebSocketOutboundWriter) async throws { guard !self.closed else { return } if self.pingData.readableBytes == 0 { // creating random payload let random = (0..) async throws { + func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws { guard !self.closed else { return } - try await self.send(frame: .init(fin: true, opcode: .pong, data: data ?? .init()), outbound: outbound) + try await outbound.write(frame: .init(fin: true, opcode: .pong, data: data ?? .init())) } /// Send close func close( code: WebSocketErrorCode = .normalClosure, - outbound: NIOAsyncChannelOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { guard !self.closed else { return } @@ -157,35 +193,21 @@ actor WebSocketHandler: Sendable { var buffer = context.allocator.buffer(capacity: 2) buffer.write(webSocketErrorCode: code) - try await self.send(frame: .init(fin: true, opcode: .connectionClose, data: buffer), outbound: outbound) - } - - /// Send WebSocket frame - func send( - frame: WebSocketFrame, - outbound: NIOAsyncChannelOutboundWriter - ) async throws { - var frame = frame - frame.maskKey = self.makeMaskKey() - try await outbound.write(frame) - } - - /// Make mask key to be used in WebSocket frame - private func makeMaskKey() -> WebSocketMaskingKey? { - guard self.type == .client else { return nil } - let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) } - return WebSocketMaskingKey(bytes) + try await outbound.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) + outbound.finish() } } extension WebSocketErrorCode { - init(_ error: NIOWebSocketError) { + init(_ error: any Error) { switch error { - case .invalidFrameLength: + case NIOWebSocketError.invalidFrameLength: self = .messageTooLarge - case .fragmentedControlFrame, - .multiByteControlFrameLength: + case NIOWebSocketError.fragmentedControlFrame, + NIOWebSocketError.multiByteControlFrameLength: self = .protocolError + default: + self = .unexpectedServerError } } } diff --git a/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift b/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift index 487afa5..d6067ae 100644 --- a/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift +++ b/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift @@ -17,7 +17,8 @@ import NIOCore import NIOWebSocket /// Inbound websocket data AsyncSequence -public typealias WebSocketHandlerInbound = AsyncChannel +public typealias WebSocketInboundStream = AsyncChannel + /// Enumeration holding WebSocket data public enum WebSocketDataFrame: Equatable, Sendable, CustomStringConvertible, CustomDebugStringConvertible { case text(String) diff --git a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift index dccc044..4d334ee 100644 --- a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift +++ b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift @@ -16,7 +16,7 @@ import NIOCore import NIOWebSocket /// Outbound websocket writer -public struct WebSocketHandlerOutboundWriter { +public struct WebSocketOutboundWriter: Sendable { /// WebSocket frame that can be written public enum OutboundFrame: Sendable { /// Text frame @@ -25,13 +25,12 @@ public struct WebSocketHandlerOutboundWriter { case binary(ByteBuffer) /// Unsolicited pong frame case pong - /// A ping frame. The returning pong will be dealt with by the underlying code - case ping /// A custom frame not supported by the above case custom(WebSocketFrame) } - let webSocket: WebSocketHandler + let type: WebSocketType + let allocator: ByteBufferAllocator let outbound: NIOAsyncChannelOutboundWriter /// Write WebSocket frame @@ -40,20 +39,37 @@ public struct WebSocketHandlerOutboundWriter { switch frame { case .binary(let buffer): // send binary data - try await self.webSocket.send(frame: .init(fin: true, opcode: .binary, data: buffer), outbound: self.outbound) + try await self.write(frame: .init(fin: true, opcode: .binary, data: buffer)) case .text(let string): // send text based data - let buffer = self.webSocket.asyncChannel.channel.allocator.buffer(string: string) - try await self.webSocket.send(frame: .init(fin: true, opcode: .text, data: buffer), outbound: self.outbound) - case .ping: - // send ping - try await self.webSocket.ping(outbound: self.outbound) + let buffer = self.allocator.buffer(string: string) + try await self.write(frame: .init(fin: true, opcode: .text, data: buffer)) case .pong: // send unexplained pong as a heartbeat - try await self.webSocket.pong(data: nil, outbound: self.outbound) + try await self.write(frame: .init(fin: true, opcode: .pong, data: .init())) case .custom(let frame): // send custom WebSocketFrame - try await self.webSocket.send(frame: frame, outbound: self.outbound) + try await self.write(frame: frame) } } + + /// Send WebSocket frame + func write( + frame: WebSocketFrame + ) async throws { + var frame = frame + frame.maskKey = self.makeMaskKey() + try await self.outbound.write(frame) + } + + func finish() { + self.outbound.finish() + } + + /// Make mask key to be used in WebSocket frame + private func makeMaskKey() -> WebSocketMaskingKey? { + guard self.type == .client else { return nil } + let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) } + return WebSocketMaskingKey(bytes) + } } diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index b7666bf..780613c 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -82,8 +82,13 @@ final class HummingbirdWebSocketTests: XCTestCase { ) async throws { try await withThrowingTaskGroup(of: Void.self) { group in let promise = Promise() - let logger = { - var logger = Logger(label: "WebSocketTest") + let serverLogger = { + var logger = Logger(label: "WebSocketServer") + logger.logLevel = .debug + return logger + }() + let clientLogger = { + var logger = Logger(label: "WebSocketClient") logger.logLevel = .debug return logger }() @@ -101,7 +106,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router: router, server: .tls(webSocketUpgrade, tlsConfiguration: serverTLSConfiguration), onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, - logger: logger + logger: serverLogger ) serviceGroup = ServiceGroup( configuration: .init( @@ -115,7 +120,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router: router, server: webSocketUpgrade, onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, - logger: logger + logger: serverLogger ) serviceGroup = ServiceGroup( configuration: .init( @@ -129,7 +134,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try await serviceGroup.run() } group.addTask { - let client = try await getClient(promise.wait(), logger) + let client = try await getClient(promise.wait(), clientLogger) try await client.run() } do { @@ -169,8 +174,13 @@ final class HummingbirdWebSocketTests: XCTestCase { ) async throws { try await withThrowingTaskGroup(of: Void.self) { group in let promise = Promise() - let logger = { - var logger = Logger(label: "WebSocketTest") + let serverLogger = { + var logger = Logger(label: "WebSocketServer") + logger.logLevel = .debug + return logger + }() + let clientLogger = { + var logger = Logger(label: "WebSocketClient") logger.logLevel = .debug return logger }() @@ -180,7 +190,7 @@ final class HummingbirdWebSocketTests: XCTestCase { router: router, server: .webSocketUpgrade(webSocketRouter: webSocketRouter), onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, - logger: logger + logger: serverLogger ) serviceGroup = ServiceGroup( configuration: .init( @@ -193,7 +203,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try await serviceGroup.run() } group.addTask { - let client = try await getClient(promise.wait(), logger) + let client = try await getClient(promise.wait(), clientLogger) try await client.run() } do { @@ -567,3 +577,16 @@ final class HummingbirdWebSocketTests: XCTestCase { */ } + +extension Logger { + /// Create new Logger with additional metadata value + /// - Parameters: + /// - metadataKey: Metadata key + /// - value: Metadata value + /// - Returns: Logger + func with(metadataKey: String, value: MetadataValue) -> Logger { + var logger = self + logger[metadataKey: metadataKey] = value + return logger + } +}