Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration groups #6

Merged
merged 5 commits into from Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 65 additions & 0 deletions Sources/HummingbirdPostgres/Migration.swift
@@ -0,0 +1,65 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Logging
@_spi(ConnectionPool) import PostgresNIO

/// Protocol for a database migration
///
/// Requires two functions one to apply the database migration and one to revert it.
public protocol HBPostgresMigration {
/// Apply database migration
func apply(connection: PostgresConnection, logger: Logger) async throws
/// Revert database migration
func revert(connection: PostgresConnection, logger: Logger) async throws
/// Migration name
var name: String { get }
/// Group migration belongs to
var group: HBMigrationGroup { get }
}

extension HBPostgresMigration {
/// Default implementaion of name
public var name: String { String(describing: Self.self) }
/// Default group is default
public var group: HBMigrationGroup { .default }
}

/// Group identifier for a group of migrations.
///
/// Migrations in one group are treated independently of migrations in other groups. You can add a
/// migration to a group and it will not affect any subsequent migrations not in that group. By default
/// all migrations belong to the ``HBMigrationGroup.default`` group.
///
/// To add a migration to a separate group you first need to define the group by adding a static variable
/// to `HBMigrationGroup`.
/// ```
/// extension HBMigrationGroup {
/// public static var `myGroup`: Self { .init("myGroup") }
/// }
/// ```
/// After that use to ``HBPostgresMigration.group`` set the group for a migration.
///
/// Only use a group different from `.default` if you are certain that the database elements you are
/// creating within that group will always be independent of everything else in the database. Groups
/// are useful for libraries that use migrations to setup their database elements.
public struct HBMigrationGroup: Hashable, Equatable {
let name: String

public init(_ name: String) {
self.name = name
}

public static var `default`: Self { .init("_hb_default") }
}
33 changes: 33 additions & 0 deletions Sources/HummingbirdPostgres/MigrationError.swift
@@ -0,0 +1,33 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

/// Error thrown by migration code
public struct HBPostgresMigrationError: Error, Equatable {
enum _Internal {
case requiresChanges
case cannotRevertMigration
}

private let value: _Internal

private init(_ value: _Internal) {
self.value = value
}

/// The database requires a migration before the application can run
static var requiresChanges: Self { .init(.requiresChanges) }
/// Cannot revert a migration as we do not have its details. Add it to the revert list using
/// HBPostgresMigrations.add(revert:)
static var cannotRevertMigration: Self { .init(.cannotRevertMigration) }
}
170 changes: 85 additions & 85 deletions Sources/HummingbirdPostgres/Migrations.swift
Expand Up @@ -15,40 +15,6 @@
import Logging
@_spi(ConnectionPool) import PostgresNIO

/// Protocol for a database migration
///
/// Requires two functions one to apply the database migration and one to revert it.
public protocol HBPostgresMigration {
/// Apply database migration
func apply(connection: PostgresConnection, logger: Logger) async throws
/// Revert database migration
func revert(connection: PostgresConnection, logger: Logger) async throws
}

extension HBPostgresMigration {
public var name: String { String(describing: Self.self) }
}

/// Error thrown by migration code
public struct HBPostgresMigrationError: Error, Equatable {
enum _Internal {
case requiresChanges
case cannotRevertMigration
}

private let value: _Internal

private init(_ value: _Internal) {
self.value = value
}

/// The database requires a migration before the application can run
static var requiresChanges: Self { .init(.requiresChanges) }
/// Cannot revert a migration as we do not have its details. Add it to the revert list using
/// HBPostgresMigrations.add(revert:)
static var cannotRevertMigration: Self { .init(.cannotRevertMigration) }
}

