Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket extensions and permessage-deflate extension #45

Merged
merged 9 commits into from Mar 29, 2024
13 changes: 6 additions & 7 deletions Package.swift
Expand Up @@ -8,7 +8,7 @@ let package = Package(
platforms: [.macOS(.v14), .iOS(.v17), .tvOS(.v17)],
products: [
.library(name: "HummingbirdWebSocket", targets: ["HummingbirdWebSocket"]),
// .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]),
.library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]),
],
dependencies: [
.package(url: "https://github.com/hummingbird-project/hummingbird.git", branch: "main"),
Expand All @@ -18,7 +18,6 @@ let package = Package(
.package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"),
.package(url: "https://github.com/swift-extras/swift-extras-base64.git", from: "0.5.0"),
.package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"),
],
targets: [
Expand All @@ -32,13 +31,13 @@ let package = Package(
.product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"),
.product(name: "NIOWebSocket", package: "swift-nio"),
]),
/* .target(name: "HummingbirdWSCompression", dependencies: [
.byName(name: "HummingbirdWSCore"),
.product(name: "CompressNIO", package: "compress-nio"),
]),*/
.target(name: "HummingbirdWSCompression", dependencies: [
.byName(name: "HummingbirdWebSocket"),
.product(name: "CompressNIO", package: "compress-nio"),
]),
.testTarget(name: "HummingbirdWebSocketTests", dependencies: [
.byName(name: "HummingbirdWebSocket"),
// .byName(name: "HummingbirdWSCompression"),
.byName(name: "HummingbirdWSCompression"),
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "Hummingbird", package: "hummingbird"),
.product(name: "HummingbirdTesting", package: "hummingbird"),
Expand Down
3 changes: 2 additions & 1 deletion Snippets/WebsocketTest.swift
@@ -1,6 +1,7 @@
import HTTPTypes
import Hummingbird
import HummingbirdWebSocket
import HummingbirdWSCompression
import Logging

var logger = Logger(label: "Echo")
Expand All @@ -22,7 +23,7 @@ router.ws("/ws") { inbound, outbound, _ in

let app = Application(
router: router,
server: .webSocketUpgrade(webSocketRouter: router),
server: .webSocketUpgrade(webSocketRouter: router, configuration: .init(extensions: [.perMessageDeflate()])),
logger: logger
)
try await app.runService()
197 changes: 119 additions & 78 deletions Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift
Expand Up @@ -13,7 +13,7 @@
//===----------------------------------------------------------------------===//

import CompressNIO
import HummingbirdWSCore
import HummingbirdWebSocket
import NIOCore
import NIOWebSocket

Expand All @@ -27,21 +27,24 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
let serverNoContextTakeover: Bool
let compressionLevel: Int?
let memoryLevel: Int?
let maxDecompressedFrameSize: Int

init(
clientMaxWindow: Int? = nil,
clientNoContextTakeover: Bool = false,
serverMaxWindow: Int? = nil,
serverNoContextTakeover: Bool = false,
compressionLevel: Int? = nil,
memoryLevel: Int? = nil
memoryLevel: Int? = nil,
maxDecompressedFrameSize: Int = (1 << 14)
) {
self.clientMaxWindow = clientMaxWindow
self.clientNoContextTakeover = clientNoContextTakeover
self.serverMaxWindow = serverMaxWindow
self.serverNoContextTakeover = serverNoContextTakeover
self.compressionLevel = compressionLevel
self.memoryLevel = memoryLevel
self.maxDecompressedFrameSize = maxDecompressedFrameSize
}

/// Return client request header
Expand Down Expand Up @@ -86,17 +89,15 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
/// Create server PerMessageDeflateExtension based off request headers
/// - Parameters:
/// - request: Client request
/// - eventLoop: EventLoop it is bound to
func serverExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? {
func serverExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? {
let configuration = self.responseConfiguration(to: request)
return try PerMessageDeflateExtension(configuration: configuration, eventLoop: eventLoop)
return try PerMessageDeflateExtension(configuration: configuration)
}

/// Create client PerMessageDeflateExtension based off response headers
/// - Parameters:
/// - response: Server response
/// - eventLoop: EventLoop it is bound to
func clientExtension(from response: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> WebSocketExtension? {
func clientExtension(from response: WebSocketExtensionHTTPParameters) throws -> WebSocketExtension? {
let clientMaxWindowParam = response.parameters["client_max_window_bits"]?.integer
let clientNoContextTakeoverParam = response.parameters["client_no_context_takeover"] != nil
let serverMaxWindowParam = response.parameters["server_max_window_bits"]?.integer
Expand All @@ -107,8 +108,9 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
sendMaxWindow: clientMaxWindowParam,
sendNoContextTakeover: clientNoContextTakeoverParam,
compressionLevel: self.compressionLevel,
memoryLevel: self.memoryLevel
), eventLoop: eventLoop)
memoryLevel: self.memoryLevel,
maxDecompressedFrameSize: self.maxDecompressedFrameSize
))
}

private func responseConfiguration(to request: WebSocketExtensionHTTPParameters) -> PerMessageDeflateExtension.Configuration {
Expand All @@ -134,7 +136,8 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
sendMaxWindow: optionalMin(requestServerMaxWindow?.integer, self.serverMaxWindow),
sendNoContextTakeover: requestServerNoContextTakeover || self.serverNoContextTakeover,
compressionLevel: self.compressionLevel,
memoryLevel: self.memoryLevel
memoryLevel: self.memoryLevel,
maxDecompressedFrameSize: self.maxDecompressedFrameSize
)
}
}
Expand All @@ -144,102 +147,133 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
/// Uses deflate to compress messages sent across a WebSocket
/// See RFC 7692 for more details https://www.rfc-editor.org/rfc/rfc7692
struct PerMessageDeflateExtension: WebSocketExtension {
enum SendState: Sendable {
case idle
case sendingMessage
}

struct Configuration: Sendable {
let receiveMaxWindow: Int?
let receiveNoContextTakeover: Bool
let sendMaxWindow: Int?
let sendNoContextTakeover: Bool
let compressionLevel: Int?
let memoryLevel: Int?
let maxDecompressedFrameSize: Int
}

/// Internal mutable state and referenced types, that cannot be set to Sendable
class InternalState {
actor Decompressor {
fileprivate let decompressor: any NIODecompressor

init(_ decompressor: any NIODecompressor) throws {
self.decompressor = decompressor
try self.decompressor.startStream()
}

func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: some WebSocketContextProtocol) throws -> WebSocketFrame {
var frame = frame
precondition(frame.fin, "Only concatenated frames with fin set can be processed by the permessage-deflate extension")
// Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame
// send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2).
frame.data.writeBytes([0, 0, 255, 255])
frame.data = try frame.data.decompressStream(with: self.decompressor, maxSize: maxSize, allocator: context.allocator)
if resetStream {
try self.decompressor.resetStream()
}
return frame
}

func shutdown() throws {
try self.decompressor.finishStream()
}
}

actor Compressor {
enum SendState: Sendable {
case idle
case sendingMessage
}

fileprivate let compressor: any NIOCompressor
fileprivate var sendState: SendState
var sendState: SendState

init(configuration: Configuration) throws {
self.decompressor = CompressionAlgorithm.deflate(
configuration: .init(
windowSize: numericCast(configuration.receiveMaxWindow ?? 15)
)
).decompressor
// compression level -1 will setup the default compression level, 8 is the default memory level
self.compressor = CompressionAlgorithm.deflate(
configuration: .init(
windowSize: numericCast(configuration.sendMaxWindow ?? 15),
compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1,
memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8
)
).compressor
init(_ compressor: any NIOCompressor) throws {
self.compressor = compressor
self.sendState = .idle
try self.decompressor.startStream()
try self.compressor.startStream()
}

func shutdown() {
try? self.compressor.finishStream()
try? self.decompressor.finishStream()
func compress(_ frame: WebSocketFrame, resetStream: Bool, context: some WebSocketContextProtocol) throws -> WebSocketFrame {
// if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message
// compress the data
let shouldWeCompress = frame.data.readableBytes > 16 || !frame.fin || self.sendState != .idle
if shouldWeCompress {
var newFrame = frame
if self.sendState == .idle {
newFrame.rsv1 = true
self.sendState = .sendingMessage
}
newFrame.data = try newFrame.data.compressStream(with: self.compressor, flush: .sync, allocator: context.allocator)
// if final frame then remove last four bytes 0x00 0x00 0xff 0xff
// (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1)
if newFrame.fin {
newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data
self.sendState = .idle
if resetStream {
try self.compressor.resetStream()
}
}
return newFrame
}
return frame
}

func shutdown() throws {
try self.compressor.finishStream()
}
}

let name = "permessage-deflate"
let configuration: Configuration
let internalState: NIOLoopBound<InternalState>
let decompressor: Decompressor
let compressor: Compressor

init(configuration: Configuration, eventLoop: EventLoop) throws {
init(configuration: Configuration) throws {
self.configuration = configuration
self.internalState = try .init(.init(configuration: configuration), eventLoop: eventLoop)
self.decompressor = try .init(
CompressionAlgorithm.deflate(
configuration: .init(
windowSize: numericCast(configuration.receiveMaxWindow ?? 15)
)
).decompressor
)
self.compressor = try .init(
CompressionAlgorithm.deflate(
configuration: .init(
windowSize: numericCast(configuration.sendMaxWindow ?? 15),
compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1,
memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8
)
).compressor
)
}

func shutdown() {
self.internalState.value.shutdown()
func shutdown() async {
try? await self.decompressor.shutdown()
try? await self.compressor.shutdown()
}

func processReceivedFrame(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame {
var frame = frame
func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame {
if frame.rsv1 {
let state = self.internalState.value
precondition(frame.fin, "Only concatenated frames with fin set can be processed by the permessage-deflate extension")
// Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame
// send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2).
frame.data.writeBytes([0, 0, 255, 255])
frame.data = try frame.data.decompressStream(with: state.decompressor, maxSize: ws.maxFrameSize, allocator: ws.channel.allocator)
if self.configuration.receiveNoContextTakeover {
try state.decompressor.resetStream()
}
return try await self.decompressor.decompress(
frame,
maxSize: self.configuration.maxDecompressedFrameSize,
resetStream: self.configuration.receiveNoContextTakeover,
context: context
)
}
return frame
}

func processFrameToSend(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame {
let state = self.internalState.value
// if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message
// compress the data
let shouldWeCompress = frame.data.readableBytes > 16 || !frame.fin || state.sendState != .idle
func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame {
let isCorrectType = frame.opcode == .text || frame.opcode == .binary
if shouldWeCompress, isCorrectType {
var newFrame = frame
if state.sendState == .idle {
newFrame.rsv1 = true
state.sendState = .sendingMessage
}
newFrame.data = try newFrame.data.compressStream(with: state.compressor, flush: .sync, allocator: ws.channel.allocator)
// if final frame then remove last four bytes 0x00 0x00 0xff 0xff
// (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1)
if newFrame.fin {
newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data
state.sendState = .idle
if self.configuration.sendNoContextTakeover {
try state.compressor.resetStream()
}
}
return newFrame
if isCorrectType {
return try await self.compressor.compress(frame, resetStream: self.configuration.sendNoContextTakeover, context: context)
}
return frame
}
Expand All @@ -250,15 +284,20 @@ extension WebSocketExtensionFactory {
/// - Parameters:
/// - maxWindow: Max window to be used for decompression and compression
/// - noContextTakeover: Should we reset window on every message
public static func perMessageDeflate(maxWindow: Int? = nil, noContextTakeover: Bool = false) -> WebSocketExtensionFactory {
public static func perMessageDeflate(
maxWindow: Int? = nil,
noContextTakeover: Bool = false,
maxDecompressedFrameSize: Int = 1 << 14
) -> WebSocketExtensionFactory {
return .init {
PerMessageDeflateExtensionBuilder(
clientMaxWindow: maxWindow,
clientNoContextTakeover: noContextTakeover,
serverMaxWindow: maxWindow,
serverNoContextTakeover: noContextTakeover,
compressionLevel: nil,
memoryLevel: nil
memoryLevel: nil,
maxDecompressedFrameSize: maxDecompressedFrameSize
)
}
}
Expand All @@ -279,7 +318,8 @@ extension WebSocketExtensionFactory {
serverMaxWindow: Int? = nil,
serverNoContextTakeover: Bool = false,
compressionLevel: Int? = nil,
memoryLevel: Int? = nil
memoryLevel: Int? = nil,
maxDecompressedFrameSize: Int = 1 << 14
) -> WebSocketExtensionFactory {
return .init {
PerMessageDeflateExtensionBuilder(
Expand All @@ -288,7 +328,8 @@ extension WebSocketExtensionFactory {
serverMaxWindow: serverMaxWindow,
serverNoContextTakeover: serverNoContextTakeover,
compressionLevel: compressionLevel,
memoryLevel: memoryLevel
memoryLevel: memoryLevel,
maxDecompressedFrameSize: maxDecompressedFrameSize
)
}
}
Expand Down