Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/typed queries #375

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// swift-tools-version:5.6
// swift-tools-version:5.9

import PackageDescription
import CompilerPluginSupport

let package = Package(
name: "postgres-nio",
Expand All @@ -20,6 +22,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"),
.package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.5.2"),
.package(url: "https://github.com/apple/swift-syntax.git", branch: "main")
],
targets: [
.target(
Expand All @@ -36,6 +39,14 @@ let package = Package(
.product(name: "NIOTLS", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl"),
.product(name: "NIOFoundationCompat", package: "swift-nio"),
.target(name: "PostgresNIOMacros")
]
),
.macro(
name: "PostgresNIOMacros",
dependencies: [
.product(name: "SwiftSyntaxMacros", package: "swift-syntax"),
.product(name: "SwiftCompilerPlugin", package: "swift-syntax")
]
),
.testTarget(
Expand All @@ -44,6 +55,7 @@ let package = Package(
.target(name: "PostgresNIO"),
.product(name: "NIOEmbedded", package: "swift-nio"),
.product(name: "NIOTestUtils", package: "swift-nio"),
.product(name: "SwiftSyntaxMacrosTestSupport", package: "swift-syntax"),
]
),
.testTarget(
Expand Down
19 changes: 19 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,25 @@ extension PostgresConnection {
throw error // rethrow with more metadata
}
}

/// Run a query on the Postgres server the connection is connected to.
///
/// - Parameters:
/// - query: A ``PostgresTypedQuery`` to run
/// - logger: The `Logger` to log into for the query
/// - file: The file, the query was started in. Used for better error reporting.
/// - line: The line, the query was started in. Used for better error reporting.
/// - Returns: A ``PostgresTypedSequence`` containing typed rows the server sent as the query result.
@discardableResult
public func query<T: PostgresTypedQuery>(
_ query: T,
logger: Logger,
file: String = #fileID,
line: Int = #line
) async throws -> PostgresTypedSequence<T.Row> {
let rowSequence = try await self.query(query.sql, logger: logger, file: file, line: line)
return PostgresTypedSequence(rowSequence: rowSequence)
}
}

// MARK: EventLoopFuture interface
Expand Down
4 changes: 4 additions & 0 deletions Sources/PostgresNIO/Data/PostgresRow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ public struct PostgresRow: Sendable {
}
}

public protocol PostgresTypedRow {
init(from row: PostgresRow) throws
}

