Skip to content

Commit

Permalink
Avoid using deinit to fulfil the protocol negotiation promise (#2497)
Browse files Browse the repository at this point in the history
# Motivation

Fixes #2494

# Modification
This PR avoids using `deinit` to fulfil the protocol negotiation promise and opts to trap instead when it is being accessed before the handler is added. This allows us to use `handlerAdded` and `handlerRemoved`.

# Result
No more `deinit` usage that can be observed.
  • Loading branch information
FranzBusch committed Aug 8, 2023
1 parent 28eb2ac commit cf28163
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 31 deletions.
23 changes: 15 additions & 8 deletions Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift
Expand Up @@ -49,10 +49,14 @@ public final class NIOTypedApplicationProtocolNegotiationHandler<NegotiationResu

@_spi(AsyncChannel)
public var protocolNegotiationResult: EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> {
self.negotiatedPromise.futureResult
return self.negotiatedPromise.futureResult
}

private let negotiatedPromise: EventLoopPromise<NIOProtocolNegotiationResult<NegotiationResult>>
private var negotiatedPromise: EventLoopPromise<NIOProtocolNegotiationResult<NegotiationResult>> {
precondition(self._negotiatedPromise != nil, "Tried to access the protocol negotiation result before the handler was added to a pipeline")
return self._negotiatedPromise!
}
private var _negotiatedPromise: EventLoopPromise<NIOProtocolNegotiationResult<NegotiationResult>>?

private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>
private var stateMachine = ProtocolNegotiationHandlerStateMachine<NIOProtocolNegotiationResult<NegotiationResult>>()
Expand All @@ -63,9 +67,8 @@ public final class NIOTypedApplicationProtocolNegotiationHandler<NegotiationResu
/// - Parameter alpnCompleteHandler: The closure that will fire when ALPN
/// negotiation has completed.
@_spi(AsyncChannel)
public init(eventLoop: EventLoop, alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>) {
public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>) {
self.completionHandler = alpnCompleteHandler
self.negotiatedPromise = eventLoop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
}

/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
Expand All @@ -74,14 +77,18 @@ public final class NIOTypedApplicationProtocolNegotiationHandler<NegotiationResu
/// - Parameter alpnCompleteHandler: The closure that will fire when ALPN
/// negotiation has completed.
@_spi(AsyncChannel)
public convenience init(eventLoop: EventLoop, alpnCompleteHandler: @escaping (ALPNResult) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>) {
self.init(eventLoop: eventLoop) { result, _ in
public convenience init(alpnCompleteHandler: @escaping (ALPNResult) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>) {
self.init { result, _ in
alpnCompleteHandler(result)
}
}

deinit {
switch self.stateMachine.deinitHandler() {
public func handlerAdded(context: ChannelHandlerContext) {
self._negotiatedPromise = context.eventLoop.makePromise()
}

public func handlerRemoved(context: ChannelHandlerContext) {
switch self.stateMachine.handlerRemoved() {
case .failPromise:
self.negotiatedPromise.fail(ChannelError.inappropriateOperationForState)

Expand Down
11 changes: 3 additions & 8 deletions Sources/NIOTLS/ProtocolNegotiationHandlerStateMachine.swift
Expand Up @@ -30,21 +30,16 @@ struct ProtocolNegotiationHandlerStateMachine<NegotiationResult> {
private var state = State.initial

@usableFromInline
enum DeinitHandlerAction {
enum HandlerRemovedAction {
case failPromise
}

@inlinable
mutating func deinitHandler() -> DeinitHandlerAction? {
mutating func handlerRemoved() -> HandlerRemovedAction? {
switch self.state {
case .initial:
case .initial, .waitingForUser, .unbuffering:
return .failPromise

case .waitingForUser, .unbuffering:
// We are retaining the handler strongly while waiting and unbuffering
// so we should never hit the deinit.
fatalError("Unexpected state")

case .finished:
return .none
}
Expand Down
4 changes: 2 additions & 2 deletions Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift
Expand Up @@ -989,7 +989,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN))
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
switch alpn {
Expand Down Expand Up @@ -1020,7 +1020,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {

@discardableResult
private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> {
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
switch alpn {
Expand Down
Expand Up @@ -27,16 +27,15 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
private let negotiatedEvent: TLSUserEvent = .handshakeCompleted(negotiatedProtocol: "h2")
private let negotiatedResult: ALPNResult = .negotiated("h2")

func testPromiseIsCompleted() {
func testPromiseIsCompleted() throws {
let channel = EmbeddedChannel()
let eventLoop = channel.embeddedEventLoop

var handler: NIOTypedApplicationProtocolNegotiationHandler? = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: eventLoop) { result, channel in
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { result, channel in
return channel.eventLoop.makeSucceededFuture(.init(result: (.negotiated(result))))
}
let future = handler!.protocolNegotiationResult
handler = nil
XCTAssertThrowsError(try future.wait()) { error in
try channel.pipeline.addHandler(handler).wait()
try channel.pipeline.removeHandler(handler).wait()
XCTAssertThrowsError(try handler.protocolNegotiationResult.wait()) { error in
XCTAssertEqual(error as? ChannelError, .inappropriateOperationForState)
}
}
Expand All @@ -46,7 +45,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
let loop = emChannel.eventLoop as! EmbeddedEventLoop
var called = false

let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result, channel in
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { result, channel in
called = true
XCTAssertEqual(result, self.negotiatedResult)
XCTAssertTrue(emChannel === channel)
Expand All @@ -64,7 +63,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
let channel = EmbeddedChannel()
let loop = channel.eventLoop as! EmbeddedEventLoop

let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { result in
XCTFail("Negotiation fired")
return loop.makeSucceededFuture(.init(result: (.failed)))
}
Expand All @@ -85,7 +84,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
let channel = EmbeddedChannel()
let loop = channel.eventLoop as! EmbeddedEventLoop

let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { result in
XCTFail("Should not be called")
return loop.makeSucceededFuture(.init(result: (.failed)))
}
Expand All @@ -104,7 +103,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
let loop = channel.eventLoop as! EmbeddedEventLoop
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)

let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { result in
return continuePromise.futureResult
}

Expand Down Expand Up @@ -135,7 +134,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
let loop = channel.eventLoop as! EmbeddedEventLoop
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)

let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { result in
continuePromise.futureResult
}
let eventCounterHandler = EventCounterHandler()
Expand All @@ -162,7 +161,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
let loop = channel.eventLoop as! EmbeddedEventLoop
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)

let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { result in
continuePromise.futureResult
}
let eventCounterHandler = EventCounterHandler()
Expand Down Expand Up @@ -193,7 +192,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
let loop = channel.eventLoop as! EmbeddedEventLoop
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)

let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { result in
continuePromise.futureResult
}
let eventCounterHandler = EventCounterHandler()
Expand Down

0 comments on commit cf28163

Please sign in to comment.