Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
groue committed Oct 3, 2022
1 parent c8e0ee5 commit 6ab8b55
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -19,7 +19,7 @@ Task {
semaphore.signal()
```

The `wait()` method has a `waitUntilTaskCancellation()` variant that throws `CancellationError` if the task is canceled before a signal occurs.
The `wait()` method has a `waitUnlessCancelled()` variant that throws `CancellationError` if the task is cancelled before a signal occurs.

For a nice introduction to semaphores, see [The Beauty of Semaphores in Swift 🚦](https://medium.com/@roykronenfeld/semaphores-in-swift-e296ea80f860). The article discusses [`DispatchSemaphore`], but it can easily be ported to Swift concurrency: see the [demo playground](Demo/SemaphorePlayground.playground/Contents.swift) of this package.

Expand Down
53 changes: 35 additions & 18 deletions Sources/Semaphore/Semaphore.swift
Expand Up @@ -37,19 +37,24 @@ import Foundation
///
/// - ``signal()``
///
/// ### Blocking on the Semaphore
/// ### Waiting for the Semaphore
///
/// - ``wait()``
/// - ``waitUntilTaskCancellation()``
/// - ``waitUnlessCancelled()``
public final class Semaphore {
/// The semaphore value.
private var value: Int

/// "Waiting for a signal" is easily said, but several possible states exist.
private class Suspension {
enum State {
/// Initial state. Next is suspended, or cancelled.
case pending
case suspendedUntilTaskCancellation(UnsafeContinuation<Void, Error>)

/// Waiting for a signal, with support for cancellation.
case suspendedUnlessCancelled(UnsafeContinuation<Void, Error>)

/// Waiting for a signal, with no support for cancellation.
case suspended(UnsafeContinuation<Void, Never>)

/// Cancelled before we have started waiting.
case cancelled
}

Expand All @@ -64,15 +69,25 @@ public final class Semaphore {
}
}

// MARK: - Internal State

/// The semaphore value.
private var value: Int

/// As many elements as there are suspended tasks waiting for a signal.
/// We store `Suspension` instances instead of `UnsafeContinuation`, because
/// we support cancellation by removing `Suspension` instances from
/// this array.
private var suspensions: [Suspension] = []

/// This lock would be required even if ``Semaphore`` were made an actor,
/// because `withUnsafeContinuation` suspends before it runs its closure
/// argument. Also, by making ``Semaphore`` a plain class, we can expose a
/// non-async ``signal()`` method. The lock is recursive in order to handle
/// cancellation (see the implementation of ``wait()``).
/// The lock that protects `value` and `suspensions`.
///
/// It is recursive in order to handle cancellation (see the implementation
/// of ``waitUnlessCancelled()``).
private let lock = NSRecursiveLock()

// MARK: - Creating a Semaphore

/// Creates a semaphore.
///
/// - parameter value: The starting value for the semaphore. Do not pass a
Expand All @@ -86,6 +101,8 @@ public final class Semaphore {
precondition(suspensions.isEmpty, "Semaphore is deallocated while some task(s) are suspended waiting for a signal.")
}

// MARK: - Waiting for the Semaphore

/// Waits for, or decrements, a semaphore.
///
/// Decrement the counting semaphore. If the resulting value is less than
Expand Down Expand Up @@ -119,7 +136,7 @@ public final class Semaphore {
///
/// - Throws: If the task is canceled before a signal occurs, this function
/// throws `CancellationError`.
public func waitUntilTaskCancellation() async throws {
public func waitUnlessCancelled() async throws {
lock.lock()

value -= 1
Expand Down Expand Up @@ -148,7 +165,7 @@ public final class Semaphore {
// The first suspended task will be the first task resumed by `signal`.
// This is not intended to be a strong fifo guarantee, but just
// an attempt at some fairness.
suspension.state = .suspendedUntilTaskCancellation(continuation)
suspension.state = .suspendedUnlessCancelled(continuation)
suspensions.insert(suspension, at: 0)
lock.unlock()
}
Expand All @@ -168,7 +185,7 @@ public final class Semaphore {
suspensions.remove(at: index)
}

if case let .suspendedUntilTaskCancellation(continuation) = suspension.state {
if case let .suspendedUnlessCancelled(continuation) = suspension.state {
// Task is cancelled while suspended: resume with a CancellationError.
continuation.resume(throwing: CancellationError())
} else {
Expand All @@ -179,6 +196,8 @@ public final class Semaphore {
}
}

// MARK: - Signaling the Semaphore

/// Signals (increments) a semaphore.
///
/// Increment the counting semaphore. If the previous value was less than
Expand All @@ -194,16 +213,14 @@ public final class Semaphore {
value += 1

switch suspensions.popLast()?.state {
case let .suspendedUntilTaskCancellation(continuation):
case let .suspendedUnlessCancelled(continuation):
continuation.resume()
return true
case let .suspended(continuation):
continuation.resume()
return true
default:
break
return false
}

return false
}
}
10 changes: 5 additions & 5 deletions Tests/SemaphoreTests/SemaphoreTests.swift
Expand Up @@ -124,7 +124,7 @@ final class SemaphoreTests: XCTestCase {
let ex = expectation(description: "cancellation")
let task = Task {
do {
try await sem.waitUntilTaskCancellation()
try await sem.waitUnlessCancelled()
XCTFail("Expected CancellationError")
} catch is CancellationError {
} catch {
Expand All @@ -148,7 +148,7 @@ final class SemaphoreTests: XCTestCase {
}
}
do {
try await sem.waitUntilTaskCancellation()
try await sem.waitUnlessCancelled()
XCTFail("Expected CancellationError")
} catch is CancellationError {
} catch {
Expand All @@ -164,7 +164,7 @@ final class SemaphoreTests: XCTestCase {
// Given a task cancelled while suspended on a semaphore,
let sem = Semaphore(value: 0)
let task = Task {
try await sem.waitUntilTaskCancellation()
try await sem.waitUnlessCancelled()
}
try await Task.sleep(nanoseconds: 100_000_000)
task.cancel()
Expand Down Expand Up @@ -197,7 +197,7 @@ final class SemaphoreTests: XCTestCase {
continuation.resume()
}
}
try await sem.waitUntilTaskCancellation()
try await sem.waitUnlessCancelled()
}
task.cancel()

Expand Down Expand Up @@ -279,7 +279,7 @@ final class SemaphoreTests: XCTestCase {
await withThrowingTaskGroup(of: Void.self) { group in
for _ in 0..<(maxCount * 2) {
group.addTask {
try await sem.waitUntilTaskCancellation()
try await sem.waitUnlessCancelled()
await runner.run()
sem.signal()
}
Expand Down

0 comments on commit 6ab8b55

Please sign in to comment.