diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 531e77670d4..00351f244b7 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -1715,9 +1715,63 @@ func TestIntegration_BasicTypes(t *testing.T) { {col: "NumericArray", val: []NullNumeric(nil)}, {col: "NumericArray", val: []NullNumeric{}}, {col: "NumericArray", val: []NullNumeric{{n1, true}, {n2, true}, {}}}, + {col: "String", val: nil, want: NullString{}}, + {col: "StringArray", val: nil, want: []NullString(nil)}, + {col: "Bytes", val: nil, want: []byte(nil)}, + {col: "BytesArray", val: nil, want: [][]byte(nil)}, + {col: "Int64a", val: nil, want: NullInt64{}}, + {col: "Int64Array", val: nil, want: []NullInt64(nil)}, + {col: "Bool", val: nil, want: NullBool{}}, + {col: "BoolArray", val: nil, want: []NullBool(nil)}, + {col: "Float64", val: nil, want: NullFloat64{}}, + {col: "Float64Array", val: nil, want: []NullFloat64(nil)}, + {col: "Date", val: nil, want: NullDate{}}, + {col: "DateArray", val: nil, want: []NullDate(nil)}, + {col: "Timestamp", val: nil, want: NullTime{}}, + {col: "TimestampArray", val: nil, want: []NullTime(nil)}, + {col: "Numeric", val: nil, want: NullNumeric{}}, + {col: "NumericArray", val: nil, want: []NullNumeric(nil)}, + } + + // Write rows into table first using DML. Only do this on real Spanner + // as the emulator does not support untyped parameters. + // TODO: Remove when the emulator supports untyped parameters. + if !isEmulatorEnvSet() { + statements := make([]Statement, 0) + for i, test := range tests { + stmt := NewStatement(fmt.Sprintf("INSERT INTO Types (RowId, `%s`) VALUES (@id, @value)", test.col)) + // Note: We are not setting the parameter type here to ensure that it + // can be automatically recognized when it is actually needed. + stmt.Params["id"] = i + stmt.Params["value"] = test.val + statements = append(statements, stmt) + } + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + rowCounts, err := tx.BatchUpdate(ctx, statements) + if err != nil { + return err + } + if len(rowCounts) != len(tests) { + return fmt.Errorf("rowCounts length mismatch\nGot: %v\nWant: %v", len(rowCounts), len(tests)) + } + for i, c := range rowCounts { + if c != 1 { + return fmt.Errorf("row count mismatch for row %v:\nGot: %v\nWant: %v", i, c, 1) + } + } + return nil + }) + if err != nil { + t.Fatalf("failed to insert values using DML: %v", err) + } + // Delete all the rows so we can insert them using mutations as well. + _, err = client.Apply(ctx, []*Mutation{Delete("Types", AllKeys())}) + if err != nil { + t.Fatalf("failed to delete all rows: %v", err) + } } - // Write rows into table first. + // Verify that we can insert the rows using mutations. var muts []*Mutation for i, test := range tests { muts = append(muts, InsertOrUpdate("Types", []string{"RowID", test.col}, []interface{}{i, test.val})) diff --git a/spanner/statement.go b/spanner/statement.go index f875df00f21..bb0ae5a6d5f 100644 --- a/spanner/statement.go +++ b/spanner/statement.go @@ -17,7 +17,6 @@ limitations under the License. package spanner import ( - "errors" "fmt" proto3 "github.com/golang/protobuf/ptypes/struct" @@ -48,11 +47,6 @@ func NewStatement(sql string) Statement { return Statement{SQL: sql, Params: map[string]interface{}{}} } -var ( - errNilParam = errors.New("use T(nil), not nil") - errNoType = errors.New("no type information") -) - // convertParams converts a statement's parameters into proto Param and // ParamTypes. func (s *Statement) convertParams() (*structpb.Struct, map[string]*sppb.Type, error) { @@ -61,18 +55,14 @@ func (s *Statement) convertParams() (*structpb.Struct, map[string]*sppb.Type, er } paramTypes := map[string]*sppb.Type{} for k, v := range s.Params { - if v == nil { - return nil, nil, errBindParam(k, v, errNilParam) - } val, t, err := encodeValue(v) if err != nil { return nil, nil, errBindParam(k, v, err) } - if t == nil { // should not happen, because of nil check above - return nil, nil, errBindParam(k, v, errNoType) - } params.Fields[k] = val - paramTypes[k] = t + if t != nil { + paramTypes[k] = t + } } return params, paramTypes, nil diff --git a/spanner/statement_test.go b/spanner/statement_test.go index 08d49c7a0f3..8243b5e2baa 100644 --- a/spanner/statement_test.go +++ b/spanner/statement_test.go @@ -159,6 +159,12 @@ func TestConvertParams(t *testing.T) { listProto(listProto(intProto(10)), listProto(intProto(20))), listType(structType(mkField("field", intType()))), }, + // Untyped null + { + nil, + nullProto(), + nil, + }, } { st.Params["var"] = test.val gotParams, gotParamTypes, gotErr := st.convertParams() @@ -179,23 +185,6 @@ func TestConvertParams(t *testing.T) { t.Errorf("%#v: got %v, want %v\n", test.val, gotParamType, test.wantField) } } - - // Verify type error reporting. - for _, test := range []struct { - val interface{} - wantErr error - }{ - { - nil, - errBindParam("var", nil, errNilParam), - }, - } { - st.Params["var"] = test.val - _, _, gotErr := st.convertParams() - if !testEqual(gotErr, test.wantErr) { - t.Errorf("value %#v:\ngot: %v\nwant: %v", test.val, gotErr, test.wantErr) - } - } } func TestNewStatement(t *testing.T) {