Skip to content

Commit

Permalink
Merge pull request #21 from ptoffy/main
Browse files Browse the repository at this point in the history
Add support for Swift 5.5 and async/await
  • Loading branch information
0xTim committed Dec 20, 2021
2 parents f8d4517 + 6b06aff commit bd4d565
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 209 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Expand Up @@ -9,14 +9,14 @@ on:
jobs:
xenial:
container:
image: swift:5.2-xenial
image: swift:5.5-xenial
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- run: swift test --enable-test-discovery --enable-code-coverage
bionic:
container:
image: swift:5.4-bionic
image: swift:5.5-bionic
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions Package.swift
@@ -1,10 +1,10 @@
// swift-tools-version:5.2
// swift-tools-version:5.5
import PackageDescription

let package = Package(
name: "LeafErrorMiddleware",
platforms: [
.macOS(.v10_15),
.macOS(.v12),
],
products: [
.library(name: "LeafErrorMiddleware", targets: ["LeafErrorMiddleware"]),
Expand Down
87 changes: 43 additions & 44 deletions Sources/LeafErrorMiddleware/LeafErrorMiddleware.swift
@@ -1,71 +1,70 @@
import Vapor

/// Captures all errors and transforms them into an internal server error.
public final class LeafErrorMiddleware<T: Encodable>: Middleware {
public final class LeafErrorMiddleware<T: Encodable>: AsyncMiddleware {
let contextGenerator: (HTTPStatus, Error, Request) async throws -> T

let contextGenerator: ((HTTPStatus, Error, Request) -> EventLoopFuture<T>)

public init(contextGenerator: @escaping ((HTTPStatus, Error, Request) -> EventLoopFuture<T>)) {
public init(contextGenerator: @escaping ((HTTPStatus, Error, Request) async throws -> T)) {
self.contextGenerator = contextGenerator
}

/// See `Middleware.respond`
public func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture<Response> {
return next.respond(to: request).flatMap { res in
public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
do {
let res = try await next.respond(to: request)
if res.status.code >= HTTPResponseStatus.badRequest.code {
return self.handleError(for: request, status: res.status, error: Abort(res.status))
return try await handleError(for: request, status: res.status, error: Abort(res.status))
} else {
return res.encodeResponse(for: request)
return try await res.encodeResponse(for: request)
}
}.flatMapError { error in
} catch {
request.logger.report(error: error)
switch (error) {
case let abort as AbortError:
guard
abort.status.representsError
switch error {
case let abort as AbortError:
guard
abort.status.representsError
else {
if let location = abort.headers[.location].first {
return request.eventLoop.future(request.redirect(to: location))
return request.redirect(to: location)
} else {
return self.handleError(for: request, status: abort.status, error: error)
return try await handleError(for: request, status: abort.status, error: error)
}
}
return self.handleError(for: request, status: abort.status, error: error)
default:
return self.handleError(for: request, status: .internalServerError, error: error)
}
return try await handleError(for: request, status: abort.status, error: error)
default:
return try await handleError(for: request, status: .internalServerError, error: error)
}
}
}
private func handleError(for req: Request, status: HTTPStatus, error: Error) -> EventLoopFuture<Response> {

private func handleError(for request: Request, status: HTTPStatus, error: Error) async throws -> Response {
if status == .notFound {
return contextGenerator(status, error, req).flatMap { context in
return req.view.render("404", context).encodeResponse(for: req).map { res in
res.status = status
return res
}
}.flatMapError { newError in
return self.renderServerErrorPage(for: status, request: req, error: newError)
do {
let context = try await contextGenerator(status, error, request)
let res = try await request.view.render("404", context).encodeResponse(for: request).get()
res.status = status
return res
} catch {
return try await renderServerErrorPage(for: status, request: request, error: error)
}
}
return renderServerErrorPage(for: status, request: req, error: error)
return try await renderServerErrorPage(for: status, request: request, error: error)
}
private func renderServerErrorPage(for status: HTTPStatus, request: Request, error: Error) -> EventLoopFuture<Response> {
return contextGenerator(status, error, request).flatMap { context in
request.logger.error("Internal server error. Status: \(status.code) - path: \(request.url)")
return request.view.render("serverError", context).encodeResponse(for: request).map { res in
res.status = status
return res
}
}.flatMapError { error -> EventLoopFuture<Response> in

private func renderServerErrorPage(for status: HTTPStatus, request: Request, error: Error) async throws -> Response {
do {
let context = try await contextGenerator(status, error, request)
request.logger.error("Internal server error. Status: \(status.code) - path: \(request.url)")
let res = try await request.view.render("serverError", context).encodeResponse(for: request).get()
res.status = status
return res
} catch {
let body = "<h1>Internal Error</h1><p>There was an internal error. Please try again later.</p>"
request.logger.error("Failed to render custom error page - \(error)")
return body.encodeResponse(for: request).map { res in
res.status = status
res.headers.replaceOrAdd(name: .contentType, value: "text/html; charset=utf-8")
return res
}
let res = try await body.encodeResponse(for: request)
res.status = status
res.headers.replaceOrAdd(name: .contentType, value: "text/html; charset=utf-8")
return res
}
}
}
Expand Down
@@ -1,33 +1,15 @@
import Vapor

@available(*, deprecated, renamed: "LeafErrorMiddlewareDefaultGenerator")
public enum LeafErorrMiddlewareDefaultGenerator {
static func generate(_ status: HTTPStatus, _ error: Error, _ req: Request) -> EventLoopFuture<DefaultContext> {
let reason: String?
if let abortError = error as? AbortError {
reason = abortError.reason
} else {
reason = nil
}
let context = DefaultContext(status: status.code.description, statusMessage: status.reasonPhrase, reason: reason)
return req.eventLoop.future(context )
}

public static func build() -> LeafErrorMiddleware<DefaultContext> {
LeafErrorMiddleware(contextGenerator: generate)
}
}

public enum LeafErrorMiddlewareDefaultGenerator {
static func generate(_ status: HTTPStatus, _ error: Error, _ req: Request) -> EventLoopFuture<DefaultContext> {
static func generate(_ status: HTTPStatus, _ error: Error, _ req: Request) async throws -> DefaultContext {
let reason: String?
if let abortError = error as? AbortError {
reason = abortError.reason
} else {
reason = nil
}
let context = DefaultContext(status: status.code.description, statusMessage: status.reasonPhrase, reason: reason)
return req.eventLoop.future(context )
return context
}

public static func build() -> LeafErrorMiddleware<DefaultContext> {
Expand Down
97 changes: 37 additions & 60 deletions Tests/LeafErrorMiddlewareTests/CustomGeneratorTests.swift
@@ -1,84 +1,76 @@
import XCTest
import LeafErrorMiddleware
import Vapor
@testable import Logging
import Vapor
import XCTest

struct AContext: Encodable {
let trigger: Bool
}

class CustomGeneratorTests: XCTestCase {

// MARK: - Properties

var app: Application!
var viewRenderer: ThrowingViewRenderer!
var logger = CapturingLogger()
var eventLoopGroup: EventLoopGroup!

// MARK: - Overrides

override func setUpWithError() throws {
eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
viewRenderer = ThrowingViewRenderer(eventLoop: eventLoopGroup.next())
LoggingSystem.bootstrapInternal { _ in
return self.logger
self.logger
}
app = Application(.testing, .shared(eventLoopGroup))

app.views.use { _ in
return self.viewRenderer
self.viewRenderer
}

func routes(_ router: RoutesBuilder) throws {
router.get("ok") { _ in
"ok"
}

router.get("ok") { req in
return "ok"
router.get("404") { _ -> HTTPStatus in
.notFound
}

router.get("404") { req -> HTTPStatus in
return .notFound
router.get("403") { _ -> Response in
throw Abort(.forbidden)
}

router.get("serverError") { req -> EventLoopFuture<Response> in
router.get("serverError") { _ -> Response in
throw Abort(.internalServerError)
}

router.get("unknownError") { req -> EventLoopFuture<Response> in
router.get("unknownError") { _ -> Response in
throw TestError()
}

router.get("unauthorized") { req -> EventLoopFuture<Response> in
router.get("unauthorized") { _ -> Response in
throw Abort(.unauthorized)
}

router.get("future404") { req -> EventLoopFuture<Response> in
return req.eventLoop.future(error: Abort(.notFound))
}

router.get("future403") { req -> EventLoopFuture<Response> in
return req.eventLoop.future(error: Abort(.forbidden))
}

router.get("future303") { req -> EventLoopFuture<Response> in
return req.eventLoop.future(error: Abort.redirect(to: "ok"))
}

router.get("future404NoAbort") { req -> EventLoopFuture<HTTPStatus> in
return req.eventLoop.future(.notFound)
router.get("303") { _ -> Response in
throw Abort.redirect(to: "ok")
}

router.get("404withReason") { req -> HTTPStatus in
router.get("404withReason") { _ -> HTTPStatus in
throw Abort(.notFound, reason: "Could not find it")
}

router.get("500withReason") { req -> HTTPStatus in
router.get("500withReason") { _ -> HTTPStatus in
throw Abort(.badGateway, reason: "I messed up")
}
}

try routes(app)

let leafMiddleware = LeafErrorMiddleware() { status, error, req -> EventLoopFuture<AContext> in
return req.eventLoop.future(AContext(trigger: true))
let leafMiddleware = LeafErrorMiddleware { status, error, req async throws -> AContext in
AContext(trigger: true)
}
app.middleware.use(leafMiddleware)
}
Expand Down Expand Up @@ -139,36 +131,24 @@ class CustomGeneratorTests: XCTestCase {
XCTAssertEqual(viewRenderer.leafPath, "serverError")
}

func testNonAbort404IsCaughtCorrectly() throws {
let response = try app.getResponse(to: "/404")
XCTAssertEqual(response.status, .notFound)
XCTAssertEqual(viewRenderer.leafPath, "404")
}

func testThatFuture404IsCaughtCorrectly() throws {
let response = try app.getResponse(to: "/future404")
XCTAssertEqual(response.status, .notFound)
XCTAssertEqual(viewRenderer.leafPath, "404")
func testThatRedirectIsNotCaught() throws {
let response = try app.getResponse(to: "/303")
XCTAssertEqual(response.status, .seeOther)
XCTAssertEqual(response.headers[.location].first, "ok")
}

func testFutureNonAbort404IsCaughtCorrectly() throws {
let response = try app.getResponse(to: "/future404NoAbort")
func testNonAbort404IsCaughtCorrectly() throws {
let response = try app.getResponse(to: "/404")
XCTAssertEqual(response.status, .notFound)
XCTAssertEqual(viewRenderer.leafPath, "404")
}

func testThatFuture403IsCaughtCorrectly() throws {
let response = try app.getResponse(to: "/future403")
func testThat403IsCaughtCorrectly() throws {
let response = try app.getResponse(to: "/403")
XCTAssertEqual(response.status, .forbidden)
XCTAssertEqual(viewRenderer.leafPath, "serverError")
}

func testThatRedirectIsNotCaught() throws {
let response = try app.getResponse(to: "/future303")
XCTAssertEqual(response.status, .seeOther)
XCTAssertEqual(response.headers[.location].first, "ok")
}

func testContextGeneratedOn404Page() throws {
let response = try app.getResponse(to: "/404")
XCTAssertEqual(response.status, .notFound)
Expand All @@ -189,20 +169,19 @@ class CustomGeneratorTests: XCTestCase {
app.shutdown()
app = Application(.testing, .shared(eventLoopGroup))
app.views.use { _ in
return self.viewRenderer
self.viewRenderer
}
let leafErrorMiddleware = LeafErrorMiddleware() { status, error, req -> EventLoopFuture<AContext> in
return req.eventLoop.future(error: Abort(.internalServerError))

let leafErrorMiddleware = LeafErrorMiddleware { _, _, _ -> AContext in
throw Abort(.internalServerError)
}
app.middleware = .init()
app.middleware.use(leafErrorMiddleware)

app.get("404") { req -> EventLoopFuture<Response> in
req.eventLoop.makeFailedFuture(Abort(.notFound))
app.get("404") { _ async throws -> Response in
throw Abort(.notFound)
}
app.get("500") { req -> EventLoopFuture<Response> in
req.eventLoop.makeFailedFuture(Abort(.internalServerError))
app.get("500") { _ async throws -> Response in
throw Abort(.internalServerError)
}

let response404 = try app.getResponse(to: "404")
Expand All @@ -213,6 +192,4 @@ class CustomGeneratorTests: XCTestCase {
XCTAssertEqual(response500.status, .internalServerError)
XCTAssertNil(viewRenderer.leafPath)
}


}

0 comments on commit bd4d565

Please sign in to comment.