diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go index eb4103c857e..41705abd12f 100644 --- a/spanner/spannertest/inmem.go +++ b/spanner/spannertest/inmem.go @@ -485,9 +485,19 @@ func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.Tra } func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { - // Assume this is probably a DML statement. Queries tend to use ExecuteStreamingSql. + // Assume this is probably a DML statement or a ping from the session pool. + // Queries normally use ExecuteStreamingSql. // TODO: Expand this to support more things. + // If it is a single-use transaction we assume it is a query. + if req.Transaction.GetSelector() == nil || req.Transaction.GetSingleUse().GetReadOnly() != nil { + ri, err := s.executeQuery(req) + if err != nil { + return nil, err + } + return s.resultSet(ri) + } + obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id) if !ok { return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector) @@ -527,15 +537,23 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp } defer cleanup() + ri, err := s.executeQuery(req) + if err != nil { + return err + } + return s.readStream(stream.Context(), tx, stream.Send, ri) +} + +func (s *server) executeQuery(req *spannerpb.ExecuteSqlRequest) (ri rowIter, err error) { q, err := spansql.ParseQuery(req.Sql) if err != nil { // TODO: check what code the real Spanner returns here. - return status.Errorf(codes.InvalidArgument, "bad query: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "bad query: %v", err) } params, err := parseQueryParams(req.GetParams(), req.ParamTypes) if err != nil { - return err + return nil, err } s.logf("Querying: %s", q.SQL()) @@ -543,11 +561,7 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp s.logf(" ▹ %v", params) } - ri, err := s.db.Query(q, params) - if err != nil { - return err - } - return s.readStream(stream.Context(), tx, stream.Send, ri) + return s.db.Query(q, params) } // TODO: Read @@ -591,21 +605,39 @@ func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Span return s.readStream(stream.Context(), tx, stream.Send, ri) } -func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error { - // Build the result set metadata. - rsm := &spannerpb.ResultSetMetadata{ - RowType: &spannerpb.StructType{}, - // TODO: transaction info? +func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) { + rsm, err := s.buildResultSetMetadata(ri) + if err != nil { + return nil, err } - for _, ci := range ri.Cols() { - st, err := spannerTypeFromType(ci.Type) - if err != nil { - return err + rs := &spannerpb.ResultSet{ + Metadata: rsm, + } + for { + row, err := ri.Next() + if err == io.EOF { + break + } else if err != nil { + return nil, err } - rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{ - Name: string(ci.Name), - Type: st, - }) + + values := make([]*structpb.Value, len(row)) + for i, x := range row { + v, err := spannerValueFromValue(x) + if err != nil { + return nil, err + } + values[i] = v + } + rs.Rows = append(rs.Rows, &structpb.ListValue{Values: values}) + } + return rs, nil +} + +func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error { + rsm, err := s.buildResultSetMetadata(ri) + if err != nil { + return err } for { @@ -640,6 +672,25 @@ func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spa return nil } +func (s *server) buildResultSetMetadata(ri rowIter) (*spannerpb.ResultSetMetadata, error) { + // Build the result set metadata. + rsm := &spannerpb.ResultSetMetadata{ + RowType: &spannerpb.StructType{}, + // TODO: transaction info? + } + for _, ci := range ri.Cols() { + st, err := spannerTypeFromType(ci.Type) + if err != nil { + return nil, err + } + rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{ + Name: string(ci.Name), + Type: st, + }) + } + return rsm, nil +} + func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { //s.logf("BeginTransaction(%v)", req) diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 3ab6a5821fc..6c20fed33e0 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -36,6 +36,7 @@ import ( "cloud.google.com/go/civil" "cloud.google.com/go/spanner" dbadmin "cloud.google.com/go/spanner/admin/database/apiv1" + v1 "cloud.google.com/go/spanner/apiv1" "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/grpc" @@ -56,7 +57,7 @@ func dbName() string { return "projects/fake-proj/instances/fake-instance/databases/fake-db" } -func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, func()) { +func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, *v1.Client, func()) { // Despite the docs, this context is also used for auth, // so it needs to be long-lived. ctx := context.Background() @@ -73,7 +74,13 @@ func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, fu client.Close() t.Fatalf("Connecting DB admin client: %v", err) } - return client, adminClient, func() { client.Close(); adminClient.Close() } + gapicClient, err := v1.NewClient(ctx, dialOpt) + if err != nil { + client.Close() + adminClient.Close() + t.Fatalf("Connecting Spanner generated client: %v", err) + } + return client, adminClient, gapicClient, func() { client.Close(); adminClient.Close(); gapicClient.Close() } } // Don't use SPANNER_EMULATOR_HOST because we need the raw connection for @@ -102,16 +109,23 @@ func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, fu srv.Close() t.Fatalf("Connecting to in-memory fake DB admin: %v", err) } - return client, adminClient, func() { + gapicClient, err := v1.NewClient(ctx, option.WithGRPCConn(conn)) + if err != nil { + srv.Close() + t.Fatalf("Connecting to in-memory fake generated Spanner client: %v", err) + } + + return client, adminClient, gapicClient, func() { client.Close() adminClient.Close() + gapicClient.Close() conn.Close() srv.Close() } } func TestIntegration_SpannerBasics(t *testing.T) { - client, adminClient, cleanup := makeClient(t) + client, adminClient, generatedClient, cleanup := makeClient(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) @@ -137,6 +151,34 @@ func TestIntegration_SpannerBasics(t *testing.T) { } it.Stop() + // Try to execute the equivalent of a session pool ping. + // This used to cause a panic as ExecuteSql did not expect any requests + // that would execute a query without a transaction selector. + // https://github.com/googleapis/google-cloud-go/issues/3639 + s, err := generatedClient.CreateSession(ctx, &spannerpb.CreateSessionRequest{Database: dbName()}) + if err != nil { + t.Fatalf("Creating session: %v", err) + } + rs, err := generatedClient.ExecuteSql(ctx, &spannerpb.ExecuteSqlRequest{ + Session: s.Name, + Sql: "SELECT 1", + }) + if err != nil { + t.Fatalf("Executing ping: %v", err) + } + if len(rs.Rows) != 1 { + t.Fatalf("Ping gave %v rows, want 1", len(rs.Rows)) + } + if len(rs.Rows[0].Values) != 1 { + t.Fatalf("Ping gave %v cols, want 1", len(rs.Rows[0].Values)) + } + if rs.Rows[0].Values[0].GetStringValue() != "1" { + t.Fatalf("Ping gave value %v, want '1'", rs.Rows[0].Values[0].GetStringValue()) + } + if err = generatedClient.DeleteSession(ctx, &spannerpb.DeleteSessionRequest{Name: s.Name}); err != nil { + t.Fatalf("Deleting session: %v", err) + } + // Drop any previous test table/index, and make a fresh one in a few stages. const tableName = "Characters" err = updateDDL(t, adminClient, "DROP INDEX AgeIndex") @@ -400,7 +442,7 @@ func TestIntegration_SpannerBasics(t *testing.T) { } func TestIntegration_ReadsAndQueries(t *testing.T) { - client, adminClient, cleanup := makeClient(t) + client, adminClient, _, cleanup := makeClient(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)