From 8eede8411a5521f45a5c3f8091c42b3c5407ea90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 4 Feb 2021 09:04:03 +0100 Subject: [PATCH] fix(spanner/spannertest): support queries in ExecuteSql (#3640) Normal queries from the Spanner client use the ExecuteStreamingSql method, while DML statements use ExecuteSql. This distinction was also built into spannertest where ExecuteSql would only support DML statements and required a transaction to be specified. The session pool however uses ExecuteSql to execute a simple `SELECT 1` query without specifying any transaction. This would cause a nil pointer dereference. This PR introduces support for queries in the ExecuteSql method. The current logic assumes that the statement is a query if the transaction is a single- use read-only transaction. Fixes #3639 --- spanner/spannertest/inmem.go | 93 +++++++++++++++++++------ spanner/spannertest/integration_test.go | 52 ++++++++++++-- 2 files changed, 119 insertions(+), 26 deletions(-) diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go index eb4103c857e..41705abd12f 100644 --- a/spanner/spannertest/inmem.go +++ b/spanner/spannertest/inmem.go @@ -485,9 +485,19 @@ func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.Tra } func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { - // Assume this is probably a DML statement. Queries tend to use ExecuteStreamingSql. + // Assume this is probably a DML statement or a ping from the session pool. + // Queries normally use ExecuteStreamingSql. // TODO: Expand this to support more things. + // If it is a single-use transaction we assume it is a query. + if req.Transaction.GetSelector() == nil || req.Transaction.GetSingleUse().GetReadOnly() != nil { + ri, err := s.executeQuery(req) + if err != nil { + return nil, err + } + return s.resultSet(ri) + } + obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id) if !ok { return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector) @@ -527,15 +537,23 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp } defer cleanup() + ri, err := s.executeQuery(req) + if err != nil { + return err + } + return s.readStream(stream.Context(), tx, stream.Send, ri) +} + +func (s *server) executeQuery(req *spannerpb.ExecuteSqlRequest) (ri rowIter, err error) { q, err := spansql.ParseQuery(req.Sql) if err != nil { // TODO: check what code the real Spanner returns here. - return status.Errorf(codes.InvalidArgument, "bad query: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "bad query: %v", err) } params, err := parseQueryParams(req.GetParams(), req.ParamTypes) if err != nil { - return err + return nil, err } s.logf("Querying: %s", q.SQL()) @@ -543,11 +561,7 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp s.logf(" ▹ %v", params) } - ri, err := s.db.Query(q, params) - if err != nil { - return err - } - return s.readStream(stream.Context(), tx, stream.Send, ri) + return s.db.Query(q, params) } // TODO: Read @@ -591,21 +605,39 @@ func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Span return s.readStream(stream.Context(), tx, stream.Send, ri) } -func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error { - // Build the result set metadata. - rsm := &spannerpb.ResultSetMetadata{ - RowType: &spannerpb.StructType{}, - // TODO: transaction info? +func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) { + rsm, err := s.buildResultSetMetadata(ri) + if err != nil { + return nil, err } - for _, ci := range ri.Cols() { - st, err := spannerTypeFromType(ci.Type) - if err != nil { - return err + rs := &spannerpb.ResultSet{ + Metadata: rsm, + } + for { + row, err := ri.Next() + if err == io.EOF { + break + } else if err != nil { + return nil, err } - rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{ - Name: string(ci.Name), - Type: st, - }) + + values := make([]*structpb.Value, len(row)) + for i, x := range row { + v, err := spannerValueFromValue(x) + if err != nil { + return nil, err + } + values[i] = v + } + rs.Rows = append(rs.Rows, &structpb.ListValue{Values: values}) + } + return rs, nil +} + +func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error { + rsm, err := s.buildResultSetMetadata(ri) + if err != nil { + return err } for { @@ -640,6 +672,25 @@ func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spa return nil } +func (s *server) buildResultSetMetadata(ri rowIter) (*spannerpb.ResultSetMetadata, error) { + // Build the result set metadata. + rsm := &spannerpb.ResultSetMetadata{ + RowType: &spannerpb.StructType{}, + // TODO: transaction info? + } + for _, ci := range ri.Cols() { + st, err := spannerTypeFromType(ci.Type) + if err != nil { + return nil, err + } + rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{ + Name: string(ci.Name), + Type: st, + }) + } + return rsm, nil +} + func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { //s.logf("BeginTransaction(%v)", req) diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 3ab6a5821fc..6c20fed33e0 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -36,6 +36,7 @@ import ( "cloud.google.com/go/civil" "cloud.google.com/go/spanner" dbadmin "cloud.google.com/go/spanner/admin/database/apiv1" + v1 "cloud.google.com/go/spanner/apiv1" "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/grpc" @@ -56,7 +57,7 @@ func dbName() string { return "projects/fake-proj/instances/fake-instance/databases/fake-db" } -func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, func()) { +func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, *v1.Client, func()) { // Despite the docs, this context is also used for auth, // so it needs to be long-lived. ctx := context.Background() @@ -73,7 +74,13 @@ func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, fu client.Close() t.Fatalf("Connecting DB admin client: %v", err) } - return client, adminClient, func() { client.Close(); adminClient.Close() } + gapicClient, err := v1.NewClient(ctx, dialOpt) + if err != nil { + client.Close() + adminClient.Close() + t.Fatalf("Connecting Spanner generated client: %v", err) + } + return client, adminClient, gapicClient, func() { client.Close(); adminClient.Close(); gapicClient.Close() } } // Don't use SPANNER_EMULATOR_HOST because we need the raw connection for @@ -102,16 +109,23 @@ func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, fu srv.Close() t.Fatalf("Connecting to in-memory fake DB admin: %v", err) } - return client, adminClient, func() { + gapicClient, err := v1.NewClient(ctx, option.WithGRPCConn(conn)) + if err != nil { + srv.Close() + t.Fatalf("Connecting to in-memory fake generated Spanner client: %v", err) + } + + return client, adminClient, gapicClient, func() { client.Close() adminClient.Close() + gapicClient.Close() conn.Close() srv.Close() } } func TestIntegration_SpannerBasics(t *testing.T) { - client, adminClient, cleanup := makeClient(t) + client, adminClient, generatedClient, cleanup := makeClient(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) @@ -137,6 +151,34 @@ func TestIntegration_SpannerBasics(t *testing.T) { } it.Stop() + // Try to execute the equivalent of a session pool ping. + // This used to cause a panic as ExecuteSql did not expect any requests + // that would execute a query without a transaction selector. + // https://github.com/googleapis/google-cloud-go/issues/3639 + s, err := generatedClient.CreateSession(ctx, &spannerpb.CreateSessionRequest{Database: dbName()}) + if err != nil { + t.Fatalf("Creating session: %v", err) + } + rs, err := generatedClient.ExecuteSql(ctx, &spannerpb.ExecuteSqlRequest{ + Session: s.Name, + Sql: "SELECT 1", + }) + if err != nil { + t.Fatalf("Executing ping: %v", err) + } + if len(rs.Rows) != 1 { + t.Fatalf("Ping gave %v rows, want 1", len(rs.Rows)) + } + if len(rs.Rows[0].Values) != 1 { + t.Fatalf("Ping gave %v cols, want 1", len(rs.Rows[0].Values)) + } + if rs.Rows[0].Values[0].GetStringValue() != "1" { + t.Fatalf("Ping gave value %v, want '1'", rs.Rows[0].Values[0].GetStringValue()) + } + if err = generatedClient.DeleteSession(ctx, &spannerpb.DeleteSessionRequest{Name: s.Name}); err != nil { + t.Fatalf("Deleting session: %v", err) + } + // Drop any previous test table/index, and make a fresh one in a few stages. const tableName = "Characters" err = updateDDL(t, adminClient, "DROP INDEX AgeIndex") @@ -400,7 +442,7 @@ func TestIntegration_SpannerBasics(t *testing.T) { } func TestIntegration_ReadsAndQueries(t *testing.T) { - client, adminClient, cleanup := makeClient(t) + client, adminClient, _, cleanup := makeClient(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)