diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md index 6f820f97e52..d05f9d277c6 100644 --- a/spanner/spannertest/README.md +++ b/spanner/spannertest/README.md @@ -19,16 +19,16 @@ by ascending esotericism: - expression functions - more aggregation functions -- INSERT/UPDATE DML statements - SELECT HAVING - case insensitivity -- FULL JOIN +- FULL JOIN, multiple joins - alternate literal types (esp. strings) - STRUCT types - transaction simulation - expression type casting, coercion - subselects - FOREIGN KEY and CHECK constraints +- INSERT DML statements - set operations (UNION, INTERSECT, EXCEPT) - partition support - conditional expressions diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go index 810bcecae27..d1a9ef20df0 100644 --- a/spanner/spannertest/db.go +++ b/spanner/spannertest/db.go @@ -999,6 +999,65 @@ func (d *database) Execute(stmt spansql.DMLStmt, params queryParams) (int, error i++ } return n, nil + case *spansql.Update: + t, err := d.table(stmt.Table) + if err != nil { + return 0, err + } + + t.mu.Lock() + defer t.mu.Unlock() + + ec := evalContext{ + cols: t.cols, + params: params, + } + + // Build parallel slices of destination column index and expressions to evaluate. + var dstIndex []int + var expr []spansql.Expr + for _, ui := range stmt.Items { + i, err := ec.resolveColumnIndex(ui.Column) + if err != nil { + return 0, err + } + // TODO: Enforce "A column can appear only once in the SET clause.". + if i < t.pkCols { + return 0, status.Errorf(codes.InvalidArgument, "cannot update primary key %s", ui.Column) + } + dstIndex = append(dstIndex, i) + expr = append(expr, ui.Value) + } + + n := 0 + values := make(row, len(stmt.Items)) // scratch space for new values + for i := 0; i < len(t.rows); i++ { + ec.row = t.rows[i] + b, err := ec.evalBoolExpr(stmt.Where) + if err != nil { + return 0, err + } + if b != nil && *b { + // Compute every update item. + for j := range dstIndex { + if expr[j] == nil { // DEFAULT + values[j] = nil + continue + } + v, err := ec.evalExpr(expr[j]) + if err != nil { + return 0, err + } + values[j] = v + } + // Write them to the row. + for j, v := range values { + t.rows[i][dstIndex[j]] = v + } + n++ + } + } + return n, nil } } diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index adbe329808e..f70b01f84b8 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -412,7 +412,7 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { "Staff", "PlayerStats", "JoinA", "JoinB", "JoinC", "JoinD", "JoinE", "JoinF", - "SomeStrings", + "SomeStrings", "Updateable", } errc := make(chan error) for _, table := range allTables { @@ -618,6 +618,11 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { `CREATE TABLE JoinF ( y INT64, z STRING(MAX) ) PRIMARY KEY (y, z)`, // Some other test tables. `CREATE TABLE SomeStrings ( i INT64, str STRING(MAX) ) PRIMARY KEY (i)`, + `CREATE TABLE Updateable ( + id INT64, + first STRING(MAX), + last STRING(MAX), + ) PRIMARY KEY (id)`, ) if err != nil { t.Fatalf("Creating sample tables: %v", err) @@ -661,11 +666,39 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{1, "abar"}), spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{2, nil}), spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{3, "bbar"}), + + spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{0, "joe", nil}), + spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{1, "doe", "joan"}), + spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{2, "wong", "wong"}), }) if err != nil { t.Fatalf("Inserting sample data: %v", err) } + // Perform UPDATE DML; the results are checked later on. + n = 0 + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + for _, u := range []string{ + `UPDATE Updateable SET last = "bloggs" WHERE id = 0`, + `UPDATE Updateable SET first = last, last = first WHERE id = 1`, + `UPDATE Updateable SET last = DEFAULT WHERE id = 2`, + `UPDATE Updateable SET first = "noname" WHERE id = 3`, // no id=3 + } { + nr, err := tx.Update(ctx, spanner.NewStatement(u)) + if err != nil { + return err + } + n += nr + } + return nil + }) + if err != nil { + t.Fatalf("Updating with DML: %v", err) + } + if n != 3 { + t.Errorf("Updating with DML affected %d rows, want 3", n) + } + // Do some complex queries. tests := []struct { q string @@ -976,6 +1009,16 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { {int64(4), nil, "p"}, }, }, + // Check the output of the UPDATE DML. + { + `SELECT id, first, last FROM Updateable ORDER BY id`, + nil, + [][]interface{}{ + {int64(0), "joe", "bloggs"}, + {int64(1), "joan", "doe"}, + {int64(2), "wong", nil}, + }, + }, // Regression test for aggregating no rows; it used to return an empty row. // https://github.com/googleapis/google-cloud-go/issues/2793 {