Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pubsublite): prevent subscriber flow control token races #4060

Merged
merged 4 commits into from May 5, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 12 additions & 2 deletions pubsublite/internal/test/verifier.go
Expand Up @@ -48,8 +48,8 @@ func newBarrier() *Barrier {
}
}

// Release should be called by the test.
func (b *Barrier) Release() {
// ReleaseAfter releases the barrier, after invoking f provided by the test.
func (b *Barrier) ReleaseAfter(f func()) {
// Wait for the server to reach the barrier.
select {
case <-time.After(blockWaitTimeout):
Expand All @@ -58,10 +58,20 @@ func (b *Barrier) Release() {
case <-b.serverBlock:
}

// Run any test-specific code.
if f != nil {
f()
}

// Then close the client block.
close(b.clientBlock)
}

// Release should be called by the test.
func (b *Barrier) Release() {
b.ReleaseAfter(nil)
}

func (b *Barrier) serverWait() error {
if b.err != nil {
return b.err
Expand Down
24 changes: 16 additions & 8 deletions pubsublite/internal/wire/subscriber.go
Expand Up @@ -123,12 +123,13 @@ type subscribeStream struct {
metadata pubsubMetadata

// Fields below must be guarded with mu.
messageQueue *messageDeliveryQueue
stream *retryableStream
offsetTracker subscriberOffsetTracker
flowControl flowControlBatcher
pollFlowControl *periodicTask
seekInFlight bool
messageQueue *messageDeliveryQueue
stream *retryableStream
offsetTracker subscriberOffsetTracker
flowControl flowControlBatcher
pollFlowControl *periodicTask
seekInFlight bool
enableBatchFlowControl bool

abstractService
}
Expand Down Expand Up @@ -223,10 +224,14 @@ func (s *subscribeStream) onStreamStatusChange(status streamStatus) {
s.seekInFlight = true
}
s.unsafeSendFlowControl(s.flowControl.RequestForRestart())
s.enableBatchFlowControl = true
s.pollFlowControl.Start()

case streamReconnecting:
s.seekInFlight = false
// Ensure no batch flow control tokens are sent until the RequestForRestart
// is sent above when a new subscribe stream is initialized.
s.enableBatchFlowControl = false
s.pollFlowControl.Stop()

case streamTerminated:
Expand Down Expand Up @@ -301,12 +306,15 @@ func (s *subscribeStream) onAckAsync(msgBytes int64) {
func (s *subscribeStream) sendBatchFlowControl() {
s.mu.Lock()
defer s.mu.Unlock()
s.unsafeSendFlowControl(s.flowControl.ReleasePendingRequest())

if s.enableBatchFlowControl {
s.unsafeSendFlowControl(s.flowControl.ReleasePendingRequest())
}
}

func (s *subscribeStream) unsafeAllowFlow(allow flowControlTokens) {
s.flowControl.OnClientFlow(allow)
if s.flowControl.ShouldExpediteBatchRequest() {
if s.flowControl.ShouldExpediteBatchRequest() && s.enableBatchFlowControl {
s.unsafeSendFlowControl(s.flowControl.ReleasePendingRequest())
}
}
Expand Down
55 changes: 55 additions & 0 deletions pubsublite/internal/wire/subscriber_test.go
Expand Up @@ -224,6 +224,12 @@ func (ts *testSubscribeStream) SendBatchFlowControl() {
ts.sub.sendBatchFlowControl()
}

func (ts *testSubscribeStream) PendingFlowControlRequest() *pb.FlowControlRequest {
ts.sub.mu.Lock()
defer ts.sub.mu.Unlock()
return ts.sub.flowControl.pendingTokens.ToFlowControlRequest()
}

func TestSubscribeStreamReconnect(t *testing.T) {
subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
acks := newAckTracker()
Expand Down Expand Up @@ -329,6 +335,55 @@ func TestSubscribeStreamExpediteFlowControl(t *testing.T) {
}
}

func TestSubscribeStreamDisableBatchFlowControl(t *testing.T) {
subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
acks := newAckTracker()
// MaxOutstandingBytes = 1000, so this pushes the pending flow control bytes
// over the expediteBatchRequestRatio=50% threshold in flowControlBatcher.
msg := seqMsgWithOffsetAndSize(67, 800)
retryableErr := status.Error(codes.Unavailable, "unavailable")
serverErr := status.Error(codes.InvalidArgument, "verifies flow control received")

verifiers := test.NewVerifiers(t)

stream1 := test.NewRPCVerifier(t)
stream1.Push(initSubReq(subscription), initSubResp(), nil)
stream1.Push(initFlowControlReq(), msgSubResp(msg), nil)
// Break the stream immediately after sending the message.
stream1.Push(nil, nil, retryableErr)
verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream1)

stream2 := test.NewRPCVerifier(t)
// The barrier is used to pause in the middle of stream reconnection.
barrier := stream2.PushWithBarrier(initSubReq(subscription), initSubResp(), nil)
stream2.Push(seekReq(68), seekResp(68), nil)
// Full flow control tokens should be sent after stream has connected.
stream2.Push(initFlowControlReq(), nil, serverErr)
verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream2)

mockServer.OnTestStart(verifiers)
defer mockServer.OnTestEnd()

sub := newTestSubscribeStream(t, subscription, testSubscriberSettings(), acks)
if gotErr := sub.StartError(); gotErr != nil {
t.Errorf("Start() got err: (%v)", gotErr)
}

sub.Receiver.ValidateMsg(msg)
barrier.ReleaseAfter(func() {
// While the stream is not connected, the pending flow control request
// should not be released and sent to the stream.
sub.sub.onAckAsync(msg.SizeBytes)
if sub.PendingFlowControlRequest() == nil {
t.Errorf("Pending flow control request should not be cleared")
}
})

if gotErr := sub.FinalError(); !test.ErrorEqual(gotErr, serverErr) {
t.Errorf("Final err: (%v), want: (%v)", gotErr, serverErr)
}
}

func TestSubscribeStreamInvalidInitialResponse(t *testing.T) {
subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
acks := newAckTracker()
Expand Down