From d084c8aebc3e336bf2a48b4164ea783cc4becac7 Mon Sep 17 00:00:00 2001 From: Varun Santhanam Date: Mon, 7 Jun 2021 10:06:07 -0400 Subject: [PATCH] Create custom `Publisher` type ahead of support for more advanced behaviors --- .../Ombi/Core/Internal/RequestPublisher.swift | 293 ++++++++++++++++++ .../Internal/ResponsePublisherProviding.swift | 2 +- .../Ombi/Core/Internal/URLSession+Ombi.swift | 11 +- Sources/Ombi/Core/RequestManager.swift | 141 ++------- Sources/Ombi/Wrappers/AnyRequestable.swift | 2 +- Tests/OmbiTests/RequestManagerTests.swift | 2 +- 6 files changed, 320 insertions(+), 131 deletions(-) create mode 100644 Sources/Ombi/Core/Internal/RequestPublisher.swift diff --git a/Sources/Ombi/Core/Internal/RequestPublisher.swift b/Sources/Ombi/Core/Internal/RequestPublisher.swift new file mode 100644 index 0000000..97c2c64 --- /dev/null +++ b/Sources/Ombi/Core/Internal/RequestPublisher.swift @@ -0,0 +1,293 @@ +// Ombi +// RequestPublisher.swift +// +// MIT License +// +// Copyright (c) 2021 Varun Santhanam +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the Software), to deal +// +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +import Combine +import Foundation +import os.log + +final class RequestPublisher: Publisher where ResponseError: Error { + + // MARK: - Initializers + + init(session: URLSession, + request: T, + host: String, + injectedHeaders: RequestHeaders, + backupAuthentication: RequestAuthentication?, + log: OSLog?) where T: Requestable, T.RequestBody == RequestBody, T.ResponseBody == ResponseBody, T.ResponseError == ResponseError { + self.session = session + self.request = .init(request) + self.host = host + self.injectedHeaders = injectedHeaders + self.backupAuthentication = backupAuthentication + self.log = log + } + + // MARK: - Publisher + + func receive(subscriber: S) where S: Subscriber, RequestError == S.Failure, RequestResponse == S.Input { + subscriber.receive(subscription: RequestSubscription(requestPublisher: self, downstream: subscriber)) + } + + typealias Output = RequestResponse + + typealias Failure = RequestError + + // MARK: - Private + + fileprivate let session: URLSession + fileprivate let request: AnyRequestable + fileprivate let host: String + fileprivate let injectedHeaders: RequestHeaders + fileprivate let backupAuthentication: RequestAuthentication? + fileprivate let log: OSLog? + +} + +private class RequestSubscription: Subscription where Downstream: Subscriber, Downstream.Input == RequestPublisher.Output, Downstream.Failure == RequestPublisher.Failure, ResponseError: Error { + + // MARK: - Initializers + + init(requestPublisher: RequestPublisher, downstream: Downstream) { + lock = .init() + self.requestPublisher = requestPublisher + self.downstream = downstream + demand = .max(0) + } + + // MARK: - Subscription + + func request(_ demand: Subscribers.Demand) { + lock.lock() + guard let requestPublisher = requestPublisher else { + lock.unlock() + return + } + + guard var urlComponents = URLComponents(string: requestPublisher.host) else { + self.demand += 1 + failWithMalformedComponents() + return + } + + urlComponents.path = "\(urlComponents.path)\(requestPublisher.request.path)" + if !requestPublisher.request.query.isEmpty { + urlComponents.queryItems = requestPublisher.request.query + } + + guard let finalURL = urlComponents.url else { + self.demand += 1 + failWithMalformedComponents() + return + } + + var request = URLRequest(url: finalURL) + request.httpMethod = requestPublisher.request.method.rawValue + if let body = requestPublisher.request.body { + do { + request.httpBody = try requestPublisher.request.requestEncoder.encode(body) + } catch { + self.demand += 1 + failWithMalformedComponents() + return + } + } + + var headers = requestPublisher.request.headers + .reduce([String: String]()) { prev, pair in + let (key, value) = pair + var next = prev + next[key.description] = value.description + return next + } + + headers = requestPublisher.injectedHeaders + .reduce(headers) { prev, pair in + let (key, value) = pair + var next = prev + next[key.description] = value.description + return next + } + + if let authentication = requestPublisher.request.authentication ?? requestPublisher.backupAuthentication { + headers[authentication.headerKey.description] = authentication.headerValue.description + } + + request.allHTTPHeaderFields = headers + + request.timeoutInterval = requestPublisher.request.timeoutInterval + + if self.task == nil { + let task = requestPublisher.session.dataTask(with: request, + completionHandler: handleResponse(data:response:error:)) + self.task = task + } + + guard let log = requestPublisher.log else { return } + var message = "Making Request" + if let urlString = request.url?.absoluteString { + message.append("\nURL: \(urlString)") + } + if let method = request.requestMethod { + message.append("\nMethod: \(method)") + } + if let headers = request.allHTTPHeaderFields { + message.append("\nHeaders:\n\(headers.description)") + } + if let body = requestPublisher.request.body { + message.append("\nBody:\n\(String(describing: body))") + } + if let encodedBody = request.httpBody { + message.append("\nEncoded Body:\n\(String(describing: encodedBody))") + } + os_log(.debug, log: log, "%@", message) + + self.demand += 1 + let task = self.task! + lock.unlock() + task.resume() + } + + func cancel() { + lock.lock() + guard requestPublisher != nil else { + lock.unlock() + return + } + requestPublisher = nil + downstream = nil + demand = .max(0) + let task = self.task + self.task = nil + lock.unlock() + task?.cancel() + } + + // MARK: - Private + + private let lock: Lock + private var requestPublisher: RequestPublisher? + private var downstream: Downstream? + private var demand: Subscribers.Demand + private var task: URLSessionDataTask! + + private func failWithMalformedComponents() { + lock.lock() + guard demand > 0, + requestPublisher != nil, + let downstream = downstream else { + lock.unlock() + return + } + + requestPublisher = nil + self.downstream = nil + demand = .max(0) + task = nil + lock.unlock() + downstream.receive(completion: .failure(.malformedRequest)) + } + + private func handleResponse(data: Data?, response: URLResponse?, error: Error?) { + lock.lock() + guard demand > 0, + requestPublisher != nil, + let downstream = downstream, + let request = requestPublisher?.request else { + lock.unlock() + return + } + + let log = requestPublisher?.log + requestPublisher = nil + self.downstream = nil + + demand = .max(0) + task = nil + lock.unlock() + + if let response = response, + error == nil { + do { + let body = try request.responseDecoder.decode(data) + let finalResponse: RequestResponse + if let response = response as? HTTPURLResponse { + let headers = response.allHeaderFields.reduce(RequestHeaders()) { headers, pair in + let (field, value) = pair + var next = headers + next[.init(String(describing: field))] = .init(String(describing: value)) + return next + } + finalResponse = .init(url: response.url, headers: headers, statusCode: response.statusCode, body: body) + } else { + finalResponse = .init(url: response.url, headers: nil, statusCode: nil, body: body) + } + + guard let log = log else { return } + var message = "" + if let urlString = response.url?.absoluteString { + message.append("Received Response from \(urlString)") + } else { + message.append("Received Response") + } + if let code = finalResponse.statusCode { + message.append("\nStatus Code: \(code.description)") + } + if let headers = finalResponse.headers { + message.append("\nHeaders:\n\(headers.description)") + } + if let body = finalResponse.body { + message.append("\nBody:\n\(String(describing: body))") + } + os_log(.debug, log: log, "%@", message) + + _ = downstream.receive(finalResponse) + } catch { + downstream.receive(completion: .failure(.decodingError(error))) + } + } else { + if let urlError = error as? URLError { + if urlError.code == .timedOut { + downstream.receive(completion: .failure(.timedOut)) + } + downstream.receive(completion: .failure(.urlSessionFailed(urlError))) + } else { + downstream.receive(completion: .failure(.unknownError)) + } + } + } +} + +private final class Lock { + private var isLocked: Bool = false + + func lock() { + isLocked = true + } + + func unlock() { + isLocked = false + } +} diff --git a/Sources/Ombi/Core/Internal/ResponsePublisherProviding.swift b/Sources/Ombi/Core/Internal/ResponsePublisherProviding.swift index 4ad9203..01d81d4 100644 --- a/Sources/Ombi/Core/Internal/ResponsePublisherProviding.swift +++ b/Sources/Ombi/Core/Internal/ResponsePublisherProviding.swift @@ -27,5 +27,5 @@ import Combine import Foundation protocol ResponsePublisherProviding { - func publisher(for urlRequest: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> + func requestPublisher(for urlRequest: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> } diff --git a/Sources/Ombi/Core/Internal/URLSession+Ombi.swift b/Sources/Ombi/Core/Internal/URLSession+Ombi.swift index 1990698..982a51d 100644 --- a/Sources/Ombi/Core/Internal/URLSession+Ombi.swift +++ b/Sources/Ombi/Core/Internal/URLSession+Ombi.swift @@ -26,9 +26,8 @@ import Combine import Foundation -extension URLSession: ResponsePublisherProviding { - func publisher(for urlRequest: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> { - dataTaskPublisher(for: urlRequest) - .eraseToAnyPublisher() - } -} +// extension URLSession: ResponsePublisherProviding { +// func requestPublisher(for urlRequest: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> { +// RequestPublisher(session: self, request: urlRequest).eraseToAnyPublisher() +// } +// } diff --git a/Sources/Ombi/Core/RequestManager.swift b/Sources/Ombi/Core/RequestManager.swift index 49d7add..36acb14 100644 --- a/Sources/Ombi/Core/RequestManager.swift +++ b/Sources/Ombi/Core/RequestManager.swift @@ -123,6 +123,7 @@ open class RequestManager { /// Headers to add to every request open var additionalHeaders: RequestHeaders = [:] + /// Authentication to add to every request open var requestAuthentication: RequestAuthentication? = nil /// Whether or not to inject Ombi's default headers @@ -143,7 +144,7 @@ open class RequestManager { fallback: T.Response? = nil) -> AnyPublisher where T: Requestable, S: Scheduler { publisher(for: requestable, scheduler: scheduler, - authentication: requestable.authentication ?? authentication ?? requestAuthentication) + authentication: authentication) .retry(retries) .timeout(sla, scheduler: scheduler, @@ -198,7 +199,7 @@ open class RequestManager { // MARK: - Private init(host: String, - session: ResponsePublisherProviding, + session: URLSession, log: OSLog?) { self.host = host self.session = session @@ -214,7 +215,16 @@ open class RequestManager { self.init(host: host, session: session, log: log) } - private let session: ResponsePublisherProviding + private let session: URLSession + + private var injectedHeaders: RequestHeaders { + additionalHeaders.reduce(defaultHeaders) { prev, pair in + var copy = prev + let (key, value) = pair + copy[key] = value + return copy + } + } private var defaultHeaders: RequestHeaders { guard shouldInjectDefaultHeaders else { @@ -226,122 +236,13 @@ open class RequestManager { } private func publisher(for requestable: T, scheduler: S, authentication: RequestAuthentication?) -> AnyPublisher where T: Requestable, S: Scheduler { - typealias InstantFailure = Fail - guard var urlComponents = URLComponents(string: host) else { - return InstantFailure(error: .malformedRequest) - .eraseToAnyPublisher() - } - urlComponents.path = "\(urlComponents.path)\(requestable.path)" - if !requestable.query.isEmpty { - urlComponents.queryItems = requestable.query - } - guard let finalURL = urlComponents.url else { - return InstantFailure(error: .malformedRequest) - .eraseToAnyPublisher() - } - var request = URLRequest(url: finalURL) - request.httpMethod = requestable.method.rawValue - if let body = requestable.body { - do { - request.httpBody = try requestable.requestEncoder.encode(body) - } catch { - return InstantFailure(error: .malformedRequest) - .eraseToAnyPublisher() - } - } - var dict = defaultHeaders - .reduce([String: String]()) { prev, pair in - let (key, value) = pair - var next = prev - next[key.description] = value.description - return next - } - - dict = requestable.headers - .reduce(dict) { prev, pair in - let (key, value) = pair - var next = prev - next[key.description] = value.description - return next - } - - if let authentication = authentication { - dict[authentication.headerKey.description] = authentication.headerValue.description - } - - request.allHTTPHeaderFields = additionalHeaders - .reduce(dict) { prev, pair in - let (key, value) = pair - var next = prev - next[key.description] = value.description - return next - } - request.timeoutInterval = requestable.timeoutInterval - - return session.publisher(for: request) - .mapError { error -> T.Failure in - if error.code == URLError.timedOut { - return T.Failure.timedOut - } - return T.Failure.urlSessionFailed(error) - } - .tryMap { (data: Data, response: URLResponse) -> T.Response in - do { - let body = try requestable.responseDecoder.decode(data) - if let response = response as? HTTPURLResponse { - let headers = response.allHeaderFields.reduce(RequestHeaders()) { headers, pair in - let (field, value) = pair - var next = headers - next[.init(String(describing: field))] = .init(String(describing: value)) - return next - } - return .init(url: response.url, headers: headers, statusCode: response.statusCode, body: body) - } - return .init(url: response.url, headers: nil, statusCode: nil, body: body) - } catch { - throw T.Failure.decodingError(error) - } - } - .handleEvents(receiveOutput: { [log] response in - guard let log = log else { return } - var message = "" - if let urlString = response.url?.absoluteString { - message.append("Received Response from \(urlString)") - } else { - message.append("Received Response") - } - if let code = response.statusCode { - message.append("\nStatus Code: \(code.description)") - } - if let headers = response.headers { - message.append("\nHeaders:\n\(headers.description)") - } - if let body = response.body { - message.append("\nBody:\n\(String(describing: body))") - } - os_log(.debug, log: log, "%@", message) - }) + RequestPublisher(session: session, + request: requestable, + host: host, + injectedHeaders: injectedHeaders, + backupAuthentication: authentication ?? requestAuthentication, + log: log) .validate(using: requestable.responseValidator) - .handleEvents(receiveSubscription: { [log] _ in - guard let log = log else { return } - var message = "Making Request" - if let urlString = request.url?.absoluteString { - message.append("\nURL: \(urlString)") - } - if let method = request.requestMethod { - message.append("\nMethod: \(method)") - } - if let headers = request.allHTTPHeaderFields { - message.append("\nHeaders:\n\(headers.description)") - } - if let body = requestable.body { - message.append("\nBody:\n\(String(describing: body))") - } - if let encodedBody = request.httpBody { - message.append("\nEncoded Body:\n\(String(describing: encodedBody))") - } - os_log(.debug, log: log, "%@", message) - }) .eraseToAnyPublisher() } @@ -370,10 +271,6 @@ open class RequestManager { return "tvOS" #elseif os(macOS) return "macOS" - #elseif os(Linux) - return "Linux" - #elseif os(Windows) - return "Windows" #else return "Unknown" #endif diff --git a/Sources/Ombi/Wrappers/AnyRequestable.swift b/Sources/Ombi/Wrappers/AnyRequestable.swift index 0b4bb8e..f7063d7 100644 --- a/Sources/Ombi/Wrappers/AnyRequestable.swift +++ b/Sources/Ombi/Wrappers/AnyRequestable.swift @@ -60,7 +60,7 @@ public struct AnyRequestable: Requesta public var authentication: RequestAuthentication? { authenticationClosure() } - public var fallbackResponse: Response? { fallbackResponseClosure() } + public var fallbackResponse: RequestResponse? { fallbackResponseClosure() } public var requestEncoder: BodyEncoder { requestEncoderClosure() } diff --git a/Tests/OmbiTests/RequestManagerTests.swift b/Tests/OmbiTests/RequestManagerTests.swift index 29cb4fe..d69606f 100644 --- a/Tests/OmbiTests/RequestManagerTests.swift +++ b/Tests/OmbiTests/RequestManagerTests.swift @@ -436,7 +436,7 @@ private class ResponsePublisherProvidingMock: ResponsePublisherProviding { self.delay = delay } - func publisher(for urlRequest: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> { + func requestPublisher(for urlRequest: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> { validateClosure(urlRequest) return subject .delay(for: .seconds(delay), scheduler: scheduler)