Skip to content

Commit

Permalink
Merge pull request #213 from erizocosmico/fix/byte-array-scan
Browse files Browse the repository at this point in the history
scan bytea as []byte and vice-versa
  • Loading branch information
erizocosmico committed Aug 1, 2017
2 parents 96a9f67 + 92523fd commit 0cd6899
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 102 deletions.
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -724,7 +724,8 @@ kallax migrate up --dir ./my-migrations --dsn 'user:pass@localhost:5432/dbname?s
| `url.URL` | `text` |
| `time.Time` | `timestamptz` |
| `time.Duration` | `bigint` |
| `[]T` | `T'[]` * where `T'` is the SQL type of type `T` |
| `[]byte` | `bytea` |
| `[]T` | `T'[]` * where `T'` is the SQL type of type `T`, except for `T` = `byte` |
| `map[K]V` | `jsonb` |
| `struct` | `jsonb` |
| `*struct` | `jsonb` |
Expand Down
9 changes: 8 additions & 1 deletion generator/migration.go
Expand Up @@ -197,6 +197,7 @@ func (s *ColumnSchema) String() string {
type ColumnType string

const (
ByteaColumn ColumnType = "bytea"
SmallIntColumn ColumnType = "smallint"
IntegerColumn ColumnType = "integer"
BigIntColumn ColumnType = "bigint"
Expand Down Expand Up @@ -225,6 +226,7 @@ func ArrayColumn(typ ColumnType) ColumnType {
if strings.HasSuffix(string(typ), "[]") {
return typ
}

return typ + "[]"
}

Expand Down Expand Up @@ -833,7 +835,12 @@ func (t *packageTransformer) transformType(f *Field, pk bool) (ColumnType, error
}

if f.Kind == Array || f.Kind == Slice {
return ArrayColumn(typeMappings[removeTypePrefix(f.Type)]), nil
typ := removeTypePrefix(f.Type)
if typ == "byte" {
return ByteaColumn, nil
}

return ArrayColumn(typeMappings[typ]), nil
}

if pk {
Expand Down
2 changes: 2 additions & 0 deletions generator/migration_test.go
Expand Up @@ -527,6 +527,7 @@ type Profile struct {
// should be added anyway
// should be added as bigint, as it is not a pk
Metadata ProfileMetadata
SomeData []byte
}
type ProfileMetadata struct {
Expand Down Expand Up @@ -569,6 +570,7 @@ func (s *PackageTransformerSuite) TestTransform() {
mkCol("background", TextColumn, false, true, nil),
mkCol("user_id", UUIDColumn, false, false, mkRef("users", "id", true)),
mkCol("spouse", UUIDColumn, false, false, nil),
mkCol("some_data", ByteaColumn, false, true, nil),
),
mkTable(
"metadata",
Expand Down
66 changes: 14 additions & 52 deletions types/slices.go
Expand Up @@ -71,10 +71,10 @@ func Slice(v interface{}) SQLType {
return (*Int8Array)(&v)
case *[]int8:
return (*Int8Array)(v)
case []uint8:
return (*Uint8Array)(&v)
case *[]uint8:
return (*Uint8Array)(v)
case []byte:
return (*ByteArray)(&v)
case *[]byte:
return (*ByteArray)(v)
case *[]float32:
return (*Float32Array)(v)
case []float32:
Expand Down Expand Up @@ -646,67 +646,29 @@ func (a Int8Array) Value() (driver.Value, error) {
return "{}", nil
}

// Uint8Array represents a one-dimensional array of the PostgreSQL unsigned integer type.
type Uint8Array []uint8
// ByteArray represents a byte array `bytea`.
type ByteArray []uint8

// Scan implements the sql.Scanner interface.
func (a *Uint8Array) Scan(src interface{}) error {
func (a *ByteArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
*(*[]byte)(a) = src
return nil
case string:
return a.scanBytes([]byte(src))
*(*[]byte)(a) = []byte(src)
return nil
case nil:
*a = nil
return nil
}

return fmt.Errorf("kallax: cannot convert %T to Uint8Array", src)
}

func (a *Uint8Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Uint8Array")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Uint8Array, len(elems))
for i, v := range elems {
val, err := strconv.ParseUint(string(v), 10, 8)
if err != nil {
return fmt.Errorf("kallax: parsing array element index %d: %v", i, err)
}
b[i] = uint8(val)
}
*a = b
}
return nil
return fmt.Errorf("kallax: cannot convert %T to ByteArray", src)
}

// Value implements the driver.Valuer interface.
func (a Uint8Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}

if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'

b = strconv.AppendUint(b, uint64(a[0]), 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendUint(b, uint64(a[i]), 10)
}

return string(append(b, '}')), nil
}

return "{}", nil
func (a ByteArray) Value() (driver.Value, error) {
return ([]byte)(a), nil
}

// Float32Array represents a one-dimensional array of the PostgreSQL real type.
Expand Down
103 changes: 55 additions & 48 deletions types/slices_test.go
Expand Up @@ -14,8 +14,6 @@ import (
)

func TestSlice(t *testing.T) {
require := require.New(t)

cases := []struct {
v interface{}
input interface{}
Expand Down Expand Up @@ -76,16 +74,6 @@ func TestSlice(t *testing.T) {
[]int8{1, 3, 4},
&([]int8{}),
},
{
&([]uint8{1, 3, 4}),
[]uint8{1, 3, 4},
&([]uint8{}),
},
{
&([]byte{1, 3, 4}),
[]byte{1, 3, 4},
&([]byte{}),
},
{
&([]float32{1., 3., .4}),
[]float32{1., 3., .4},
Expand All @@ -94,22 +82,35 @@ func TestSlice(t *testing.T) {
}

for _, c := range cases {
arr := Slice(c.v)
val, err := arr.Value()
require.Nil(err)
t.Run(reflect.TypeOf(c.input).String(), func(t *testing.T) {
require := require.New(t)
arr := Slice(c.v)
val, err := arr.Value()
require.NoError(err)

pqArr := pq.Array(c.input)
pqVal, err := pqArr.Value()
require.NoError(err)

require.Equal(pqVal, val)
require.NoError(Slice(c.dest).Scan(val))
require.Equal(c.v, c.dest)
})
}

pqArr := pq.Array(c.input)
pqVal, err := pqArr.Value()
require.Nil(err)
t.Run("[]byte", func(t *testing.T) {
require := require.New(t)
arr := Slice([]byte{1, 2, 3})
val, err := arr.Value()
require.NoError(err)

require.Equal(pqVal, val)
require.Nil(Slice(c.dest).Scan(val))
require.Equal(c.v, c.dest)
}
var b []byte
require.NoError(Slice(&b).Scan(val))
require.Equal([]byte{1, 2, 3}, b)
})
}

func TestSlice_Integration(t *testing.T) {
s := require.New(t)
cases := []struct {
name string
typ string
Expand All @@ -118,85 +119,91 @@ func TestSlice_Integration(t *testing.T) {
}{
{
"int8",
"smallint",
"smallint[]",
[]int8{math.MaxInt8, math.MinInt8},
&([]int8{}),
},
{
"unsigned int8",
"smallint",
[]uint8{math.MaxUint8, 0},
&([]uint8{}),
"byte",
"bytea",
[]byte{math.MaxUint8, 0},
&([]byte{}),
},
{
"int16",
"smallint",
"smallint[]",
[]int16{math.MaxInt16, math.MinInt16},
&([]int16{}),
},
{
"unsigned int16",
"integer",
"integer[]",
[]uint16{math.MaxUint16, 0},
&([]uint16{}),
},
{
"int32",
"integer",
"integer[]",
[]int32{math.MaxInt32, math.MinInt32},
&([]int32{}),
},
{
"unsigned int32",
"bigint",
"bigint[]",
[]uint32{math.MaxUint32, 0},
&([]uint32{}),
},
{
"int/int64",
"bigint",
"bigint[]",
[]int{math.MaxInt64, math.MinInt64},
&([]int{}),
},
{
"unsigned int/int64",
"numeric(20)",
"numeric(20)[]",
[]uint{math.MaxUint64, 0},
&([]uint{}),
},
{
"float32",
"decimal(10,3)",
"decimal(10,3)[]",
[]float32{.3, .6},
&([]float32{.3, .6}),
},
}

db, err := openTestDB()
s.Nil(err)
require.NoError(t, err)

defer func() {
_, err = db.Exec("DROP TABLE IF EXISTS foo")
s.Nil(err)
require.NoError(t, err)

s.Nil(db.Close())
require.NoError(t, db.Close())
}()

for _, c := range cases {
_, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo (
testcol %s[]
t.Run(c.name, func(t *testing.T) {
require := require.New(t)

_, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo (
testcol %s
)`, c.typ))
s.Nil(err, c.name)
require.NoError(err, c.name)

_, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", Slice(c.input))
s.Nil(err, c.name)
defer func() {
_, err := db.Exec("DROP TABLE foo")
require.NoError(err)
}()

s.Nil(db.QueryRow("SELECT testcol FROM foo LIMIT 1").Scan(Slice(c.dst)), c.name)
slice := reflect.ValueOf(c.dst).Elem().Interface()
s.Equal(c.input, slice, c.name)
_, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", Slice(c.input))
require.NoError(err, c.name)

_, err = db.Exec("DROP TABLE foo")
s.Nil(err, c.name)
require.NoError(db.QueryRow("SELECT testcol FROM foo LIMIT 1").Scan(Slice(c.dst)), c.name)
slice := reflect.ValueOf(c.dst).Elem().Interface()
require.Equal(c.input, slice, c.name)
})
}
}

Expand Down

0 comments on commit 0cd6899

Please sign in to comment.