/// Database migration support
public final class HBPostgresMigrations {
var migrations: [HBPostgresMigration]
Expand All @@ -74,60 +40,79 @@ public final class HBPostgresMigrations {
/// Apply database migrations
@_spi(ConnectionPool)
@MainActor
public func apply(client: PostgresClient, logger: Logger, dryRun: Bool) async throws {
try await self.migrate(client: client, migrations: self.migrations, logger: logger, dryRun: dryRun)
public func apply(client: PostgresClient, groups: [HBMigrationGroup] = [], logger: Logger, dryRun: Bool) async throws {
try await self.migrate(client: client, migrations: self.migrations, groups: groups, logger: logger, dryRun: dryRun)
}

@_spi(ConnectionPool)
@MainActor
public func revert(client: PostgresClient, logger: Logger, dryRun: Bool) async throws {
try await self.migrate(client: client, migrations: [], logger: logger, dryRun: dryRun)
public func revert(client: PostgresClient, groups: [HBMigrationGroup] = [], logger: Logger, dryRun: Bool) async throws {
try await self.migrate(client: client, migrations: [], groups: groups, logger: logger, dryRun: dryRun)
}

private func migrate(client: PostgresClient, migrations: [HBPostgresMigration], logger: Logger, dryRun: Bool) async throws {
private func migrate(
client: PostgresClient,
migrations: [HBPostgresMigration],
groups: [HBMigrationGroup],
logger: Logger,
dryRun: Bool
) async throws {
let repository = HBPostgresMigrationRepository(client: client)
_ = try await repository.withContext(logger: logger) { context in
// setup migration repository (create table)
try await repository.setup(context: context)
var requiresChanges = false
// get migrations currently applied in the order they were applied
let appliedMigrations = try await repository.getAll(context: context)
let minMigrationCount = min(migrations.count, appliedMigrations.count)
var i = 0
while i < minMigrationCount, appliedMigrations[i] == migrations[i].name {
i += 1
}
// Revert deleted migrations, and any migrations after a deleted migration
for j in (i..<appliedMigrations.count).reversed() {
let migrationName = appliedMigrations[j]
// look for migration to revert in migration list and revert dictionary. NB we are looking in the migration
// array belonging to the type, not the one supplied to the function
guard let migration = self.migrations.first(where: { $0.name == migrationName }) ?? self.reverts[migrationName] else {
throw HBPostgresMigrationError.cannotRevertMigration
do {
_ = try await repository.withContext(logger: logger) { context in
// setup migration repository (create table)
try await repository.setup(context: context)
var requiresChanges = false
// get migrations currently applied in the order they were applied
let appliedMigrations = try await repository.getAll(context: context)
// if groups array passed in is empty then work out list of migration groups by combining
// list of groups from migrations and applied migrations
let groups = groups.count == 0
? (migrations.map(\.group) + appliedMigrations.map(\.group)).uniqueElements
: groups
// for each group apply/revert migrations
for group in groups {
let groupMigrations = migrations.filter { $0.group == group }
let appliedGroupMigrations = appliedMigrations.filter { $0.group == group }

let minMigrationCount = min(groupMigrations.count, appliedGroupMigrations.count)
var i = 0
while i < minMigrationCount, appliedGroupMigrations[i].name == groupMigrations[i].name {
i += 1
}
// Revert deleted migrations, and any migrations after a deleted migration
for j in (i..<appliedGroupMigrations.count).reversed() {
let migrationName = appliedGroupMigrations[j].name
// look for migration to revert in migration list and revert dictionary. NB we are looking in the migration
// array belonging to the type, not the one supplied to the function
guard let migration = self.migrations.first(where: { $0.name == migrationName }) ?? self.reverts[migrationName] else {
throw HBPostgresMigrationError.cannotRevertMigration
}
logger.info("Reverting \(migration.name) from group \(group.name) \(dryRun ? " (dry run)" : "")")
if !dryRun {
try await migration.revert(connection: context.connection, logger: context.logger)
try await repository.remove(migration, context: context)
} else {
requiresChanges = true
}
}
// Apply migration
for j in i..<groupMigrations.count {
let migration = groupMigrations[j]
logger.info("Migrating \(migration.name) from group \(group.name) \(dryRun ? " (dry run)" : "")")
if !dryRun {
try await migration.apply(connection: context.connection, logger: context.logger)
try await repository.add(migration, context: context)
} else {
requiresChanges = true
}
}
}
logger.info("Reverting \(migration.name)\(dryRun ? " (dry run)" : "")")
if !dryRun {
try await migration.revert(connection: context.connection, logger: context.logger)
try await repository.remove(migration, context: context)
} else {
requiresChanges = true
// if changes are required
guard requiresChanges == false else {
throw HBPostgresMigrationError.requiresChanges
}
}
// Apply migration
for j in i..<migrations.count {
let migration = migrations[j]
logger.info("Migrating \(migration.name)\(dryRun ? " (dry run)" : "")")
if !dryRun {
try await migration.apply(connection: context.connection, logger: context.logger)
try await repository.add(migration, context: context)
} else {
requiresChanges = true
}
}
// if changes are required
guard requiresChanges == false else {
throw HBPostgresMigrationError.requiresChanges
}
}
}
}
Expand All @@ -153,7 +138,7 @@ struct HBPostgresMigrationRepository {

func add(_ migration: HBPostgresMigration, context: Context) async throws {
try await context.connection.query(
"INSERT INTO _hb_migrations (name) VALUES (\(migration.name))",
"INSERT INTO _hb_migrations (\"name\", \"group\") VALUES (\(migration.name), \(migration.group.name))",
logger: context.logger
)
}
Expand All @@ -165,14 +150,14 @@ struct HBPostgresMigrationRepository {
)
}

func getAll(context: Context) async throws -> [String] {
func getAll(context: Context) async throws -> [(name: String, group: HBMigrationGroup)] {
let stream = try await context.connection.query(
"SELECT name FROM _hb_migrations ORDER BY \"order\"",
"SELECT \"name\", \"group\" FROM _hb_migrations ORDER BY \"order\"",
logger: context.logger
)
var result: [String] = []
for try await name in stream.decode(String.self, context: .default) {
result.append(name)
var result: [(String, HBMigrationGroup)] = []
for try await (name, group) in stream.decode((String, String).self, context: .default) {
result.append((name, .init(group)))
}
return result
}
Expand All @@ -182,10 +167,25 @@ struct HBPostgresMigrationRepository {
"""
CREATE TABLE IF NOT EXISTS _hb_migrations (
"order" SERIAL PRIMARY KEY,
"name" text
"name" text,
"group" text
)
""",
logger: logger
)
}
}

extension Array where Element: Equatable {
/// The list of unique elements in the list, in the order they are found
var uniqueElements: [Element] {
self.reduce([]) { result, name in
if result.first(where: { $0 == name }) == nil {
var result = result
result.append(name)
return result
}
return result
}
}
}