Skip to content
This repository has been archived by the owner on Jun 28, 2018. It is now read-only.

Use db.Conn to fix postgres lock/unlock #303

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions .travis.yml
Expand Up @@ -5,6 +5,8 @@ go:
- 1.7
- 1.8
- 1.9
- 1.9.1
- 1.9.2

env:
- MIGRATE_TEST_CONTAINER_BOOT_DELAY=10
Expand Down Expand Up @@ -38,7 +40,7 @@ deploy:
secure: EFow50BI448HVb/uQ1Kk2Kq0xzmwIYq3V67YyymXIuqSCodvXEsMiBPUoLrxEknpPEIc67LEQTNdfHBgvyHk6oRINWAfie+7pr5tKrpOTF9ghyxoN1PlO8WKQCqwCvGMBCnc5ur5rvzp0bqfpV2rs5q9/nngy3kBuEvs12V7iho=
skip_cleanup: true
on:
go: 1.8
go: 1.9
repo: mattes/migrate
tags: true
file:
Expand All @@ -56,7 +58,7 @@ deploy:
package_glob: '*.deb'
skip_cleanup: true
on:
go: 1.8
go: 1.9
repo: mattes/migrate
tags: true

31 changes: 20 additions & 11 deletions database/postgres/postgres.go
@@ -1,3 +1,4 @@
// +build go1.9
package postgres

import (
Expand All @@ -10,6 +11,7 @@ import (
"github.com/lib/pq"
"github.com/mattes/migrate"
"github.com/mattes/migrate/database"
"context"
)

func init() {
Expand All @@ -33,7 +35,8 @@ type Config struct {
}

type Postgres struct {
db *sql.DB
// Locking and unlocking need to use the same connection
db *sql.Conn
isLocked bool

// Open and WithInstance need to garantuee that config is never nil
Expand Down Expand Up @@ -65,8 +68,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config.MigrationsTable = DefaultMigrationsTable
}

conn, err := instance.Conn(context.Background())

if err != nil {
return nil, err
}

px := &Postgres{
db: instance,
db: conn,
config: config,
}

Expand Down Expand Up @@ -123,7 +132,7 @@ func (p *Postgres) Lock() error {
// or return false if the lock cannot be acquired immediately.
query := `SELECT pg_try_advisory_lock($1)`
var success bool
if err := p.db.QueryRow(query, aid).Scan(&success); err != nil {
if err := p.db.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}

Expand All @@ -146,7 +155,7 @@ func (p *Postgres) Unlock() error {
}

query := `SELECT pg_advisory_unlock($1)`
if _, err := p.db.Exec(query, aid); err != nil {
if _, err := p.db.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
p.isLocked = false
Expand All @@ -161,7 +170,7 @@ func (p *Postgres) Run(migration io.Reader) error {

// run migration
query := string(migr[:])
if _, err := p.db.Exec(query); err != nil {
if _, err := p.db.ExecContext(context.Background(), query); err != nil {
// TODO: cast to postgress error and get line number
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
Expand All @@ -170,7 +179,7 @@ func (p *Postgres) Run(migration io.Reader) error {
}

func (p *Postgres) SetVersion(version int, dirty bool) error {
tx, err := p.db.Begin()
tx, err := p.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
Expand Down Expand Up @@ -198,7 +207,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {

func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
err = p.db.QueryRow(query).Scan(&version, &dirty)
err = p.db.QueryRowContext(context.Background(),query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
Expand All @@ -219,7 +228,7 @@ func (p *Postgres) Version() (version int, dirty bool, err error) {
func (p *Postgres) Drop() error {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
tables, err := p.db.Query(query)
tables, err := p.db.QueryContext(context.Background(),query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand All @@ -241,7 +250,7 @@ func (p *Postgres) Drop() error {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
if _, err := p.db.Exec(query); err != nil {
if _, err := p.db.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
Expand All @@ -257,7 +266,7 @@ func (p *Postgres) ensureVersionTable() error {
// check if migration table exists
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.db.QueryRow(query, p.config.MigrationsTable).Scan(&count); err != nil {
if err := p.db.QueryRowContext(context.Background(),query, p.config.MigrationsTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
Expand All @@ -266,7 +275,7 @@ func (p *Postgres) ensureVersionTable() error {

// if not, create the empty migration table
query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
if _, err := p.db.Exec(query); err != nil {
if _, err := p.db.ExecContext(context.Background(),query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
Expand Down
39 changes: 38 additions & 1 deletion database/postgres/postgres_test.go
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/lib/pq"
dt "github.com/mattes/migrate/database/testing"
mt "github.com/mattes/migrate/testing"
"context"
)

var versions = []mt.Version{
Expand Down Expand Up @@ -69,7 +70,7 @@ func TestMultiStatement(t *testing.T) {

// make sure second table exists
var exists bool
if err := d.(*Postgres).db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
if err := d.(*Postgres).db.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
Expand Down Expand Up @@ -148,3 +149,39 @@ func TestWithSchema(t *testing.T) {
func TestWithInstance(t *testing.T) {

}

func TestPostgres_Lock(t *testing.T) {
mt.ParallelTest(t, versions, isReady,
func(t *testing.T, i mt.Instance) {
p := &Postgres{}
addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable", i.Host(), i.Port())
d, err := p.Open(addr)
if err != nil {
t.Fatalf("%v", err)
}

dt.Test(t, d, []byte("SELECT 1"))

ps := d.(*Postgres)

err = ps.Lock()
if err != nil {
t.Fatal(err)
}

err = ps.Unlock()
if err != nil {
t.Fatal(err)
}

err = ps.Lock()
if err != nil {
t.Fatal(err)
}

err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
})
}