Skip to content

Commit

Permalink
Add suspend, resume and reconnect methods (#64)
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Apr 17, 2024
1 parent 7ebee0d commit 86326af
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 21 deletions.
22 changes: 20 additions & 2 deletions README.md
Expand Up @@ -20,7 +20,6 @@ Currently, the client supports **Core NATS** with auth, TLS, lame duck mode and

JetStream, KV, Object Store, Service API are on the roadmap.


## Support

Join the [#swift](https://natsio.slack.com/channels/swift) channel on nats.io Slack.
Expand Down Expand Up @@ -127,7 +126,6 @@ specific subjects, facilitating asynchronous communication patterns. This exampl
will guide you through creating a subscription to a subject, allowing your application to process
incoming messages as they are received.


```swift
let subscription = try await nats.subscribe(subject: "foo.>")

Expand Down Expand Up @@ -175,6 +173,26 @@ nats.on(.connected) { event in
}
```

### AppDelegate or SceneDelegate Integration

In order to make sure the connection is managed properly in your
AppDelegate.swift or SceneDelegate.swift, integrate the NatsClient connection
management as follows:

```swift
func sceneDidBecomeActive(_ scene: UIScene) {
Task {
try await self.natsClient.resume()
}
}

func sceneWillResignActive(_ scene: UIScene) {
Task {
try await self.natsClient.suspend()
}
}
```

## Attribution

This library is based on excellent work in https://github.com/aus-der-Technik/SwiftyNats
27 changes: 27 additions & 0 deletions Sources/Nats/NatsClient/NatsClient.swift
Expand Up @@ -26,6 +26,7 @@ public enum NatsState {
case connected
case disconnected
case closed
case suspended
}

public struct Auth {
Expand Down Expand Up @@ -85,6 +86,8 @@ extension NatsClient {
}
if !connectionHandler.retryOnFailedConnect {
try await connectionHandler.connect()
connectionHandler.state = .connected
connectionHandler.fire(.connected)
} else {
connectionHandler.handleReconnect()
}
Expand All @@ -98,6 +101,30 @@ extension NatsClient {
try await connectionHandler.close()
}

public func suspend() async throws {
logger.debug("suspend")
guard let connectionHandler = self.connectionHandler else {
throw NatsClientError("internal error: empty connection handler")
}
try await connectionHandler.suspend()
}

public func resume() async throws {
logger.debug("resume")
guard let connectionHandler = self.connectionHandler else {
throw NatsClientError("internal error: empty connection handler")
}
try await connectionHandler.resume()
}

public func reconnect() async throws {
logger.debug("resume")
guard let connectionHandler = self.connectionHandler else {
throw NatsClientError("internal error: empty connection handler")
}
try await connectionHandler.reconnect()
}

public func publish(
_ payload: Data, subject: String, reply: String? = nil, headers: NatsHeaderMap? = nil
) async throws {
Expand Down
1 change: 1 addition & 0 deletions Sources/Nats/NatsClient/NatsClientOptions.swift
Expand Up @@ -13,6 +13,7 @@

import Dispatch
import Foundation
import Logging
import NIO
import NIOFoundationCompat

Expand Down
100 changes: 89 additions & 11 deletions Sources/Nats/NatsConnection.swift
Expand Up @@ -47,7 +47,9 @@ class ConnectionHandler: ChannelInboundHandler {
private var clientKey: URL?

typealias InboundIn = ByteBuffer
private var state: NatsState = .pending
private let stateLock = NSLock()
internal var state: NatsState = .pending

private var subscriptions: [UInt64: Subscription]
private var subscriptionCounter = ManagedAtomic<UInt64>(0)
private var serverInfo: ServerInfo?
Expand All @@ -56,6 +58,7 @@ class ConnectionHandler: ChannelInboundHandler {
private var pingTask: RepeatedTask?
private var outstandingPings = ManagedAtomic<UInt8>(0)
private var reconnectAttempts = 0
private var reconnectTask: Task<(), Never>? = nil

private var group: MultiThreadedEventLoopGroup

Expand Down Expand Up @@ -219,7 +222,6 @@ class ConnectionHandler: ChannelInboundHandler {
// if there are more reconnect attempts than the number of servers,
// we are after the initial connect, so sleep between servers
let shouldSleep = self.reconnectAttempts >= self.urls.count
logger.debug("reconnect attempts: \(self.reconnectAttempts)")
for s in servers {
if let maxReconnects {
if reconnectAttempts >= maxReconnects {
Expand Down Expand Up @@ -249,8 +251,6 @@ class ConnectionHandler: ChannelInboundHandler {
throw lastErr
}
self.reconnectAttempts = 0
self.state = .connected
self.fire(.connected)
guard let channel = self.channel else {
throw NatsClientError("internal error: empty channel")
}
Expand Down Expand Up @@ -530,16 +530,71 @@ class ConnectionHandler: ChannelInboundHandler {
}

func close() async throws {
self.state = .closed
try await disconnect()
self.reconnectTask?.cancel()
await self.reconnectTask?.value

guard let eventLoop = self.channel?.eventLoop else {
throw NatsClientError("internal error: channel should not be nil")
}
let promise = eventLoop.makePromise(of: Void.self)

eventLoop.execute { // This ensures the code block runs on the event loop
self.state = .closed
self.pingTask?.cancel()
self.channel?.close(mode: .all, promise: promise)
}

try await promise.futureResult.get()
self.fire(.closed)
}

func disconnect() async throws {
private func disconnect() async throws {
self.pingTask?.cancel()
try await self.channel?.close().get()
}

func suspend() async throws {
self.reconnectTask?.cancel()
_ = await self.reconnectTask?.value

guard let eventLoop = self.channel?.eventLoop else {
throw NatsClientError("internal error: channel should not be nil")
}
let promise = eventLoop.makePromise(of: Void.self)

eventLoop.execute { // This ensures the code block runs on the event loop
if self.state == .connected {
self.state = .suspended
self.pingTask?.cancel()
self.channel?.close(mode: .all, promise: promise)
} else {
self.state = .suspended
promise.succeed()
}
}

try await promise.futureResult.get()
self.fire(.suspended)
}

func resume() async throws {
guard let eventLoop = self.channel?.eventLoop else {
throw NatsClientError("internal error: channel should not be nil")
}
try await eventLoop.submit {
guard self.state == .suspended else {
throw NatsClientError(
"unable to resume connection - connection is not in suspended state")
}
self.handleReconnect()
}.get()
}

func reconnect() async throws {
try await suspend()
try await resume()
}

internal func sendPing(_ rttCommand: RttCommand? = nil) async {
let pingsOut = self.outstandingPings.wrappingIncrementThenLoad(
ordering: AtomicUpdateOrdering.relaxed)
Expand Down Expand Up @@ -627,19 +682,30 @@ class ConnectionHandler: ChannelInboundHandler {
}

func handleReconnect() {
Task {
while maxReconnects == nil || self.reconnectAttempts < maxReconnects! {
reconnectTask = Task {
var reconnected = false
while !Task.isCancelled
&& (maxReconnects == nil || self.reconnectAttempts < maxReconnects!)
{
do {
try await self.connect()
} catch _ as CancellationError {
// task cancelled
return
} catch {
// TODO(pp): add option to set this to exponential backoff (with jitter)
logger.debug("could not reconnect: \(error)")
continue
}
logger.debug("reconnected")
reconnected = true
break
}
if self.state != .connected {
// if task was cancelled when establishing connection, do not attempt to recreate subscriptions
if Task.isCancelled {
return
}
if !reconnected && !Task.isCancelled {
logger.error("could not reconnect; maxReconnects exceeded")
logger.debug("closing connection")
do {
Expand All @@ -651,7 +717,15 @@ class ConnectionHandler: ChannelInboundHandler {
return
}
for (sid, sub) in self.subscriptions {
try await write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
do {
try await write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
} catch {
logger.error("error recreating subscription \(sid): \(error)")
}
}
self.channel?.eventLoop.execute {
self.state = .connected
self.fire(.connected)
}
}
}
Expand Down Expand Up @@ -741,6 +815,7 @@ public enum NatsEventKind: String {
case connected = "connected"
case disconnected = "disconnected"
case closed = "closed"
case suspended = "suspended"
case lameDuckMode = "lameDuckMode"
case error = "error"
static let all = [connected, disconnected, closed, lameDuckMode, error]
Expand All @@ -749,6 +824,7 @@ public enum NatsEventKind: String {
public enum NatsEvent {
case connected
case disconnected
case suspended
case closed
case lameDuckMode
case error(NatsError)
Expand All @@ -759,6 +835,8 @@ public enum NatsEvent {
return .connected
case .disconnected:
return .disconnected
case .suspended:
return .suspended
case .closed:
return .closed
case .lameDuckMode:
Expand Down
2 changes: 1 addition & 1 deletion Sources/Nats/NatsSubscription.swift
Expand Up @@ -62,7 +62,7 @@ public class Subscription: AsyncSequence {
}
}

func complete() {
internal func complete() {
lock.withLock {
closed = true
if let continuation {
Expand Down

0 comments on commit 86326af

Please sign in to comment.