Skip to content
This repository has been archived by the owner on Jun 28, 2018. It is now read-only.

Return a new instance of Driver from registry's GetDriver() function #195

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
@@ -1,3 +1,7 @@
__[v3.0 in the making](https://github.com/mattes/migrate/tree/v3.0-prev)__

---

# migrate

[![Build Status](https://travis-ci.org/mattes/migrate.svg?branch=master)](https://travis-ci.org/mattes/migrate)
Expand Down
2 changes: 1 addition & 1 deletion driver/bash/bash.go
Expand Up @@ -32,5 +32,5 @@ func (driver *Driver) Version() (uint64, error) {
}

func init() {
driver.RegisterDriver("bash", &Driver{})
driver.RegisterDriver("bash", Driver{})
}
97 changes: 41 additions & 56 deletions driver/cassandra/cassandra.go
Expand Up @@ -14,44 +14,16 @@ import (
"github.com/mattes/migrate/migrate/direction"
)

// Driver implements migrate Driver interface
type Driver struct {
session *gocql.Session
}

const (
tableName = "schema_migrations"
versionRow = 1
tableName = "schema_migrations"
)

type counterStmt bool

func (c counterStmt) Exec(session *gocql.Session) error {
var version int64
if err := session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version); err != nil {
return err
}

if bool(c) {
version++
} else {
version--
}

return session.Query("UPDATE "+tableName+" SET version = ? WHERE versionRow = ?", version, versionRow).Exec()
}

const (
up counterStmt = true
down counterStmt = false
)

// Cassandra Driver URL format:
// cassandra://host:port/keyspace?protocol=version&consistency=level
//
// Examples:
// cassandra://localhost/SpaceOfKeys?protocol=4
// cassandra://localhost/SpaceOfKeys?protocol=4&consistency=all
// cassandra://localhost/SpaceOfKeys?consistency=quorum
// Initialize will be called first
func (driver *Driver) Initialize(rawurl string) error {
u, err := url.Parse(rawurl)
if err != nil {
Expand All @@ -68,7 +40,8 @@ func (driver *Driver) Initialize(rawurl string) error {
cluster.Timeout = 1 * time.Minute

if len(u.Query().Get("consistency")) > 0 {
consistency, err := parseConsistency(u.Query().Get("consistency"))
var consistency gocql.Consistency
consistency, err = parseConsistency(u.Query().Get("consistency"))
if err != nil {
return err
}
Expand All @@ -77,7 +50,8 @@ func (driver *Driver) Initialize(rawurl string) error {
}

if len(u.Query().Get("protocol")) > 0 {
protoversion, err := strconv.Atoi(u.Query().Get("protocol"))
var protoversion int
protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
if err != nil {
return err
}
Expand All @@ -90,7 +64,7 @@ func (driver *Driver) Initialize(rawurl string) error {
password, passwordSet := u.User.Password()

if passwordSet == false {
return fmt.Errorf("Missing password. Please provide password.")
return fmt.Errorf("Missing password. Please provide password")
}

cluster.Authenticator = gocql.PasswordAuthenticator{
Expand All @@ -112,61 +86,71 @@ func (driver *Driver) Initialize(rawurl string) error {
return nil
}

// Close last function to be called. Closes cassandra session
func (driver *Driver) Close() error {
driver.session.Close()
return nil
}

func (driver *Driver) ensureVersionTableExists() error {
err := driver.session.Query("CREATE TABLE IF NOT EXISTS " + tableName + " (version int, versionRow bigint primary key);").Exec()
err := driver.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id uuid primary key, version bigint)", tableName)).Exec()
if err != nil {
return err
}

_, err = driver.Version()
if err != nil {
if err.Error() == "not found" {
return driver.session.Query("UPDATE "+tableName+" SET version = ? WHERE versionRow = ?", 1, versionRow).Exec()
}
if _, err = driver.Version(); err != nil {
return err
}

return nil
}

// FilenameExtension return extension of migrations files
func (driver *Driver) FilenameExtension() string {
return "cql"
}

func (driver *Driver) version(d direction.Direction, invert bool) error {
var stmt counterStmt
switch d {
case direction.Up:
stmt = up
case direction.Down:
stmt = down
func (driver *Driver) updateVersion(version uint64, dir direction.Direction) error {
var ids []string
var id string
var err error
iter := driver.session.Query(fmt.Sprintf("SELECT id FROM %s WHERE version >= ? ALLOW FILTERING", tableName), version).Iter()
for iter.Scan(&id) {
ids = append(ids, id)
}
if len(ids) > 0 {
err = driver.session.Query(fmt.Sprintf("DELETE FROM %s WHERE id IN ?", tableName), ids).Exec()
if err != nil {
return err
}
}
if invert {
stmt = !stmt
if dir == direction.Up {
return driver.session.Query(fmt.Sprintf("INSERT INTO %s (id, version) VALUES (uuid(), ?)", tableName), version).Exec()
}
return stmt.Exec(driver.session)
return nil
}

// Migrate run migration file. Restore previous version in case of fail
func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
var err error
previousVersion, err := driver.Version()
if err != nil {
close(pipe)
return
}
defer func() {
if err != nil {
// Invert version direction if we couldn't apply the changes for some reason.
if err := driver.version(f.Direction, true); err != nil {
pipe <- err
if updErr := driver.updateVersion(previousVersion, direction.Up); updErr != nil {
pipe <- updErr
}
pipe <- err
}
close(pipe)
}()

pipe <- f
if err = driver.version(f.Direction, false); err != nil {
if err = driver.updateVersion(f.Version, f.Direction); err != nil {
return
}

Expand All @@ -186,14 +170,15 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
}
}

// Version return current version
func (driver *Driver) Version() (uint64, error) {
var version int64
err := driver.session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version)
return uint64(version) - 1, err
err := driver.session.Query(fmt.Sprintf("SELECT max(version) FROM %s", tableName)).Scan(&version)
return uint64(version), err
}

func init() {
driver.RegisterDriver("cassandra", &Driver{})
driver.RegisterDriver("cassandra", Driver{})
}

// ParseConsistency wraps gocql.ParseConsistency to return an error
Expand Down
16 changes: 9 additions & 7 deletions driver/cassandra/cassandra_test.go
Expand Up @@ -21,10 +21,10 @@ func TestMigrate(t *testing.T) {

host := os.Getenv("CASSANDRA_PORT_9042_TCP_ADDR")
port := os.Getenv("CASSANDRA_PORT_9042_TCP_PORT")
driverUrl := "cassandra://" + host + ":" + port + "/system"
driverURL := "cassandra://" + host + ":" + port + "/system"

// prepare a clean test database
u, err := url.Parse(driverUrl)
u, err := url.Parse(driverURL)
if err != nil {
t.Fatal(err)
}
Expand All @@ -35,23 +35,25 @@ func TestMigrate(t *testing.T) {
cluster.Timeout = 1 * time.Minute

session, err = cluster.CreateSession()

if err != nil {
t.Fatal(err)
}

if err := session.Query(`DROP KEYSPACE IF EXISTS migrate;`).Exec(); err != nil {
if err = session.Query(`DROP KEYSPACE IF EXISTS migrate;`).Exec(); err != nil {
t.Fatal(err)
}
if err := session.Query(`CREATE KEYSPACE IF NOT EXISTS migrate WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1};`).Exec(); err != nil {
if err = session.Query(`CREATE KEYSPACE IF NOT EXISTS migrate WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1};`).Exec(); err != nil {
t.Fatal(err)
}
cluster.Keyspace = "migrate"
session, err = cluster.CreateSession()
driverUrl = "cassandra://" + host + ":" + port + "/migrate"
if err != nil {
t.Fatal(err)
}
driverURL = "cassandra://" + host + ":" + port + "/migrate"

d := &Driver{}
if err := d.Initialize(driverUrl); err != nil {
if err := d.Initialize(driverURL); err != nil {
t.Fatal(err)
}

Expand Down
6 changes: 3 additions & 3 deletions driver/crate/crate.go
Expand Up @@ -13,7 +13,7 @@ import (
)

func init() {
driver.RegisterDriver("crate", &Driver{})
driver.RegisterDriver("crate", Driver{})
}

type Driver struct {
Expand Down Expand Up @@ -97,8 +97,8 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
func splitContent(content string) []string {
lines := strings.Split(content, ";")
resultLines := make([]string, 0, len(lines))
for i, line := range lines {
line = strings.Replace(lines[i], ";", "", -1)
for i := range lines {
line := strings.Replace(lines[i], ";", "", -1)
line = strings.TrimSpace(line)
if line != "" {
resultLines = append(resultLines, line)
Expand Down
18 changes: 17 additions & 1 deletion driver/driver.go
Expand Up @@ -43,7 +43,7 @@ func New(url string) (Driver, error) {

d := GetDriver(u.Scheme)
if d == nil {
return nil, fmt.Errorf("Driver '%s' not found.", u.Scheme)
return nil, fmt.Errorf("Driver '%s' not found", u.Scheme)
}
verifyFilenameExtension(u.Scheme, d)
if err := d.Initialize(url); err != nil {
Expand All @@ -53,6 +53,22 @@ func New(url string) (Driver, error) {
return d, nil
}

// FilenameExtensionFromURL return extension for migration files. Used for create migrations
func FilenameExtensionFromURL(url string) (string, error) {
u, err := neturl.Parse(url)
if err != nil {
return "", err
}

d := GetDriver(u.Scheme)
if d == nil {
return "", fmt.Errorf("Driver '%s' not found", u.Scheme)
}
verifyFilenameExtension(u.Scheme, d)

return d.FilenameExtension(), nil
}

// verifyFilenameExtension panics if the driver's filename extension
// is not correct or empty.
func verifyFilenameExtension(driverName string, d Driver) {
Expand Down
7 changes: 4 additions & 3 deletions driver/mongodb/example/mongodb_test.go
Expand Up @@ -6,13 +6,14 @@ import (
"github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction"

"os"
"reflect"
"time"

"github.com/mattes/migrate/driver"
"github.com/mattes/migrate/driver/mongodb"
"github.com/mattes/migrate/driver/mongodb/gomethods"
pipep "github.com/mattes/migrate/pipe"
"os"
"reflect"
"time"
)

type ExpectedMigrationResult struct {
Expand Down
3 changes: 2 additions & 1 deletion driver/mongodb/example/sample_mongdb_migrator.go
@@ -1,11 +1,12 @@
package example

import (
"time"

"github.com/mattes/migrate/driver/mongodb/gomethods"
_ "github.com/mattes/migrate/driver/mongodb/gomethods"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
"time"

"github.com/mattes/migrate/driver/mongodb"
)
Expand Down
5 changes: 3 additions & 2 deletions driver/mongodb/gomethods/gomethods_migrator.go
Expand Up @@ -3,11 +3,12 @@ package gomethods
import (
"bufio"
"fmt"
"github.com/mattes/migrate/driver"
"github.com/mattes/migrate/file"
"os"
"path"
"strings"

"github.com/mattes/migrate/driver"
"github.com/mattes/migrate/file"
)

type MethodNotFoundError string
Expand Down
3 changes: 2 additions & 1 deletion driver/mongodb/gomethods/gomethods_registry.go
Expand Up @@ -2,8 +2,9 @@ package gomethods

import (
"fmt"
"github.com/mattes/migrate/driver"
"sync"

"github.com/mattes/migrate/driver"
)

var methodsReceiversMu sync.Mutex
Expand Down
7 changes: 4 additions & 3 deletions driver/mongodb/mongodb.go
Expand Up @@ -2,14 +2,15 @@ package mongodb

import (
"errors"
"reflect"
"strings"

"github.com/mattes/migrate/driver"
"github.com/mattes/migrate/driver/mongodb/gomethods"
"github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
"reflect"
"strings"
)

type UnregisteredMethodsReceiverError string
Expand Down Expand Up @@ -55,7 +56,7 @@ func (d *Driver) SetMethodsReceiver(r interface{}) error {
}

func init() {
driver.RegisterDriver("mongodb", &Driver{})
driver.RegisterDriver("mongodb", Driver{})
}

type DbMigration struct {
Expand Down
2 changes: 1 addition & 1 deletion driver/mysql/mysql.go
Expand Up @@ -181,5 +181,5 @@ func (driver *Driver) Version() (uint64, error) {
}

func init() {
driver.RegisterDriver("mysql", &Driver{})
driver.RegisterDriver("mysql", Driver{})
}