diff --git a/bigquery/storage/managedwriter/client.go b/bigquery/storage/managedwriter/client.go index 014cd7dcd29..847aca65abe 100644 --- a/bigquery/storage/managedwriter/client.go +++ b/bigquery/storage/managedwriter/client.go @@ -25,6 +25,7 @@ import ( "google.golang.org/api/option" storagepb "google.golang.org/genproto/googleapis/cloud/bigquery/storage/v1beta2" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) // Client is a managed BigQuery Storage write client scoped to a single project. @@ -81,8 +82,11 @@ func (c *Client) buildManagedStream(ctx context.Context, streamFunc streamClient c: c, ctx: ctx, cancel: cancel, - open: func() (storagepb.BigQueryWrite_AppendRowsClient, error) { - arc, err := streamFunc(ctx, gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(10*1024*1024))) + open: func(streamID string) (storagepb.BigQueryWrite_AppendRowsClient, error) { + arc, err := streamFunc( + // Bidi Streaming doesn't append stream ID as request metadata, so we must inject it manually. + metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", fmt.Sprintf("write_stream=%s", streamID)), + gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(10*1024*1024))) if err != nil { return nil, err } diff --git a/bigquery/storage/managedwriter/integration_test.go b/bigquery/storage/managedwriter/integration_test.go index cfc60b455b1..b685e64440f 100644 --- a/bigquery/storage/managedwriter/integration_test.go +++ b/bigquery/storage/managedwriter/integration_test.go @@ -82,9 +82,9 @@ func getTestClients(ctx context.Context, t *testing.T, opts ...option.ClientOpti } // setupTestDataset generates a unique dataset for testing, and a cleanup that can be deferred. -func setupTestDataset(ctx context.Context, t *testing.T, bqc *bigquery.Client) (ds *bigquery.Dataset, cleanup func(), err error) { +func setupTestDataset(ctx context.Context, t *testing.T, bqc *bigquery.Client, location string) (ds *bigquery.Dataset, cleanup func(), err error) { dataset := bqc.Dataset(datasetIDs.New()) - if err := dataset.Create(ctx, nil); err != nil { + if err := dataset.Create(ctx, &bigquery.DatasetMetadata{Location: location}); err != nil { return nil, nil, err } return dataset, func() { @@ -117,7 +117,7 @@ func TestIntegration_ManagedWriter(t *testing.T) { defer mwClient.Close() defer bqClient.Close() - dataset, cleanup, err := setupTestDataset(context.Background(), t, bqClient) + dataset, cleanup, err := setupTestDataset(context.Background(), t, bqClient, "us-east1") if err != nil { t.Fatalf("failed to init test dataset: %v", err) } diff --git a/bigquery/storage/managedwriter/managed_stream.go b/bigquery/storage/managedwriter/managed_stream.go index e29d54395c4..1064d49314c 100644 --- a/bigquery/storage/managedwriter/managed_stream.go +++ b/bigquery/storage/managedwriter/managed_stream.go @@ -79,7 +79,7 @@ type ManagedStream struct { // aspects of the stream client ctx context.Context // retained context for the stream cancel context.CancelFunc - open func() (storagepb.BigQueryWrite_AppendRowsClient, error) // how we get a new connection + open func(streamID string) (storagepb.BigQueryWrite_AppendRowsClient, error) // how we get a new connection mu sync.Mutex arc *storagepb.BigQueryWrite_AppendRowsClient // current stream connection @@ -198,7 +198,11 @@ func (ms *ManagedStream) openWithRetry() (storagepb.BigQueryWrite_AppendRowsClie r := defaultRetryer{} for { recordStat(ms.ctx, AppendClientOpenCount, 1) - arc, err := ms.open() + streamID := "" + if ms.streamSettings != nil { + streamID = ms.streamSettings.streamID + } + arc, err := ms.open(streamID) bo, shouldRetry := r.Retry(err) if err != nil && shouldRetry { recordStat(ms.ctx, AppendClientOpenRetryCount, 1) diff --git a/bigquery/storage/managedwriter/managed_stream_test.go b/bigquery/storage/managedwriter/managed_stream_test.go index 23976f0b1aa..9df35f919f2 100644 --- a/bigquery/storage/managedwriter/managed_stream_test.go +++ b/bigquery/storage/managedwriter/managed_stream_test.go @@ -55,7 +55,7 @@ func TestManagedStream_OpenWithRetry(t *testing.T) { for _, tc := range testCases { ms := &ManagedStream{ ctx: context.Background(), - open: func() (storagepb.BigQueryWrite_AppendRowsClient, error) { + open: func(s string) (storagepb.BigQueryWrite_AppendRowsClient, error) { if len(tc.errors) == 0 { panic("out of errors") } @@ -107,7 +107,7 @@ func TestManagedStream_FirstAppendBehavior(t *testing.T) { ms := &ManagedStream{ ctx: ctx, - open: func() (storagepb.BigQueryWrite_AppendRowsClient, error) { + open: func(s string) (storagepb.BigQueryWrite_AppendRowsClient, error) { testARC.openCount = testARC.openCount + 1 return testARC, nil },