Skip to content

Commit

Permalink
Merge pull request #1380 from cubiest/feature/unmarshal_text
Browse files Browse the repository at this point in the history
Implement (Un-)marshalText for Decimal and NullDecimal
  • Loading branch information
stephenafamo committed Apr 20, 2024
2 parents 028caa1 + 4676800 commit 20df20a
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
43 changes: 43 additions & 0 deletions types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@ func (d *Decimal) UnmarshalJSON(data []byte) error {
return d.Big.UnmarshalJSON(data)
}

// MarshalText marshals a decimal value
func (d Decimal) MarshalText() ([]byte, error) {
if d.Big == nil {
return nullBytes, nil
}

return d.Big.MarshalText()
}

// UnmarshalText allows marshalling text into a null pointer
func (d *Decimal) UnmarshalText(data []byte) error {
if d.Big == nil {
d.Big = new(decimal.Big)
}

return d.Big.UnmarshalText(data)
}

// Randomize implements sqlboiler's randomize interface
func (d *Decimal) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) {
d.Big = randomDecimal(nextInt, fieldType, false)
Expand Down Expand Up @@ -118,6 +136,31 @@ func (n *NullDecimal) UnmarshalJSON(data []byte) error {
return n.Big.UnmarshalJSON(data)
}

// MarshalText marshals a decimal value
func (n NullDecimal) MarshalText() ([]byte, error) {
if n.Big == nil {
return nullBytes, nil
}

return n.Big.MarshalText()
}

// UnmarshalText allows marshalling text into a null pointer
func (n *NullDecimal) UnmarshalText(data []byte) error {
if bytes.Equal(data, nullBytes) {
if n != nil {
n.Big = nil
}
return nil
}

if n.Big == nil {
n.Big = decimal.WithContext(DecimalContext)
}

return n.Big.UnmarshalText(data)
}

// String impl
func (n NullDecimal) String() string {
if n.Big == nil {
Expand Down
65 changes: 65 additions & 0 deletions types/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,38 @@ func TestDecimal_JSON(t *testing.T) {
}
}

func TestDecimal_Text(t *testing.T) {
t.Parallel()

d := new(Decimal)

err := d.UnmarshalText([]byte(`54.45`))
if err != nil {
t.Error(err)
}

want, _ := new(decimal.Big).SetString("54.45")
if d.Cmp(want) != 0 {
t.Error("D was wrong:", d)
}
}

func TestDecimal_TextNil(t *testing.T) {
t.Parallel()

var n Decimal
b, _ := n.MarshalText()
if string(b) != `null` {
t.Errorf("want: null, got: %s", b)
}

n2 := new(Decimal)
b, _ = n2.MarshalText()
if string(b) != `null` {
t.Errorf("want: null, got: %s", b)
}
}

func TestNullDecimal_JSON(t *testing.T) {
t.Parallel()

Expand All @@ -185,6 +217,23 @@ func TestNullDecimal_JSON(t *testing.T) {
}
}

func TestNullDecimal_Text(t *testing.T) {
t.Parallel()

n := new(NullDecimal)

err := n.UnmarshalText([]byte(`54.45`))
if err != nil {
t.Error(err)
}

want, _ := new(decimal.Big).SetString("54.45")
if n.Cmp(want) != 0 {
fmt.Println(want, n)
t.Error("N was wrong:", n)
}
}

func TestNullDecimal_JSONNil(t *testing.T) {
t.Parallel()

Expand All @@ -201,6 +250,22 @@ func TestNullDecimal_JSONNil(t *testing.T) {
}
}

func TestNullDecimal_TextNil(t *testing.T) {
t.Parallel()

var n NullDecimal
b, _ := n.MarshalText()
if string(b) != `null` {
t.Errorf("want: null, got: %s", b)
}

n2 := new(NullDecimal)
b, _ = n2.MarshalText()
if string(b) != `null` {
t.Errorf("want: null, got: %s", b)
}
}

func TestNullDecimal_IsZero(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 20df20a

Please sign in to comment.