From b0a341c5fba9543f768a276f14a90e8c7ae7e7e1 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 7 Mar 2024 17:41:33 +0000 Subject: [PATCH] 2.x.x: Fix multi-value query values and headers for APIGatewayV2 (#24) * Fix multi-value query values in APIGatewayV2 * Add support for multivalue headers in APIGatewayV2 --- .../HummingbirdLambda/APIGatewayLambda.swift | 34 +++++++++++++++- .../APIGatewayV2Lambda.swift | 11 ++++- .../Request+APIGateway.swift | 40 +++---------------- 3 files changed, 48 insertions(+), 37 deletions(-) diff --git a/Sources/HummingbirdLambda/APIGatewayLambda.swift b/Sources/HummingbirdLambda/APIGatewayLambda.swift index d0fe979..e7e17e6 100644 --- a/Sources/HummingbirdLambda/APIGatewayLambda.swift +++ b/Sources/HummingbirdLambda/APIGatewayLambda.swift @@ -57,7 +57,39 @@ extension HBLambda where Output == APIGatewayResponse { } // conform `APIGatewayRequest` to `APIRequest` so we can use HBRequest.init(context:application:from) -extension APIGatewayRequest: APIRequest {} +extension APIGatewayRequest: APIRequest { + var queryString: String { + func urlPercentEncoded(_ string: String) -> String { + return string.addingPercentEncoding(withAllowedCharacters: .urlQueryComponentAllowed) ?? string + } + var queryParams: [String] = [] + var queryStringParameters = self.queryStringParameters ?? [:] + // go through list of multi value query string params first, removing any + // from the single value list if they are found in the multi value list + self.multiValueQueryStringParameters?.forEach { multiValueQuery in + queryStringParameters[multiValueQuery.key] = nil + queryParams += multiValueQuery.value.map { "\(urlPercentEncoded(multiValueQuery.key))=\(urlPercentEncoded($0))" } + } + queryParams += queryStringParameters.map { + "\(urlPercentEncoded($0.key))=\(urlPercentEncoded($0.value))" + } + return queryParams.joined(separator: "&") + } + + var httpHeaders: [(name: String, value: String)] { + var headerValues = [(name: String, value: String)].init() + var originalHeaders = self.headers + headerValues.reserveCapacity(headers.count) + for header in self.multiValueHeaders { + originalHeaders[header.key] = nil + for value in header.value { + headerValues.append((name: header.key, value: value)) + } + } + headerValues.append(contentsOf: originalHeaders.map { (name: $0.key, value: $0.value) }) + return headerValues + } +} // conform `APIGatewayResponse` to `APIResponse` so we can use HBResponse.apiReponse() extension APIGatewayResponse: APIResponse {} diff --git a/Sources/HummingbirdLambda/APIGatewayV2Lambda.swift b/Sources/HummingbirdLambda/APIGatewayV2Lambda.swift index 28047c0..17763e6 100644 --- a/Sources/HummingbirdLambda/APIGatewayV2Lambda.swift +++ b/Sources/HummingbirdLambda/APIGatewayV2Lambda.swift @@ -63,8 +63,15 @@ extension APIGatewayV2Request: APIRequest { } var httpMethod: AWSLambdaEvents.HTTPMethod { context.http.method } - var multiValueQueryStringParameters: [String: [String]]? { nil } - var multiValueHeaders: HTTPMultiValueHeaders { [:] } + var queryString: String { self.rawQueryString } + var httpHeaders: [(name: String, value: String)] { + self.headers.flatMap { header in + let headers = header.value + .split(separator: ",") + .map { (name: header.key, value: String($0.drop(while: \.isWhitespace))) } + return headers + } + } } // conform `APIGatewayV2Response` to `APIResponse` so we can use HBResponse.apiReponse() diff --git a/Sources/HummingbirdLambda/Request+APIGateway.swift b/Sources/HummingbirdLambda/Request+APIGateway.swift index 1f420da..d8c4a6e 100644 --- a/Sources/HummingbirdLambda/Request+APIGateway.swift +++ b/Sources/HummingbirdLambda/Request+APIGateway.swift @@ -23,10 +23,8 @@ import NIOCore protocol APIRequest { var path: String { get } var httpMethod: AWSLambdaEvents.HTTPMethod { get } - var queryStringParameters: [String: String]? { get } - var multiValueQueryStringParameters: [String: [String]]? { get } - var headers: AWSLambdaEvents.HTTPHeaders { get } - var multiValueHeaders: HTTPMultiValueHeaders { get } + var queryString: String { get } + var httpHeaders: [(name: String, value: String)] { get } var body: String? { get } var isBase64Encoded: Bool { get } } @@ -44,23 +42,12 @@ extension HBRequest { // construct URI with query parameters var uri = from.path - var queryParams: [String] = [] - var queryStringParameters = from.queryStringParameters ?? [:] - // go through list of multi value query string params first, removing any - // from the single value list if they are found in the multi value list - from.multiValueQueryStringParameters?.forEach { multiValueQuery in - queryStringParameters[multiValueQuery.key] = nil - queryParams += multiValueQuery.value.map { "\(urlPercentEncoded(multiValueQuery.key))=\(urlPercentEncoded($0))" } - } - queryParams += queryStringParameters.map { - "\(urlPercentEncoded($0.key))=\(urlPercentEncoded($0.value))" - } - if queryParams.count > 0 { - uri += "?\(queryParams.joined(separator: "&"))" + if from.queryString.count > 0 { + uri += "?\(from.queryString)" } // construct headers var authority: String? - let headers = HTTPFields(headers: from.headers, multiValueHeaders: from.multiValueHeaders, authority: &authority) + let headers = HTTPFields(headers: from.httpHeaders, authority: &authority) // get body let body: ByteBuffer? @@ -94,24 +81,10 @@ extension HTTPFields { /// - headers: headers /// - multiValueHeaders: multi-value headers /// - authority: reference to authority string - init(headers: AWSLambdaEvents.HTTPHeaders, multiValueHeaders: HTTPMultiValueHeaders, authority: inout String?) { + init(headers: [(name: String, value: String)], authority: inout String?) { self.init() self.reserveCapacity(headers.count) var firstHost = true - for (name, values) in multiValueHeaders { - if firstHost, name.lowercased() == "host" { - if let value = values.first { - firstHost = false - authority = value - continue - } - } - if let fieldName = HTTPField.Name(name) { - for value in values { - self.append(HTTPField(name: fieldName, value: value)) - } - } - } for (name, value) in headers { if firstHost, name.lowercased() == "host" { firstHost = false @@ -119,7 +92,6 @@ extension HTTPFields { continue } if let fieldName = HTTPField.Name(name) { - if self[fieldName] != nil { continue } self.append(HTTPField(name: fieldName, value: value)) } }