diff --git a/spanner/value.go b/spanner/value.go index 4f0ca8cfc37..02d67ccaa19 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -24,6 +24,7 @@ import ( "math/big" "reflect" "strconv" + "strings" "time" "cloud.google.com/go/civil" @@ -49,6 +50,26 @@ const ( NumericScaleDigits = 9 ) +// LossOfPrecisionHandlingOption describes the option to deal with loss of +// precision on numeric values. +type LossOfPrecisionHandlingOption int + +const ( + // NumericRound automatically rounds a numeric value that has a higher + // precision than what is supported by Spanner, e.g., 0.1234567895 rounds + // to 0.123456790. + NumericRound LossOfPrecisionHandlingOption = iota + // NumericError returns an error for numeric values that have a higher + // precision than what is supported by Spanner. E.g. the client returns an + // error if the application tries to insert the value 0.1234567895. + NumericError +) + +// LossOfPrecisionHandling configures how to deal with loss of precision on +// numeric values. The value of this configuration is global and will be used +// for all Spanner clients. +var LossOfPrecisionHandling LossOfPrecisionHandlingOption + // NumericString returns a string representing a *big.Rat in a format compatible // with Spanner SQL. It returns a floating-point literal with 9 digits after the // decimal point. @@ -56,6 +77,29 @@ func NumericString(r *big.Rat) string { return r.FloatString(NumericScaleDigits) } +// validateNumeric returns nil if there are no errors. It will return an error +// when the numeric number is not valid. +func validateNumeric(r *big.Rat) error { + if r == nil { + return nil + } + // Add one more digit to the scale component to find out if there are more + // digits than required. + strRep := r.FloatString(NumericScaleDigits + 1) + strRep = strings.TrimRight(strRep, "0") + strRep = strings.TrimLeft(strRep, "-") + s := strings.Split(strRep, ".") + whole := s[0] + scale := s[1] + if len(scale) > NumericScaleDigits { + return fmt.Errorf("max scale for a numeric is %d. The requested numeric has more", NumericScaleDigits) + } + if len(whole) > NumericPrecisionDigits-NumericScaleDigits { + return fmt.Errorf("max precision for the whole component of a numeric is %d. The requested numeric has a whole component with precision %d", NumericPrecisionDigits-NumericScaleDigits, len(whole)) + } + return nil +} + var ( // CommitTimestamp is a special value used to tell Cloud Spanner to insert // the commit timestamp of the transaction into a column. It can be used in @@ -2671,6 +2715,15 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } pt = listType(floatType()) case big.Rat: + switch LossOfPrecisionHandling { + case NumericError: + err = validateNumeric(&v) + if err != nil { + return nil, nil, err + } + case NumericRound: + // pass + } pb.Kind = stringKind(NumericString(&v)) pt = numericType() case []big.Rat: @@ -2695,6 +2748,15 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } pt = listType(numericType()) case *big.Rat: + switch LossOfPrecisionHandling { + case NumericError: + err = validateNumeric(v) + if err != nil { + return nil, nil, err + } + case NumericRound: + // pass + } if v != nil { pb.Kind = stringKind(NumericString(v)) } diff --git a/spanner/value_test.go b/spanner/value_test.go index 5dccbe9d738..a689b779e64 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -228,6 +228,8 @@ func TestEncodeValue(t *testing.T) { numValuePtr := big.NewRat(12345, 1e3) var numNilPtr *big.Rat num2ValuePtr := big.NewRat(12345, 1e4) + maxNumValuePtr, _ := (&big.Rat{}).SetString("99999999999999999999999999999.999999999") + minNumValuePtr, _ := (&big.Rat{}).SetString("-99999999999999999999999999999.999999999") var ( tString = stringType() @@ -296,6 +298,8 @@ func TestEncodeValue(t *testing.T) { // NUMERIC / NUMERIC ARRAY {*numValuePtr, numericProto(numValuePtr), tNumeric, "big.Rat"}, {numValuePtr, numericProto(numValuePtr), tNumeric, "*big.Rat"}, + {maxNumValuePtr, numericProto(maxNumValuePtr), tNumeric, "max numeric"}, + {minNumValuePtr, numericProto(minNumValuePtr), tNumeric, "min numeric"}, {numNilPtr, nullProto(), tNumeric, "*big.Rat with null"}, {NullNumeric{*numValuePtr, true}, numericProto(numValuePtr), tNumeric, "NullNumeric with value"}, {NullNumeric{*numValuePtr, false}, nullProto(), tNumeric, "NullNumeric with null"}, @@ -429,6 +433,40 @@ func TestEncodeValue(t *testing.T) { } } +// Test encoding invalid values. +func TestEncodeInvalidValues(t *testing.T) { + type CustomNumeric big.Rat + + invalidNumPtr1 := big.NewRat(11234567891, 1e10) + invalidNumPtr2, _ := (&big.Rat{}).SetString("199999999999999999999999999999.999999999") + + // Enable error mode. + LossOfPrecisionHandling = NumericError + + for i, test := range []struct { + desc string + in interface{} + errMsg string + }{ + // NUMERIC + {desc: "numeric pointer with invalid scale component", in: invalidNumPtr1, errMsg: "max scale for a numeric is 9. The requested numeric has more"}, + {desc: "numeric pointer with invalid whole component", in: invalidNumPtr2, errMsg: "max precision for the whole component of a numeric is 29. The requested numeric has a whole component with precision 30"}, + {desc: "numeric with invalid scale component", in: *invalidNumPtr1, errMsg: "max scale for a numeric is 9. The requested numeric has more"}, + {desc: "numeric with invalid whole component", in: *invalidNumPtr2, errMsg: "max precision for the whole component of a numeric is 29. The requested numeric has a whole component with precision 30"}, + // CUSTOM NUMERIC + {desc: "custom numeric type with invalid scale component", in: CustomNumeric(*invalidNumPtr1), errMsg: "max scale for a numeric is 9. The requested numeric has more"}, + {desc: "custom numeric type with invalid whole component", in: CustomNumeric(*invalidNumPtr2), errMsg: "max precision for the whole component of a numeric is 29. The requested numeric has a whole component with precision 30"}, + } { + _, _, err := encodeValue(test.in) + if err == nil { + t.Fatalf("#%d (%s): want error during encoding, but got nil", i, test.desc) + } + if err.Error() != test.errMsg { + t.Errorf("#%d (%s): incorrect error message, got %v, want %v", i, test.desc, err, test.errMsg) + } + } +} + type encodeTest struct { desc string in interface{}