From 593a3ccbccfc9404761897af632ab9155248d524 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Tue, 5 Mar 2024 16:35:27 +0000 Subject: [PATCH] Add support for multivalue headers in APIGatewayV2 --- .../HummingbirdLambda/APIGatewayLambda.swift | 14 ++++++++++++ .../APIGatewayV2Lambda.swift | 10 +++++++-- .../Request+APIGateway.swift | 22 +++---------------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/Sources/HummingbirdLambda/APIGatewayLambda.swift b/Sources/HummingbirdLambda/APIGatewayLambda.swift index 5d09191..e7e17e6 100644 --- a/Sources/HummingbirdLambda/APIGatewayLambda.swift +++ b/Sources/HummingbirdLambda/APIGatewayLambda.swift @@ -75,6 +75,20 @@ extension APIGatewayRequest: APIRequest { } return queryParams.joined(separator: "&") } + + var httpHeaders: [(name: String, value: String)] { + var headerValues = [(name: String, value: String)].init() + var originalHeaders = self.headers + headerValues.reserveCapacity(headers.count) + for header in self.multiValueHeaders { + originalHeaders[header.key] = nil + for value in header.value { + headerValues.append((name: header.key, value: value)) + } + } + headerValues.append(contentsOf: originalHeaders.map { (name: $0.key, value: $0.value) }) + return headerValues + } } // conform `APIGatewayResponse` to `APIResponse` so we can use HBResponse.apiReponse() diff --git a/Sources/HummingbirdLambda/APIGatewayV2Lambda.swift b/Sources/HummingbirdLambda/APIGatewayV2Lambda.swift index 5995976..17763e6 100644 --- a/Sources/HummingbirdLambda/APIGatewayV2Lambda.swift +++ b/Sources/HummingbirdLambda/APIGatewayV2Lambda.swift @@ -63,9 +63,15 @@ extension APIGatewayV2Request: APIRequest { } var httpMethod: AWSLambdaEvents.HTTPMethod { context.http.method } - var multiValueQueryStringParameters: [String: [String]]? { nil } - var multiValueHeaders: HTTPMultiValueHeaders { [:] } var queryString: String { self.rawQueryString } + var httpHeaders: [(name: String, value: String)] { + self.headers.flatMap { header in + let headers = header.value + .split(separator: ",") + .map { (name: header.key, value: String($0.drop(while: \.isWhitespace))) } + return headers + } + } } // conform `APIGatewayV2Response` to `APIResponse` so we can use HBResponse.apiReponse() diff --git a/Sources/HummingbirdLambda/Request+APIGateway.swift b/Sources/HummingbirdLambda/Request+APIGateway.swift index 0f00851..d8c4a6e 100644 --- a/Sources/HummingbirdLambda/Request+APIGateway.swift +++ b/Sources/HummingbirdLambda/Request+APIGateway.swift @@ -24,8 +24,7 @@ protocol APIRequest { var path: String { get } var httpMethod: AWSLambdaEvents.HTTPMethod { get } var queryString: String { get } - var headers: AWSLambdaEvents.HTTPHeaders { get } - var multiValueHeaders: HTTPMultiValueHeaders { get } + var httpHeaders: [(name: String, value: String)] { get } var body: String? { get } var isBase64Encoded: Bool { get } } @@ -48,7 +47,7 @@ extension HBRequest { } // construct headers var authority: String? - let headers = HTTPFields(headers: from.headers, multiValueHeaders: from.multiValueHeaders, authority: &authority) + let headers = HTTPFields(headers: from.httpHeaders, authority: &authority) // get body let body: ByteBuffer? @@ -82,24 +81,10 @@ extension HTTPFields { /// - headers: headers /// - multiValueHeaders: multi-value headers /// - authority: reference to authority string - init(headers: AWSLambdaEvents.HTTPHeaders, multiValueHeaders: HTTPMultiValueHeaders, authority: inout String?) { + init(headers: [(name: String, value: String)], authority: inout String?) { self.init() self.reserveCapacity(headers.count) var firstHost = true - for (name, values) in multiValueHeaders { - if firstHost, name.lowercased() == "host" { - if let value = values.first { - firstHost = false - authority = value - continue - } - } - if let fieldName = HTTPField.Name(name) { - for value in values { - self.append(HTTPField(name: fieldName, value: value)) - } - } - } for (name, value) in headers { if firstHost, name.lowercased() == "host" { firstHost = false @@ -107,7 +92,6 @@ extension HTTPFields { continue } if let fieldName = HTTPField.Name(name) { - if self[fieldName] != nil { continue } self.append(HTTPField(name: fieldName, value: value)) } }