Skip to content

Commit

Permalink
Merge pull request #23 from dannflor/customErrors
Browse files Browse the repository at this point in the history
Add custom mapping from errors to views
  • Loading branch information
0xTim committed Jul 2, 2022
2 parents 2f707d0 + 16f05a1 commit 272cc7b
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 11 deletions.
27 changes: 22 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<br>
<br>
<a href="https://swift.org">
<img src="http://img.shields.io/badge/Swift-5.2-brightgreen.svg" alt="Language">
<img src="http://img.shields.io/badge/Swift-5.6-brightgreen.svg" alt="Language">
</a>
<a href="https://github.com/brokenhandsio/leaf-error-middleware/actions">
<img src="https://github.com/brokenhandsio/leaf-error-middleware/workflows/CI/badge.svg?branch=main" alt="Build Status">
Expand Down Expand Up @@ -71,14 +71,31 @@ The closure receives three parameters:
* `Error` - the error caught to be handled.
* `Request` - the request currently being handled. This can be used to log information, make external API calls or check the session.

# Setting Up
## Custom Mappings

You need to include two [Leaf](https://github.com/vapor/leaf) templates in your application:
By default, you need to include two [Leaf](https://github.com/vapor/leaf) templates in your application:

* `404.leaf`
* `serverError.leaf`

When Leaf Error Middleware catches a 404 error, it will return the `404.leaf` template. Any other error caught will return the `serverError.leaf` template.
However, you may elect to provide a dictionary mapping arbitrary error responses (i.e >= 400) to custom template names, like so:

```swift
let mappings: [HTTPStatus: String] = [
.notFound: "404",
.unauthorized: "401",
.forbidden: "403"
]
let leafMiddleware = LeafErrorMiddleware(errorMappings: mappings) { status, error, req async throws -> SomeContext in
SomeContext()
}

app.middleware.use(leafMiddleware)
// OR
app.middleware.use(LeafErrorMiddlewareDefaultGenerator.build(errorMappings: mapping))
```

By default, when Leaf Error Middleware catches a 404 error, it will return the `404.leaf` template. This particular mapping also allows returning a `401.leaf` or `403.leaf` template based on the error. Any other error caught will return the `serverError.leaf` template. By providing a mapping, you override the default 404 template and will need to respecify it if you want to use it.

## Default Context

Expand All @@ -88,4 +105,4 @@ If using the default context, the `serverError.leaf` template will be passed up
* `statusMessage` - a reason for the status code
* `reason` - the reason for the error, if known. Otherwise this won't be passed in.

The `404.leaf` template will get a `reason` parameter in the context if one is known.
The `404.leaf` template and any other custom error templates will get a `reason` parameter in the context if one is known.
12 changes: 8 additions & 4 deletions Sources/LeafErrorMiddleware/LeafErrorMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ import Vapor
/// Captures all errors and transforms them into an internal server error.
public final class LeafErrorMiddleware<T: Encodable>: AsyncMiddleware {
let contextGenerator: (HTTPStatus, Error, Request) async throws -> T

public init(contextGenerator: @escaping ((HTTPStatus, Error, Request) async throws -> T)) {
let errorMappings: [HTTPStatus: String]

/// Accepts an optional mapping of error statuses to template names for more granular error page templates
public init(errorMappings: [HTTPStatus: String] = [.notFound: "404"],
contextGenerator: @escaping ((HTTPStatus, Error, Request) async throws -> T)) {
self.contextGenerator = contextGenerator
self.errorMappings = errorMappings
}

/// See `Middleware.respond`
Expand Down Expand Up @@ -38,10 +42,10 @@ public final class LeafErrorMiddleware<T: Encodable>: AsyncMiddleware {
}

private func handleError(for request: Request, status: HTTPStatus, error: Error) async throws -> Response {
if status == .notFound {
if let viewMapping = errorMappings[status] {
do {
let context = try await contextGenerator(status, error, request)
let res = try await request.view.render("404", context).encodeResponse(for: request).get()
let res = try await request.view.render(viewMapping, context).encodeResponse(for: request).get()
res.status = status
return res
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ public enum LeafErrorMiddlewareDefaultGenerator {
return context
}

public static func build() -> LeafErrorMiddleware<DefaultContext> {
LeafErrorMiddleware(contextGenerator: generate)
public static func build(errorMappings: [HTTPStatus: String]? = nil) -> LeafErrorMiddleware<DefaultContext> {
if let errorMappings = errorMappings {
return LeafErrorMiddleware(errorMappings: errorMappings, contextGenerator: generate)
}
else {
return LeafErrorMiddleware(contextGenerator: generate)
}
}
}
155 changes: 155 additions & 0 deletions Tests/LeafErrorMiddlewareTests/CustomMappingCustomGeneratorTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import LeafErrorMiddleware
@testable import Logging
import Vapor
import XCTest

class CustomMappingCustomGeneratorTests: XCTestCase {
// MARK: - Properties

var app: Application!
var viewRenderer: ThrowingViewRenderer!
var logger = CapturingLogger()
var eventLoopGroup: EventLoopGroup!

// MARK: - Overrides

override func setUpWithError() throws {
eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
viewRenderer = ThrowingViewRenderer(eventLoop: eventLoopGroup.next())
LoggingSystem.bootstrapInternal { _ in
self.logger
}
app = Application(.testing, .shared(eventLoopGroup))

app.views.use { _ in
self.viewRenderer
}

func routes(_ router: RoutesBuilder) throws {
router.get("ok") { _ in
"ok"
}

router.get("404") { _ -> HTTPStatus in
.notFound
}

router.get("403") { _ -> Response in
throw Abort(.forbidden)
}

router.get("serverError") { _ -> Response in
throw Abort(.internalServerError)
}

router.get("unknownError") { _ -> Response in
throw TestError()
}

router.get("unauthorized") { _ -> Response in
throw Abort(.unauthorized)
}

router.get("303") { _ -> Response in
throw Abort.redirect(to: "ok")
}

router.get("404withReason") { _ -> HTTPStatus in
throw Abort(.notFound, reason: "Could not find it")
}

router.get("500withReason") { _ -> HTTPStatus in
throw Abort(.badGateway, reason: "I messed up")
}
}

try routes(app)
let mappings: [HTTPStatus: String] = [
.notFound: "404",
.unauthorized: "401",
.forbidden: "403",
// Verify that non-error mappings are ignored
.seeOther: "303"
]
let leafMiddleware = LeafErrorMiddleware(errorMappings: mappings) { status, error, req async throws -> AContext in
AContext(trigger: true)
}
app.middleware.use(leafMiddleware)
}

override func tearDownWithError() throws {
app.shutdown()
try eventLoopGroup.syncShutdownGracefully()
}

// MARK: - Tests

func testThatValidEndpointWorks() throws {
let response = try app.getResponse(to: "/ok")
XCTAssertEqual(response.status, .ok)
}

func testThatRedirectIsNotCaught() throws {
let response = try app.getResponse(to: "/303")
XCTAssertEqual(response.status, .seeOther)
XCTAssertEqual(response.headers[.location].first, "ok")
XCTAssertNotEqual(viewRenderer.leafPath, "303")
}

func testNonAbort404IsCaughtCorrectly() throws {
let response = try app.getResponse(to: "/404")
XCTAssertEqual(response.status, .notFound)
XCTAssertEqual(viewRenderer.leafPath, "404")
}

func testThat403IsCaughtCorrectly() throws {
let response = try app.getResponse(to: "/403")
XCTAssertEqual(response.status, .forbidden)
XCTAssertEqual(viewRenderer.leafPath, "403")
let context = try XCTUnwrap(viewRenderer.capturedContext as? AContext)
XCTAssertTrue(context.trigger)
}

func testContextGeneratedOn401Page() throws {
let response = try app.getResponse(to: "/unauthorized")
XCTAssertEqual(response.status, .unauthorized)
XCTAssertEqual(viewRenderer.leafPath, "401")
let context = try XCTUnwrap(viewRenderer.capturedContext as? AContext)
XCTAssertTrue(context.trigger)
}
func testContextGeneratedOn500Page() throws {
let response = try app.getResponse(to: "/serverError")
XCTAssertEqual(response.status, .internalServerError)
XCTAssertEqual(viewRenderer.leafPath, "serverError")
let context = try XCTUnwrap(viewRenderer.capturedContext as? AContext)
XCTAssertTrue(context.trigger)
}

func testGetAResponseWhenGeneratorThrows() throws {
app.shutdown()
app = Application(.testing, .shared(eventLoopGroup))
app.views.use { _ in
self.viewRenderer
}
let leafErrorMiddleware = LeafErrorMiddleware { _, _, _ -> AContext in
throw Abort(.internalServerError)
}
app.middleware = .init()
app.middleware.use(leafErrorMiddleware)

app.get("404") { _ async throws -> Response in
throw Abort(.notFound)
}
app.get("500") { _ async throws -> Response in
throw Abort(.internalServerError)
}

let response404 = try app.getResponse(to: "404")
XCTAssertEqual(response404.status, .notFound)
XCTAssertNil(viewRenderer.leafPath)

let response500 = try app.getResponse(to: "500")
XCTAssertEqual(response500.status, .internalServerError)
XCTAssertNil(viewRenderer.leafPath)
}
}

0 comments on commit 272cc7b

Please sign in to comment.