diff --git a/bigquery/storage/managedwriter/appendresult.go b/bigquery/storage/managedwriter/appendresult.go index 2570d80d7a0..0854a407385 100644 --- a/bigquery/storage/managedwriter/appendresult.go +++ b/bigquery/storage/managedwriter/appendresult.go @@ -106,7 +106,7 @@ func newPendingWrite(appends [][]byte, offset int64) *pendingWrite { // markDone propagates finalization of an append request to associated // AppendResult references. -func (pw *pendingWrite) markDone(startOffset int64, err error) { +func (pw *pendingWrite) markDone(startOffset int64, err error, fc *flowController) { curOffset := startOffset for _, ar := range pw.results { if err != nil { @@ -124,4 +124,9 @@ func (pw *pendingWrite) markDone(startOffset int64, err error) { } // Clear the reference to the request. pw.request = nil + // if there's a flow controller, signal release. The only time this should be nil is when + // encountering issues with flow control during enqueuing the initial request. + if fc != nil { + fc.release(pw.reqSize) + } } diff --git a/bigquery/storage/managedwriter/appendresult_test.go b/bigquery/storage/managedwriter/appendresult_test.go index c8beeafa734..7ef86ea995f 100644 --- a/bigquery/storage/managedwriter/appendresult_test.go +++ b/bigquery/storage/managedwriter/appendresult_test.go @@ -45,7 +45,7 @@ func TestPendingWrite(t *testing.T) { if pending.request.GetOffset() != nil { t.Errorf("request should have no offset, but is present: %q", pending.request.GetOffset().GetValue()) } - pending.markDone(NoStreamOffset, nil) + pending.markDone(NoStreamOffset, nil, nil) for k, ar := range pending.results { if ar.offset != NoStreamOffset { t.Errorf("mismatch on completed AppendResult(%d) without offset: got %d want %d", k, ar.offset, NoStreamOffset) @@ -90,7 +90,7 @@ func TestPendingWrite(t *testing.T) { // verify completion behavior reportedOffset := int64(101) wantErr := fmt.Errorf("foo") - pending.markDone(reportedOffset, wantErr) + pending.markDone(reportedOffset, wantErr, nil) if pending.request != nil { t.Errorf("expected request to be cleared, is present: %#v", pending.request) diff --git a/bigquery/storage/managedwriter/client.go b/bigquery/storage/managedwriter/client.go index d8f60ed933b..4e6b66297b5 100644 --- a/bigquery/storage/managedwriter/client.go +++ b/bigquery/storage/managedwriter/client.go @@ -119,10 +119,13 @@ func (c *Client) buildManagedStream(ctx context.Context, streamFunc streamClient streamName = resp.GetName() } ms.streamSettings.streamID = streamName - // TODO(followup CLs): instantiate an appendstream client, flow controller, etc. } } - + if ms.streamSettings != nil { + ms.fc = newFlowController(ms.streamSettings.MaxInflightRequests, ms.streamSettings.MaxInflightBytes) + } else { + ms.fc = newFlowController(0, 0) + } return ms, nil } diff --git a/bigquery/storage/managedwriter/integration_test.go b/bigquery/storage/managedwriter/integration_test.go index 4b276b2e7aa..843d4f90db3 100644 --- a/bigquery/storage/managedwriter/integration_test.go +++ b/bigquery/storage/managedwriter/integration_test.go @@ -177,7 +177,7 @@ func TestIntegration_ManagedWriter_BasicOperation(t *testing.T) { t.Errorf("failed to marshal message %d: %v", k, err) } data := [][]byte{b} - results, err = ms.AppendRows(data, NoStreamOffset) + results, err = ms.AppendRows(ctx, data, NoStreamOffset) if err != nil { t.Errorf("single-row append %d failed: %v", k, err) } @@ -195,7 +195,7 @@ func TestIntegration_ManagedWriter_BasicOperation(t *testing.T) { t.Errorf("failed to marshal message %d: %v", k, err) } data := append(data, b) - results, err = ms.AppendRows(data, NoStreamOffset) + results, err = ms.AppendRows(ctx, data, NoStreamOffset) if err != nil { t.Errorf("grouped-row append failed: %v", err) } diff --git a/bigquery/storage/managedwriter/managed_stream.go b/bigquery/storage/managedwriter/managed_stream.go index 296e8b491fd..e1a526a6051 100644 --- a/bigquery/storage/managedwriter/managed_stream.go +++ b/bigquery/storage/managedwriter/managed_stream.go @@ -71,6 +71,7 @@ type ManagedStream struct { schemaDescriptor *descriptorpb.DescriptorProto destinationTable string c *Client + fc *flowController // aspects of the stream client ctx context.Context // retained context for the stream @@ -204,7 +205,7 @@ func (ms *ManagedStream) openWithRetry() (storagepb.BigQueryWrite_AppendRowsClie // The channel relationship with its ARC is 1:1. If we get a new ARC, create a new chan // and fire up the associated receive processor. ch := make(chan *pendingWrite) - go recvProcessor(ms.ctx, arc, ch) + go recvProcessor(ms.ctx, arc, ms.fc, ch) // Also, replace the sync.Once for setting up a new stream, as we need to do "special" work // for every new connection. ms.streamSetup = new(sync.Once) @@ -293,11 +294,17 @@ func (ms *ManagedStream) Close() error { } // AppendRows sends the append requests to the service, and returns one AppendResult per row. -func (ms *ManagedStream) AppendRows(data [][]byte, offset int64) ([]*AppendResult, error) { +func (ms *ManagedStream) AppendRows(ctx context.Context, data [][]byte, offset int64) ([]*AppendResult, error) { pw := newPendingWrite(data, offset) + // check flow control + if err := ms.fc.acquire(ctx, pw.reqSize); err != nil { + // in this case, we didn't acquire, so don't pass the flow controller reference to avoid a release. + pw.markDone(NoStreamOffset, err, nil) + } + // proceed to call if err := ms.append(pw); err != nil { - // pending write is DOA, mark it done. - pw.markDone(NoStreamOffset, err) + // pending write is DOA. + pw.markDone(NoStreamOffset, err, ms.fc) return nil, err } return pw.results, nil @@ -307,7 +314,7 @@ func (ms *ManagedStream) AppendRows(data [][]byte, offset int64) ([]*AppendResul // // The receive processor only deals with a single instance of a connection/channel, and thus should never interact // with the mutex lock. -func recvProcessor(ctx context.Context, arc storagepb.BigQueryWrite_AppendRowsClient, ch <-chan *pendingWrite) { +func recvProcessor(ctx context.Context, arc storagepb.BigQueryWrite_AppendRowsClient, fc *flowController, ch <-chan *pendingWrite) { // TODO: We'd like to re-send requests that are in an ambiguous state due to channel errors. For now, we simply // ensure that pending writes get acknowledged with a terminal state. for { @@ -319,7 +326,7 @@ func recvProcessor(ctx context.Context, arc storagepb.BigQueryWrite_AppendRowsCl if !ok { return } - pw.markDone(NoStreamOffset, ctx.Err()) + pw.markDone(NoStreamOffset, ctx.Err(), fc) } case nextWrite, ok := <-ch: if !ok { @@ -330,19 +337,20 @@ func recvProcessor(ctx context.Context, arc storagepb.BigQueryWrite_AppendRowsCl // block until we get a corresponding response or err from stream. resp, err := arc.Recv() if err != nil { - nextWrite.markDone(NoStreamOffset, err) + nextWrite.markDone(NoStreamOffset, err, fc) } if status := resp.GetError(); status != nil { - nextWrite.markDone(NoStreamOffset, grpcstatus.ErrorProto(status)) + fc.release(nextWrite.reqSize) + nextWrite.markDone(NoStreamOffset, grpcstatus.ErrorProto(status), fc) continue } success := resp.GetAppendResult() off := success.GetOffset() if off != nil { - nextWrite.markDone(off.GetValue(), nil) + nextWrite.markDone(off.GetValue(), nil, fc) } - nextWrite.markDone(NoStreamOffset, nil) + nextWrite.markDone(NoStreamOffset, nil, fc) } } } diff --git a/bigquery/storage/managedwriter/managed_stream_test.go b/bigquery/storage/managedwriter/managed_stream_test.go index ba9f12eb5aa..fac9139ab20 100644 --- a/bigquery/storage/managedwriter/managed_stream_test.go +++ b/bigquery/storage/managedwriter/managed_stream_test.go @@ -87,6 +87,8 @@ func TestManagedStream_OpenWithRetry(t *testing.T) { func TestManagedStream_FirstAppendBehavior(t *testing.T) { + ctx := context.Background() + var testARC *testAppendRowsClient testARC = &testAppendRowsClient{ recvF: func() (*storagepb.AppendRowsResponse, error) { @@ -104,12 +106,13 @@ func TestManagedStream_FirstAppendBehavior(t *testing.T) { } ms := &ManagedStream{ - ctx: context.Background(), + ctx: ctx, open: func() (storagepb.BigQueryWrite_AppendRowsClient, error) { testARC.openCount = testARC.openCount + 1 return testARC, nil }, streamSettings: defaultStreamSettings(), + fc: newFlowController(0, 0), } ms.streamSettings.streamID = "FOO" ms.streamSettings.TracePrefix = "TRACE" @@ -123,7 +126,7 @@ func TestManagedStream_FirstAppendBehavior(t *testing.T) { wantReqs := 3 for i := 0; i < wantReqs; i++ { - _, err := ms.AppendRows(fakeData, NoStreamOffset) + _, err := ms.AppendRows(ctx, fakeData, NoStreamOffset) if err != nil { t.Errorf("AppendRows; %v", err) }