Skip to content

Commit

Permalink
WebSocket extensions and permessage-deflate extension (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed Mar 29, 2024
1 parent a82530e commit 541324e
Show file tree
Hide file tree
Showing 15 changed files with 738 additions and 436 deletions.
13 changes: 6 additions & 7 deletions Package.swift
Expand Up @@ -8,7 +8,7 @@ let package = Package(
platforms: [.macOS(.v14), .iOS(.v17), .tvOS(.v17)],
products: [
.library(name: "HummingbirdWebSocket", targets: ["HummingbirdWebSocket"]),
// .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]),
.library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]),
],
dependencies: [
.package(url: "https://github.com/hummingbird-project/hummingbird.git", branch: "main"),
Expand All @@ -18,7 +18,6 @@ let package = Package(
.package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"),
.package(url: "https://github.com/swift-extras/swift-extras-base64.git", from: "0.5.0"),
.package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"),
],
targets: [
Expand All @@ -32,13 +31,13 @@ let package = Package(
.product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"),
.product(name: "NIOWebSocket", package: "swift-nio"),
]),
/* .target(name: "HummingbirdWSCompression", dependencies: [
.byName(name: "HummingbirdWSCore"),
.product(name: "CompressNIO", package: "compress-nio"),
]),*/
.target(name: "HummingbirdWSCompression", dependencies: [
.byName(name: "HummingbirdWebSocket"),
.product(name: "CompressNIO", package: "compress-nio"),
]),
.testTarget(name: "HummingbirdWebSocketTests", dependencies: [
.byName(name: "HummingbirdWebSocket"),
// .byName(name: "HummingbirdWSCompression"),
.byName(name: "HummingbirdWSCompression"),
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "Hummingbird", package: "hummingbird"),
.product(name: "HummingbirdTesting", package: "hummingbird"),
Expand Down
3 changes: 2 additions & 1 deletion Snippets/WebsocketTest.swift
@@ -1,6 +1,7 @@
import HTTPTypes
import Hummingbird
import HummingbirdWebSocket
import HummingbirdWSCompression
import Logging

var logger = Logger(label: "Echo")
Expand All @@ -22,7 +23,7 @@ router.ws("/ws") { inbound, outbound, _ in

let app = Application(
router: router,
server: .webSocketUpgrade(webSocketRouter: router),
server: .webSocketUpgrade(webSocketRouter: router, configuration: .init(extensions: [.perMessageDeflate()])),
logger: logger
)
try await app.runService()
197 changes: 119 additions & 78 deletions Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift
Expand Up @@ -13,7 +13,7 @@
//===----------------------------------------------------------------------===//

import CompressNIO
import HummingbirdWSCore
import HummingbirdWebSocket
import NIOCore
import NIOWebSocket

Expand All @@ -27,21 +27,24 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
let serverNoContextTakeover: Bool
let compressionLevel: Int?
let memoryLevel: Int?
let maxDecompressedFrameSize: Int

init(
clientMaxWindow: Int? = nil,
clientNoContextTakeover: Bool = false,
serverMaxWindow: Int? = nil,
serverNoContextTakeover: Bool = false,
compressionLevel: Int? = nil,
memoryLevel: Int? = nil
memoryLevel: Int? = nil,
maxDecompressedFrameSize: Int = (1 << 14)
) {
self.clientMaxWindow = clientMaxWindow
self.clientNoContextTakeover = clientNoContextTakeover
self.serverMaxWindow = serverMaxWindow
self.serverNoContextTakeover = serverNoContextTakeover
self.compressionLevel = compressionLevel
self.memoryLevel = memoryLevel
self.maxDecompressedFrameSize = maxDecompressedFrameSize
}

/// Return client request header
Expand Down Expand Up @@ -86,17 +89,15 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
/// Create server PerMessageDeflateExtension based off request headers
/// - Parameters:
/// - request: Client request
/// - eventLoop: EventLoop it is bound to
func serverExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? {
func serverExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? {
let configuration = self.responseConfiguration(to: request)
return try PerMessageDeflateExtension(configuration: configuration, eventLoop: eventLoop)
return try PerMessageDeflateExtension(configuration: configuration)
}

/// Create client PerMessageDeflateExtension based off response headers
/// - Parameters:
/// - response: Server response
/// - eventLoop: EventLoop it is bound to
func clientExtension(from response: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> WebSocketExtension? {
func clientExtension(from response: WebSocketExtensionHTTPParameters) throws -> WebSocketExtension? {
let clientMaxWindowParam = response.parameters["client_max_window_bits"]?.integer
let clientNoContextTakeoverParam = response.parameters["client_no_context_takeover"] != nil
let serverMaxWindowParam = response.parameters["server_max_window_bits"]?.integer
Expand All @@ -107,8 +108,9 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
sendMaxWindow: clientMaxWindowParam,
sendNoContextTakeover: clientNoContextTakeoverParam,
compressionLevel: self.compressionLevel,
memoryLevel: self.memoryLevel
), eventLoop: eventLoop)
memoryLevel: self.memoryLevel,
maxDecompressedFrameSize: self.maxDecompressedFrameSize
))
}

private func responseConfiguration(to request: WebSocketExtensionHTTPParameters) -> PerMessageDeflateExtension.Configuration {
Expand All @@ -134,7 +136,8 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
sendMaxWindow: optionalMin(requestServerMaxWindow?.integer, self.serverMaxWindow),
sendNoContextTakeover: requestServerNoContextTakeover || self.serverNoContextTakeover,
compressionLevel: self.compressionLevel,
memoryLevel: self.memoryLevel
memoryLevel: self.memoryLevel,
maxDecompressedFrameSize: self.maxDecompressedFrameSize
)
}
}
Expand All @@ -144,102 +147,133 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder {
/// Uses deflate to compress messages sent across a WebSocket
/// See RFC 7692 for more details https://www.rfc-editor.org/rfc/rfc7692
struct PerMessageDeflateExtension: WebSocketExtension {
enum SendState: Sendable {
case idle
case sendingMessage
}

struct Configuration: Sendable {
let receiveMaxWindow: Int?
let receiveNoContextTakeover: Bool
let sendMaxWindow: Int?
let sendNoContextTakeover: Bool
let compressionLevel: Int?
let memoryLevel: Int?
let maxDecompressedFrameSize: Int
}

/// Internal mutable state and referenced types, that cannot be set to Sendable
class InternalState {
actor Decompressor {
fileprivate let decompressor: any NIODecompressor

init(_ decompressor: any NIODecompressor) throws {
self.decompressor = decompressor
try self.decompressor.startStream()
}

func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: some WebSocketContextProtocol) throws -> WebSocketFrame {
var frame = frame
precondition(frame.fin, "Only concatenated frames with fin set can be processed by the permessage-deflate extension")
// Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame
// send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2).
frame.data.writeBytes([0, 0, 255, 255])
frame.data = try frame.data.decompressStream(with: self.decompressor, maxSize: maxSize, allocator: context.allocator)
if resetStream {
try self.decompressor.resetStream()
}
return frame
}

func shutdown() throws {
try self.decompressor.finishStream()
}
}

actor Compressor {
enum SendState: Sendable {
case idle
case sendingMessage
}

fileprivate let compressor: any NIOCompressor
fileprivate var sendState: SendState
var sendState: SendState

init(configuration: Configuration) throws {
self.decompressor = CompressionAlgorithm.deflate(
configuration: .init(
windowSize: numericCast(configuration.receiveMaxWindow ?? 15)
)
).decompressor
// compression level -1 will setup the default compression level, 8 is the default memory level
self.compressor = CompressionAlgorithm.deflate(
configuration: .init(
windowSize: numericCast(configuration.sendMaxWindow ?? 15),
compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1,
memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8
)
).compressor
init(_ compressor: any NIOCompressor) throws {
self.compressor = compressor
self.sendState = .idle
try self.decompressor.startStream()
try self.compressor.startStream()
}

func shutdown() {
try? self.compressor.finishStream()
try? self.decompressor.finishStream()
func compress(_ frame: WebSocketFrame, resetStream: Bool, context: some WebSocketContextProtocol) throws -> WebSocketFrame {
// if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message
// compress the data
let shouldWeCompress = frame.data.readableBytes > 16 || !frame.fin || self.sendState != .idle
if shouldWeCompress {
var newFrame = frame
if self.sendState == .idle {
newFrame.rsv1 = true
self.sendState = .sendingMessage
}
newFrame.data = try newFrame.data.compressStream(with: self.compressor, flush: .sync, allocator: context.allocator)
// if final frame then remove last four bytes 0x00 0x00 0xff 0xff
// (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1)
if newFrame.fin {
newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data
self.sendState = .idle
if resetStream {
try self.compressor.resetStream()
}
}
return newFrame
}
return frame
}

func shutdown() throws {
try self.compressor.finishStream()
}
}

let name = "permessage-deflate"
let configuration: Configuration
let internalState: NIOLoopBound<InternalState>
let decompressor: Decompressor
let compressor: Compressor

init(configuration: Configuration, eventLoop: EventLoop) throws {
init(configuration: Configuration) throws {
self.configuration = configuration
self.internalState = try .init(.init(configuration: configuration), eventLoop: eventLoop)
self.decompressor = try .init(
CompressionAlgorithm.deflate(
configuration: .init(
windowSize: numericCast(configuration.receiveMaxWindow ?? 15)
)
).decompressor
)
self.compressor = try .init(
CompressionAlgorithm.deflate(
configuration: .init(
windowSize: numericCast(configuration.sendMaxWindow ?? 15),
compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1,
memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8
)
).compressor
)
}

func shutdown() {
self.internalState.value.shutdown()
func shutdown() async {
try? await self.decompressor.shutdown()
try? await self.compressor.shutdown()
}

func processReceivedFrame(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame {
var frame = frame
func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame {
if frame.rsv1 {
let state = self.internalState.value
precondition(frame.fin, "Only concatenated frames with fin set can be processed by the permessage-deflate extension")
// Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame
// send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2).
frame.data.writeBytes([0, 0, 255, 255])
frame.data = try frame.data.decompressStream(with: state.decompressor, maxSize: ws.maxFrameSize, allocator: ws.channel.allocator)
if self.configuration.receiveNoContextTakeover {
try state.decompressor.resetStream()
}
return try await self.decompressor.decompress(
frame,
maxSize: self.configuration.maxDecompressedFrameSize,
resetStream: self.configuration.receiveNoContextTakeover,
context: context
)
}
return frame
}

func processFrameToSend(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame {
let state = self.internalState.value
// if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message
// compress the data
let shouldWeCompress = frame.data.readableBytes > 16 || !frame.fin || state.sendState != .idle
func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame {
let isCorrectType = frame.opcode == .text || frame.opcode == .binary
if shouldWeCompress, isCorrectType {
var newFrame = frame
if state.sendState == .idle {
newFrame.rsv1 = true
state.sendState = .sendingMessage
}
newFrame.data = try newFrame.data.compressStream(with: state.compressor, flush: .sync, allocator: ws.channel.allocator)
// if final frame then remove last four bytes 0x00 0x00 0xff 0xff
// (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1)
if newFrame.fin {
newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data
state.sendState = .idle
if self.configuration.sendNoContextTakeover {
try state.compressor.resetStream()
}
}
return newFrame
if isCorrectType {
return try await self.compressor.compress(frame, resetStream: self.configuration.sendNoContextTakeover, context: context)
}
return frame
}
Expand All @@ -250,15 +284,20 @@ extension WebSocketExtensionFactory {
/// - Parameters:
/// - maxWindow: Max window to be used for decompression and compression
/// - noContextTakeover: Should we reset window on every message
public static func perMessageDeflate(maxWindow: Int? = nil, noContextTakeover: Bool = false) -> WebSocketExtensionFactory {
public static func perMessageDeflate(
maxWindow: Int? = nil,
noContextTakeover: Bool = false,
maxDecompressedFrameSize: Int = 1 << 14
) -> WebSocketExtensionFactory {
return .init {
PerMessageDeflateExtensionBuilder(
clientMaxWindow: maxWindow,
clientNoContextTakeover: noContextTakeover,
serverMaxWindow: maxWindow,
serverNoContextTakeover: noContextTakeover,
compressionLevel: nil,
memoryLevel: nil
memoryLevel: nil,
maxDecompressedFrameSize: maxDecompressedFrameSize
)
}
}
Expand All @@ -279,7 +318,8 @@ extension WebSocketExtensionFactory {
serverMaxWindow: Int? = nil,
serverNoContextTakeover: Bool = false,
compressionLevel: Int? = nil,
memoryLevel: Int? = nil
memoryLevel: Int? = nil,
maxDecompressedFrameSize: Int = 1 << 14
) -> WebSocketExtensionFactory {
return .init {
PerMessageDeflateExtensionBuilder(
Expand All @@ -288,7 +328,8 @@ extension WebSocketExtensionFactory {
serverMaxWindow: serverMaxWindow,
serverNoContextTakeover: serverNoContextTakeover,
compressionLevel: compressionLevel,
memoryLevel: memoryLevel
memoryLevel: memoryLevel,
maxDecompressedFrameSize: maxDecompressedFrameSize
)
}
}
Expand Down

0 comments on commit 541324e

Please sign in to comment.