extension PostgresRow: Equatable {
public static func ==(lhs: Self, rhs: Self) -> Bool {
// we don't need to compare the lookup table here, as the looup table is only derived
Expand Down
37 changes: 37 additions & 0 deletions Sources/PostgresNIO/New/PostgresQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,43 @@ public struct PostgresQuery: Sendable, Hashable {
}
}

struct PostgresMacroQuery: ExpressibleByStringInterpolation {
var sql: String

public init(stringInterpolation: StringInterpolation) {
sql = stringInterpolation.sql
}

public init(stringLiteral value: String) {
sql = value
}

struct StringInterpolation: StringInterpolationProtocol {
typealias StringLiteralType = String

var sql: String

init(literalCapacity: Int, interpolationCount: Int) {
sql = ""
}

mutating func appendLiteral(_ literal: String) {
sql.append(contentsOf: literal)
}

mutating func appendInterpolation<T: PostgresDecodable>(_ sql: String, type: T.Type) {}
}
}

@Query("SELECT \("id", type: Int.self) FROM users")
struct GetAllUsersQuery {}

public protocol PostgresTypedQuery {
associatedtype Row: PostgresTypedRow

var sql: PostgresQuery { get }
}

extension PostgresQuery: ExpressibleByStringInterpolation {
public init(stringInterpolation: StringInterpolation) {
self.sql = stringInterpolation.sql
Expand Down
29 changes: 28 additions & 1 deletion Sources/PostgresNIO/New/PostgresRowSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ extension PostgresRowSequence {
self.columns = columns
}

public mutating func next() async throws -> PostgresRow? {
public func next() async throws -> PostgresRow? {
if let dataRow = try await self.backing.next() {
return PostgresRow(
data: dataRow,
Expand All @@ -56,6 +56,33 @@ extension PostgresRowSequence {
}
}

public struct PostgresTypedSequence<T: PostgresTypedRow>: AsyncSequence {
public typealias Element = T

let rowSequence: PostgresRowSequence

init(rowSequence: PostgresRowSequence) {
self.rowSequence = rowSequence
}

public func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(rowSequence: rowSequence.makeAsyncIterator())
}
}

extension PostgresTypedSequence {
public struct AsyncIterator: AsyncIteratorProtocol {
let rowSequence: PostgresRowSequence.AsyncIterator

public func next() async throws -> T? {
guard let row = try await self.rowSequence.next() else {
return nil
}
return try T.init(from: row)
}
}
}

extension PostgresRowSequence {
public func collect() async throws -> [PostgresRow] {
var result = [PostgresRow]()
Expand Down
5 changes: 5 additions & 0 deletions Sources/PostgresNIO/Utilities/Macros.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@attached(member)
macro Query(_ query: PostgresMacroQuery) = #externalMacro(
module: "PostgresNIOMacros",
type: "PostgresTypedQueryMacro"
)
21 changes: 21 additions & 0 deletions Sources/PostgresNIOMacros/PostgresNIODiagnostic.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import SwiftDiagnostics

enum PostgresNIODiagnostic: String, DiagnosticMessage {
case wrongArgument
case notAStruct

var message: String {
switch self {
case .wrongArgument:
return "Invalid argument"
case .notAStruct:
return "Macro only works with structs"
}
}

var diagnosticID: SwiftDiagnostics.MessageID {
MessageID(domain: "PostgresNIOMacros", id: rawValue)
}

var severity: SwiftDiagnostics.DiagnosticSeverity { .error }
}
76 changes: 76 additions & 0 deletions Sources/PostgresNIOMacros/PostgresNIOMacro.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import SwiftCompilerPlugin
import SwiftSyntax
import SwiftSyntaxBuilder
import SwiftSyntaxMacros
import SwiftDiagnostics

public struct PostgresTypedQueryMacro: MemberMacro {
public static func expansion(
of node: AttributeSyntax,
providingMembersOf declaration: some DeclGroupSyntax,
in context: some MacroExpansionContext
) throws -> [DeclSyntax] {
guard declaration.is(StructDeclSyntax.self) else {
context.diagnose(Diagnostic(node: Syntax(node), message: PostgresNIODiagnostic.notAStruct))
return []
}

guard let elements = node.argument?.as(TupleExprElementListSyntax.self)?
.first?.as(TupleExprElementSyntax.self)?
.expression.as(StringLiteralExprSyntax.self)?.segments else {
// TODO: Be more specific about this error
context.diagnose(Diagnostic(node: Syntax(node), message: PostgresNIODiagnostic.wrongArgument))
return []
}



var outputTypes: [(String, String)] = []
for tup in elements {
if let expression = tup.as(ExpressionSegmentSyntax.self) {
outputTypes.append(extractColumnTypes(expression))
}
}

let rowStruct = try StructDeclSyntax("struct Row") {
for outputType in outputTypes {
MemberDeclListItemSyntax.init(decl: DeclSyntax(stringLiteral: "let \(outputType.0): \(outputType.1)"))
}
try InitializerDeclSyntax("init(from: PostgresRow) throws") {
// TODO: (id, name) = try row.decode((Int, String).self, context: .default)
}
}

return [
// DeclSyntax(rowStruct)
]
}

/// Returns ("name", "String")
private static func extractColumnTypes(_ node: ExpressionSegmentSyntax) -> (String, String) {
let tupleElements = node.expressions
guard tupleElements.count == 2 else {
fatalError("Expected tuple with exactly two elements")
}

// First element needs to be the column name
var iterator = tupleElements.makeIterator()
guard let columnName = iterator.next()?.expression.as(StringLiteralExprSyntax.self)?
.segments.first?.as(StringSegmentSyntax.self)?.content
.text else {
fatalError("Expected column name")
}

guard let columnType = iterator.next()?.expression.as(MemberAccessExprSyntax.self)?.base?.as(IdentifierExprSyntax.self)?.identifier.text else {
fatalError("Expected column type")
}
return (columnName, columnType)
}
}

@main
struct PostgresNIOMacros: CompilerPlugin {
let providingMacros: [Macro.Type] = [
PostgresTypedQueryMacro.self
]
}
47 changes: 47 additions & 0 deletions Tests/IntegrationTests/TypedQueriesTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import Logging
import XCTest
import PostgresNIO

final class TypedQueriesTests: XCTestCase {
func testTypedPostgresQuery() async throws {
struct MyQuery: PostgresTypedQuery {
struct Row: PostgresTypedRow {
let id: Int
let name: String

init(from row: PostgresRow) throws {
(id, name) = try row.decode((Int, String).self, context: .default)
}
}

var sql: PostgresQuery {
"SELECT id, name FROM users"
}
}

let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

try await withTestConnection(on: eventLoop) { connection in
let createTableQuery = PostgresQuery(unsafeSQL: """
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name character varying(255) NOT NULL
);
""")
let name = "foobar"

try await connection.query(createTableQuery, logger: .psqlTest)
try await connection.query("INSERT INTO users (name) VALUES (\(name));", logger: .psqlTest)

let rows = try await connection.query(MyQuery(), logger: .psqlTest)
for try await row in rows {
XCTAssertEqual(row.name, name)
}

let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE users")
try await connection.query(dropQuery, logger: .psqlTest)
}
}
}
45 changes: 45 additions & 0 deletions Tests/PostgresNIOTests/Macros/MacroTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import SwiftSyntaxMacros
import SwiftSyntaxMacrosTestSupport
import XCTest
import PostgresNIOMacros

let testMacros: [String: Macro.Type] = [
"Query": PostgresTypedQueryMacro.self,
]

final class MacrosTests: XCTestCase {
func testMacro() {
assertMacroExpansion(
#"""
@Query("SELECT \("id", Int.self), \("name", String.self) FROM users")
struct MyQuery {}
"""#,
expandedSource: #"""
struct MyQuery {
struct Row: PostgresTypedRow {
let id: Int
let name: String
}
}
"""#,
// expandedSource: #"""
// struct MyQuery: PostgresTypedQuery {
// struct Row: PostgresTypedRow {
// let id: Int
// let name: String
//
// init(from row: PostgresRow) throws {
// (id, name) = try row.decode((Int, String).self, context: .default)
// }
// }
//
// var sql: PostgresQuery {
// "SELECT id, name FROM users"
// }
// }
// """#,
macros: testMacros
)
}
}