Skip to content

Commit

Permalink
Cancel DispatchSource before closing socket (#4791)
Browse files Browse the repository at this point in the history
Extends socket lifetime enough to let DispatchSource cancel properly.
Also prevents from creating new DispatchSources while other are in the
middle of cancelling.

Also includes tests (see #4854 for test details).
  • Loading branch information
lxbndr committed Dec 29, 2023
1 parent 7258a8d commit f9a54f3
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 28 deletions.
4 changes: 4 additions & 0 deletions CoreFoundation/URL.subproj/CFURLSessionInterface.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ CFURLSessionEasyCode CFURLSession_easy_setopt_tc(CFURLSessionEasyHandle _Nonnull
return MakeEasyCode(curl_easy_setopt(curl, option.value, a));
}

CFURLSessionEasyCode CFURLSession_easy_setopt_scl(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionOption option, CFURLSessionCloseSocketCallback * _Nullable a) {
return MakeEasyCode(curl_easy_setopt(curl, option.value, a));
}

CFURLSessionEasyCode CFURLSession_easy_getinfo_long(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionInfo info, long *_Nonnull a) {
return MakeEasyCode(curl_easy_getinfo(curl, info.value, a));
}
Expand Down
2 changes: 2 additions & 0 deletions CoreFoundation/URL.subproj/CFURLSessionInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,8 @@ typedef int (CFURLSessionSeekCallback)(void *_Nullable userp, long long offset,
CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_setopt_seek(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionOption option, CFURLSessionSeekCallback * _Nullable a);
typedef int (CFURLSessionTransferInfoCallback)(void *_Nullable userp, long long dltotal, long long dlnow, long long ultotal, long long ulnow);
CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_setopt_tc(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionOption option, CFURLSessionTransferInfoCallback * _Nullable a);
typedef int (CFURLSessionCloseSocketCallback)(void *_Nullable clientp, CFURLSession_socket_t item);
CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_setopt_scl(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionOption option, CFURLSessionCloseSocketCallback * _Nullable a);

CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_getinfo_long(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionInfo info, long *_Nonnull a);
CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_getinfo_double(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionInfo info, double *_Nonnull a);
Expand Down
152 changes: 140 additions & 12 deletions Sources/FoundationNetworking/URLSession/libcurl/MultiHandle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ extension URLSession {
let queue: DispatchQueue
let group = DispatchGroup()
fileprivate var easyHandles: [_EasyHandle] = []
fileprivate var socketReferences: [CFURLSession_socket_t: _SocketReference] = [:]
fileprivate var timeoutSource: _TimeoutSource? = nil
private var reentrantInUpdateTimeoutTimer = false

Expand Down Expand Up @@ -127,13 +128,14 @@ fileprivate extension URLSession._MultiHandle {
if let opaque = socketSourcePtr {
Unmanaged<_SocketSources>.fromOpaque(opaque).release()
}
socketSources?.tearDown(handle: self, socket: socket, queue: queue)
socketSources = nil
}
if let ss = socketSources {
let handler = DispatchWorkItem { [weak self] in
self?.performAction(for: socket)
}
ss.createSources(with: action, socket: socket, queue: queue, handler: handler)
ss.createSources(with: action, handle: self, socket: socket, queue: queue, handler: handler)
}
return 0
}
Expand Down Expand Up @@ -161,9 +163,104 @@ extension Collection where Element == _EasyHandle {
}
}

private extension URLSession._MultiHandle {
class _SocketReference {
let socket: CFURLSession_socket_t
var shouldClose: Bool
var workItem: DispatchWorkItem?

init(socket: CFURLSession_socket_t) {
self.socket = socket
shouldClose = false
}

deinit {
if shouldClose {
#if os(Windows)
closesocket(socket)
#else
close(socket)
#endif
}
}
}

/// Creates and stores socket reference. Reentrancy is not supported.
/// Trying to begin operation for same socket twice would mean something
/// went horribly wrong, or our assumptions about CURL register/unregister
/// action flow are nor correct.
func beginOperation(for socket: CFURLSession_socket_t) -> _SocketReference {
let reference = _SocketReference(socket: socket)
precondition(socketReferences.updateValue(reference, forKey: socket) == nil, "Reentrancy is not supported for socket operations")
return reference
}

/// Removes socket reference from the shared store. If there is work item scheduled,
/// executes it on the current thread.
func endOperation(for socketReference: _SocketReference) {
precondition(socketReferences.removeValue(forKey: socketReference.socket) != nil, "No operation associated with the socket")
if let workItem = socketReference.workItem, !workItem.isCancelled {
// CURL never asks for socket close without unregistering first, and
// we should cancel pending work when unregister action is requested.
precondition(!socketReference.shouldClose, "Socket close was scheduled, but there is some pending work left")
workItem.perform()
}
}

/// Marks this reference to close socket on deinit. This allows us
/// to extend socket lifecycle by keeping the reference alive.
func scheduleClose(for socket: CFURLSession_socket_t) {
let reference = socketReferences[socket] ?? _SocketReference(socket: socket)
reference.shouldClose = true
}

/// Schedules work to be performed when an operation ends for the socket,
/// or performs it immediately if there is no operation in progress.
///
/// We're using this to postpone Dispatch Source creation when
/// previous Dispatch Source is not cancelled yet.
func schedule(_ workItem: DispatchWorkItem, for socket: CFURLSession_socket_t) {
guard let socketReference = socketReferences[socket] else {
workItem.perform()
return
}
// CURL never asks for register without pairing it with unregister later,
// and we're cancelling pending work item on unregister.
// But it is safe to just drop existing work item anyway,
// and replace it with the new one.
socketReference.workItem = workItem
}

/// Cancels pending work for socket operation. Does nothing if
/// there is no operation in progress or no pending work item.
///
/// CURL may become not interested in Dispatch Sources
/// we have planned to create. In this case we should just cancel
/// scheduled work.
func cancelWorkItem(for socket: CFURLSession_socket_t) {
guard let socketReference = socketReferences[socket] else {
return
}
socketReference.workItem?.cancel()
socketReference.workItem = nil
}

}

internal extension URLSession._MultiHandle {
/// Add an easy handle -- start its transfer.
func add(_ handle: _EasyHandle) {
// Set CLOSESOCKETFUNCTION. Note that while the option belongs to easy_handle,
// the connection cache is managed by CURL multi_handle, and sockets can actually
// outlive easy_handle (even after curl_easy_cleanup call). That's why
// socket management lives in _MultiHandle.
try! CFURLSession_easy_setopt_ptr(handle.rawHandle, CFURLSessionOptionCLOSESOCKETDATA, UnsafeMutableRawPointer(Unmanaged.passUnretained(self).toOpaque())).asError()
try! CFURLSession_easy_setopt_scl(handle.rawHandle, CFURLSessionOptionCLOSESOCKETFUNCTION) { (clientp: UnsafeMutableRawPointer?, item: CFURLSession_socket_t) in
guard let handle = URLSession._MultiHandle.from(callbackUserData: clientp) else { fatalError() }
handle.scheduleClose(for: item)
return 0
}.asError()

// If this is the first handle being added, we need to `kick` the
// underlying multi handle by calling `timeoutTimerFired` as
// described in
Expand Down Expand Up @@ -448,25 +545,56 @@ fileprivate class _SocketSources {
s.resume()
}

func tearDown() {
if let s = readSource {
s.cancel()
func tearDown(handle: URLSession._MultiHandle, socket: CFURLSession_socket_t, queue: DispatchQueue) {
handle.cancelWorkItem(for: socket) // There could be pending register action which needs to be cancelled

guard readSource != nil || writeSource != nil else {
// This means that we have posponed (and already abandoned)
// sources creation.
return
}
readSource = nil
if let s = writeSource {
s.cancel()

// Socket is guaranteed to not to be closed as long as we keeping
// the reference.
let socketReference = handle.beginOperation(for: socket)
let cancelHandlerGroup = DispatchGroup()
[readSource, writeSource].compactMap({ $0 }).forEach { source in
cancelHandlerGroup.enter()
source.setCancelHandler {
cancelHandlerGroup.leave()
}
source.cancel()
}
cancelHandlerGroup.notify(queue: queue) {
handle.endOperation(for: socketReference)
}

readSource = nil
writeSource = nil
}
}
extension _SocketSources {
/// Create a read and/or write source as specified by the action.
func createSources(with action: URLSession._MultiHandle._SocketRegisterAction, socket: CFURLSession_socket_t, queue: DispatchQueue, handler: DispatchWorkItem) {
if action.needsReadSource {
createReadSource(socket: socket, queue: queue, handler: handler)
func createSources(with action: URLSession._MultiHandle._SocketRegisterAction, handle: URLSession._MultiHandle, socket: CFURLSession_socket_t, queue: DispatchQueue, handler: DispatchWorkItem) {
// CURL casually requests to unregister and register handlers for same
// socket in a row. There is (pretty low) chance of overlapping tear-down operation
// with "register" request. Bad things could happen if we create
// a new Dispatch Source while other is being cancelled for the same socket.
// We're using `_MultiHandle.schedule(_:for:)` here to postpone sources creation until
// pending operation is finished (if there is none, submitted work item is performed
// immediately).
// Also, CURL may request unregister even before we perform any postponed work,
// so we have to cancel such work in such case. See
let createSources = DispatchWorkItem {
if action.needsReadSource {
self.createReadSource(socket: socket, queue: queue, handler: handler)
}
if action.needsWriteSource {
self.createWriteSource(socket: socket, queue: queue, handler: handler)
}
}
if action.needsWriteSource {
createWriteSource(socket: socket, queue: queue, handler: handler)
if action.needsReadSource || action.needsWriteSource {
handle.schedule(createSources, for: socket)
}
}
}
Expand Down
54 changes: 43 additions & 11 deletions Tests/Foundation/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class _TCPSocket: CustomStringConvertible {
listening = false
}

init(port: UInt16?) throws {
init(port: UInt16?, backlog: Int32) throws {
listening = true
self.port = 0

Expand All @@ -124,7 +124,7 @@ class _TCPSocket: CustomStringConvertible {
try socketAddress.withMemoryRebound(to: sockaddr.self, capacity: MemoryLayout<sockaddr>.size, {
let addr = UnsafePointer<sockaddr>($0)
_ = try attempt("bind", valid: isZero, bind(_socket, addr, socklen_t(MemoryLayout<sockaddr>.size)))
_ = try attempt("listen", valid: isZero, listen(_socket, SOMAXCONN))
_ = try attempt("listen", valid: isZero, listen(_socket, backlog))
})

var actualSA = sockaddr_in()
Expand Down Expand Up @@ -295,8 +295,8 @@ class _HTTPServer: CustomStringConvertible {
let tcpSocket: _TCPSocket
var port: UInt16 { tcpSocket.port }

init(port: UInt16?) throws {
tcpSocket = try _TCPSocket(port: port)
init(port: UInt16?, backlog: Int32 = SOMAXCONN) throws {
tcpSocket = try _TCPSocket(port: port, backlog: backlog)
}

init(socket: _TCPSocket) {
Expand Down Expand Up @@ -1094,15 +1094,32 @@ enum InternalServerError : Error {
case badHeaders
}

extension LoopbackServerTest {
struct Options {
var serverBacklog: Int32
var isAsynchronous: Bool

static let `default` = Options(serverBacklog: SOMAXCONN, isAsynchronous: true)
}
}

class LoopbackServerTest : XCTestCase {
private static let staticSyncQ = DispatchQueue(label: "org.swift.TestFoundation.HTTPServer.StaticSyncQ")

private static var _serverPort: Int = -1
private static var _serverActive = false
private static var testServer: _HTTPServer? = nil


private static var _options: Options = .default

static var options: Options {
get {
return staticSyncQ.sync { _options }
}
set {
staticSyncQ.sync { _options = newValue }
}
}

static var serverPort: Int {
get {
return staticSyncQ.sync { _serverPort }
Expand All @@ -1119,27 +1136,42 @@ class LoopbackServerTest : XCTestCase {

override class func setUp() {
super.setUp()
Self.startServer()
}

override class func tearDown() {
Self.stopServer()
super.tearDown()
}

static func startServer() {
var _serverPort = 0
let dispatchGroup = DispatchGroup()

func runServer() throws {
testServer = try _HTTPServer(port: nil)
testServer = try _HTTPServer(port: nil, backlog: options.serverBacklog)
_serverPort = Int(testServer!.port)
serverActive = true
dispatchGroup.leave()

while serverActive {
do {
let httpServer = try testServer!.listen()
globalDispatchQueue.async {

func handleRequest() {
let subServer = TestURLSessionServer(httpServer: httpServer)
do {
try subServer.readAndRespond()
} catch {
NSLog("readAndRespond: \(error)")
}
}

if options.isAsynchronous {
globalDispatchQueue.async(execute: handleRequest)
} else {
handleRequest()
}
} catch {
if (serverActive) { // Ignore errors thrown on shutdown
NSLog("httpServer: \(error)")
Expand All @@ -1165,11 +1197,11 @@ class LoopbackServerTest : XCTestCase {
fatalError("Timedout waiting for server to be ready")
}
serverPort = _serverPort
debugLog("Listening on \(serverPort)")
}

override class func tearDown() {
static func stopServer() {
serverActive = false
try? testServer?.stop()
super.tearDown()
}
}

0 comments on commit f9a54f3

Please sign in to comment.