diff --git a/spanner/batch.go b/spanner/batch.go index 0e38768ae1c..82346bfcebd 100644 --- a/spanner/batch.go +++ b/spanner/batch.go @@ -94,9 +94,25 @@ func (t *BatchReadOnlyTransaction) PartitionRead(ctx context.Context, table stri return t.PartitionReadUsingIndex(ctx, table, "", keys, columns, opt) } +// PartitionReadWithOptions returns a list of Partitions that can be used to +// read rows from the database. These partitions can be executed across multiple +// processes, even across different machines. The partition size and count hints +// can be configured using PartitionOptions. Pass a ReadOptions to modify the +// read operation. +func (t *BatchReadOnlyTransaction) PartitionReadWithOptions(ctx context.Context, table string, keys KeySet, columns []string, opt PartitionOptions, readOptions ReadOptions) ([]*Partition, error) { + return t.PartitionReadUsingIndexWithOptions(ctx, table, "", keys, columns, opt, readOptions) +} + // PartitionReadUsingIndex returns a list of Partitions that can be used to read // rows from the database using an index. func (t *BatchReadOnlyTransaction) PartitionReadUsingIndex(ctx context.Context, table, index string, keys KeySet, columns []string, opt PartitionOptions) ([]*Partition, error) { + return t.PartitionReadUsingIndexWithOptions(ctx, table, index, keys, columns, opt, ReadOptions{}) +} + +// PartitionReadUsingIndexWithOptions returns a list of Partitions that can be +// used to read rows from the database using an index. Pass a ReadOptions to +// modify the read operation. +func (t *BatchReadOnlyTransaction) PartitionReadUsingIndexWithOptions(ctx context.Context, table, index string, keys KeySet, columns []string, opt PartitionOptions, readOptions ReadOptions) ([]*Partition, error) { sh, ts, err := t.acquire(ctx) if err != nil { return nil, err @@ -123,12 +139,13 @@ func (t *BatchReadOnlyTransaction) PartitionReadUsingIndex(ctx context.Context, }) // Prepare ReadRequest. req := &sppb.ReadRequest{ - Session: sid, - Transaction: ts, - Table: table, - Index: index, - Columns: columns, - KeySet: kset, + Session: sid, + Transaction: ts, + Table: table, + Index: index, + Columns: columns, + KeySet: kset, + RequestOptions: createRequestOptions(readOptions.Priority, readOptions.RequestTag, ""), } // Generate partitions. for _, p := range resp.GetPartitions() { @@ -177,12 +194,13 @@ func (t *BatchReadOnlyTransaction) partitionQuery(ctx context.Context, statement // prepare ExecuteSqlRequest r := &sppb.ExecuteSqlRequest{ - Session: sid, - Transaction: ts, - Sql: statement.SQL, - Params: params, - ParamTypes: paramTypes, - QueryOptions: qOpts.Options, + Session: sid, + Transaction: ts, + Sql: statement.SQL, + Params: params, + ParamTypes: paramTypes, + QueryOptions: qOpts.Options, + RequestOptions: createRequestOptions(qOpts.Priority, qOpts.RequestTag, ""), } // generate Partitions @@ -270,6 +288,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R Columns: p.rreq.Columns, KeySet: p.rreq.KeySet, PartitionToken: p.pt, + RequestOptions: p.rreq.RequestOptions, ResumeToken: resumeToken, }) } @@ -283,6 +302,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R ParamTypes: p.qreq.ParamTypes, QueryOptions: p.qreq.QueryOptions, PartitionToken: p.pt, + RequestOptions: p.qreq.RequestOptions, ResumeToken: resumeToken, }) } diff --git a/spanner/client_test.go b/spanner/client_test.go index 099f1825c39..faae43dbe84 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -2665,6 +2665,54 @@ func TestClient_Apply_Tagging(t *testing.T) { checkCommitForExpectedRequestOptions(t, server.TestSpanner, sppb.RequestOptions{TransactionTag: "tx-tag"}) } +func TestClient_PartitionQuery_RequestOptions(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + + for _, qo := range []QueryOptions{ + {}, + {Priority: sppb.RequestOptions_PRIORITY_LOW}, + {RequestTag: "batch-query-tag"}, + {Priority: sppb.RequestOptions_PRIORITY_MEDIUM, RequestTag: "batch-query-with-medium-prio"}, + } { + ctx := context.Background() + txn, _ := client.BatchReadOnlyTransaction(ctx, StrongRead()) + partitions, _ := txn.PartitionQueryWithOptions(ctx, NewStatement(SelectFooFromBar), PartitionOptions{MaxPartitions: 10}, qo) + for _, p := range partitions { + iter := txn.Execute(ctx, p) + iter.Next() + iter.Stop() + } + checkRequestsForExpectedRequestOptions(t, server.TestSpanner, len(partitions), sppb.RequestOptions{RequestTag: qo.RequestTag, Priority: qo.Priority}) + } +} + +func TestClient_PartitionRead_RequestOptions(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + + for _, ro := range []ReadOptions{ + {}, + {Priority: sppb.RequestOptions_PRIORITY_LOW}, + {RequestTag: "batch-read-tag"}, + {Priority: sppb.RequestOptions_PRIORITY_MEDIUM, RequestTag: "batch-read-with-medium-prio"}, + } { + ctx := context.Background() + txn, _ := client.BatchReadOnlyTransaction(ctx, StrongRead()) + partitions, _ := txn.PartitionReadWithOptions(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}, PartitionOptions{MaxPartitions: 10}, ro) + for _, p := range partitions { + iter := txn.Execute(ctx, p) + iter.Next() + iter.Stop() + } + checkRequestsForExpectedRequestOptions(t, server.TestSpanner, len(partitions), sppb.RequestOptions{RequestTag: ro.RequestTag, Priority: ro.Priority}) + } +} + func checkRequestsForExpectedRequestOptions(t *testing.T, server InMemSpannerServer, reqCount int, ro sppb.RequestOptions) { reqs := drainRequestsFromServer(server) reqOptions := []*sppb.RequestOptions{} @@ -2686,12 +2734,10 @@ func checkRequestsForExpectedRequestOptions(t *testing.T, server InMemSpannerSer } for _, opts := range reqOptions { - var got sppb.RequestOptions_Priority - if opts != nil { - got = opts.Priority + if opts == nil { + opts = &sppb.RequestOptions{} } - want := ro.Priority - if got != want { + if got, want := opts.Priority, ro.Priority; got != want { t.Fatalf("Request priority mismatch\nGot: %v\nWant: %v", got, want) } if got, want := opts.RequestTag, ro.RequestTag; got != want { diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index c2b1b8894e4..ac6f9bdfd04 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -1057,14 +1057,17 @@ func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb. } func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { - s.mu.Lock() - if s.stopped { - s.mu.Unlock() - return nil, gstatus.Error(codes.Unavailable, "server has been stopped") - } - s.receivedRequests <- req - s.mu.Unlock() - return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") + return s.PartitionQuery(ctx, &spannerpb.PartitionQueryRequest{ + Session: req.Session, + Transaction: req.Transaction, + PartitionOptions: req.PartitionOptions, + // KeySet is currently ignored. + Sql: fmt.Sprintf( + "SELECT %s FROM %s", + strings.Join(req.Columns, ", "), + req.Table, + ), + }) } // EncodeResumeToken return mock resume token encoding for an uint64 integer.