diff --git a/pubsublite/internal/wire/acks.go b/pubsublite/internal/wire/acks.go index 0a7c208e522..411a3eab8ee 100644 --- a/pubsublite/internal/wire/acks.go +++ b/pubsublite/internal/wire/acks.go @@ -83,7 +83,7 @@ const nilCursorOffset int64 = -1 // ackTracker manages outstanding message acks, i.e. messages that have been // delivered to the user, but not yet acked. It is used by the committer and -// wireSubscriber, so requires its own mutex. +// subscribeStream, so requires its own mutex. type ackTracker struct { // Guards access to fields below. mu sync.Mutex @@ -163,6 +163,13 @@ func (at *ackTracker) Release() { at.outstandingAcks.Init() } +// Empty when there are no outstanding acks. +func (at *ackTracker) Empty() bool { + at.mu.Lock() + defer at.mu.Unlock() + return at.outstandingAcks.Len() == 0 +} + // commitCursorTracker tracks pending and last successful committed offsets. // It is only accessed by the committer. type commitCursorTracker struct { diff --git a/pubsublite/internal/wire/committer.go b/pubsublite/internal/wire/committer.go new file mode 100644 index 00000000000..d540412c2e3 --- /dev/null +++ b/pubsublite/internal/wire/committer.go @@ -0,0 +1,219 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "context" + "errors" + "fmt" + "reflect" + "time" + + "google.golang.org/grpc" + + vkit "cloud.google.com/go/pubsublite/apiv1" + pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" +) + +var ( + errInvalidInitialCommitResponse = errors.New("pubsublite: first response from server was not an initial response for streaming commit") + errInvalidCommitResponse = errors.New("pubsublite: received invalid commit response from server") +) + +// The frequency of batched cursor commits. +const commitCursorPeriod = 50 * time.Millisecond + +// committer wraps a commit cursor stream for a subscription and partition. +// A background task periodically effectively reads the latest desired cursor +// offset from the `ackTracker` and sends a commit request to the stream if the +// cursor needs to be updated. The `commitCursorTracker` is used to manage +// in-flight commit requests. +type committer struct { + // Immutable after creation. + cursorClient *vkit.CursorClient + initialReq *pb.StreamingCommitCursorRequest + + // Fields below must be guarded with mutex. + stream *retryableStream + acks *ackTracker + cursorTracker *commitCursorTracker + pollCommits *periodicTask + + abstractService +} + +func newCommitter(ctx context.Context, cursor *vkit.CursorClient, settings ReceiveSettings, + subscription subscriptionPartition, acks *ackTracker, disableTasks bool) *committer { + + c := &committer{ + cursorClient: cursor, + initialReq: &pb.StreamingCommitCursorRequest{ + Request: &pb.StreamingCommitCursorRequest_Initial{ + Initial: &pb.InitialCommitCursorRequest{ + Subscription: subscription.Path, + Partition: int64(subscription.Partition), + }, + }, + }, + acks: acks, + cursorTracker: newCommitCursorTracker(acks), + } + c.stream = newRetryableStream(ctx, c, settings.Timeout, reflect.TypeOf(pb.StreamingCommitCursorResponse{})) + + backgroundTask := c.commitOffsetToStream + if disableTasks { + backgroundTask = func() {} + } + c.pollCommits = newPeriodicTask(commitCursorPeriod, backgroundTask) + return c +} + +// Start attempts to establish a streaming commit cursor connection. +func (c *committer) Start() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.unsafeUpdateStatus(serviceStarting, nil) { + c.stream.Start() + c.pollCommits.Start() + } +} + +// Stop initiates shutdown of the committer. The commit stream remains open to +// process all outstanding acks and send the final commit offset. +func (c *committer) Stop() { + c.mu.Lock() + defer c.mu.Unlock() + c.unsafeInitiateShutdown(serviceTerminating, nil) +} + +func (c *committer) newStream(ctx context.Context) (grpc.ClientStream, error) { + return c.cursorClient.StreamingCommitCursor(ctx) +} + +func (c *committer) initialRequest() (req interface{}, needsResp bool) { + req = c.initialReq + needsResp = true + return +} + +func (c *committer) validateInitialResponse(response interface{}) error { + commitResponse, _ := response.(*pb.StreamingCommitCursorResponse) + if commitResponse.GetInitial() == nil { + return errInvalidInitialCommitResponse + } + return nil +} + +func (c *committer) onStreamStatusChange(status streamStatus) { + c.mu.Lock() + defer c.mu.Unlock() + + switch status { + case streamConnected: + c.unsafeUpdateStatus(serviceActive, nil) + // Once the stream connects, clear unconfirmed commits and immediately send + // the latest desired commit offset. + c.cursorTracker.ClearPending() + c.unsafeCommitOffsetToStream() + c.pollCommits.Start() + + case streamReconnecting: + c.pollCommits.Stop() + + case streamTerminated: + c.unsafeInitiateShutdown(serviceTerminated, c.stream.Error()) + } +} + +func (c *committer) onResponse(response interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + // If an inconsistency is detected in the server's responses, immediately + // terminate the committer, as correct processing of commits cannot be + // guaranteed. + processResponse := func() error { + commitResponse, _ := response.(*pb.StreamingCommitCursorResponse) + if commitResponse.GetCommit() == nil { + return errInvalidCommitResponse + } + numAcked := commitResponse.GetCommit().GetAcknowledgedCommits() + if numAcked <= 0 { + return fmt.Errorf("pubsublite: server acknowledged an invalid commit count: %d", numAcked) + } + if err := c.cursorTracker.ConfirmOffsets(numAcked); err != nil { + return err + } + c.unsafeCheckDone() + return nil + } + if err := processResponse(); err != nil { + c.unsafeInitiateShutdown(serviceTerminated, err) + } +} + +// commitOffsetToStream is called by the periodic background task. +func (c *committer) commitOffsetToStream() { + c.mu.Lock() + defer c.mu.Unlock() + c.unsafeCommitOffsetToStream() +} + +func (c *committer) unsafeCommitOffsetToStream() { + nextOffset := c.cursorTracker.NextOffset() + if nextOffset == nilCursorOffset { + return + } + + req := &pb.StreamingCommitCursorRequest{ + Request: &pb.StreamingCommitCursorRequest_Commit{ + Commit: &pb.SequencedCommitCursorRequest{ + Cursor: &pb.Cursor{Offset: nextOffset}, + }, + }, + } + if c.stream.Send(req) { + c.cursorTracker.AddPending(nextOffset) + } +} + +func (c *committer) unsafeInitiateShutdown(targetStatus serviceStatus, err error) { + if !c.unsafeUpdateStatus(targetStatus, err) { + return + } + + // If it's a graceful shutdown, expedite sending final commits to the stream. + if targetStatus == serviceTerminating { + c.unsafeCommitOffsetToStream() + c.unsafeCheckDone() + return + } + // Otherwise immediately terminate the stream. + c.unsafeTerminate() +} + +func (c *committer) unsafeCheckDone() { + // If the user stops the subscriber, they will no longer receive messages, but + // the commit stream remains open to process acks for outstanding messages. + if c.status == serviceTerminating && c.cursorTracker.Done() && c.acks.Empty() { + c.unsafeTerminate() + } +} + +func (c *committer) unsafeTerminate() { + c.acks.Release() + c.pollCommits.Stop() + c.stream.Stop() +} diff --git a/pubsublite/internal/wire/committer_test.go b/pubsublite/internal/wire/committer_test.go new file mode 100644 index 00000000000..1c1ea2e8e79 --- /dev/null +++ b/pubsublite/internal/wire/committer_test.go @@ -0,0 +1,252 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "context" + "testing" + + "cloud.google.com/go/pubsublite/internal/test" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// testCommitter wraps a committer for ease of testing. +type testCommitter struct { + cmt *committer + serviceTestProxy +} + +func newTestCommitter(t *testing.T, subscription subscriptionPartition, acks *ackTracker) *testCommitter { + ctx := context.Background() + cursorClient, err := newCursorClient(ctx, "ignored", testClientOpts...) + if err != nil { + t.Fatal(err) + } + + tc := &testCommitter{ + cmt: newCommitter(ctx, cursorClient, testReceiveSettings(), subscription, acks, true), + } + tc.initAndStart(t, tc.cmt, "Committer") + return tc +} + +// SendBatchCommit invokes the periodic background batch commit. Note that the +// periodic task is disabled in tests. +func (tc *testCommitter) SendBatchCommit() { + tc.cmt.commitOffsetToStream() +} + +func TestCommitterStreamReconnect(t *testing.T) { + subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} + ack1 := newAckConsumer(33, 0, nil) + ack2 := newAckConsumer(55, 0, nil) + acks := newAckTracker() + acks.Push(ack1) + acks.Push(ack2) + + verifiers := test.NewVerifiers(t) + + // Simulate a transient error that results in a reconnect. + stream1 := test.NewRPCVerifier(t) + stream1.Push(initCommitReq(subscription), initCommitResp(), nil) + barrier := stream1.PushWithBarrier(commitReq(34), nil, status.Error(codes.Unavailable, "server unavailable")) + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream1) + + // When the stream reconnects, the latest commit offset should be sent to the + // server. + stream2 := test.NewRPCVerifier(t) + stream2.Push(initCommitReq(subscription), initCommitResp(), nil) + stream2.Push(commitReq(56), commitResp(1), nil) + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream2) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + cmt := newTestCommitter(t, subscription, acks) + if gotErr := cmt.StartError(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + + // Send 2 commits. + ack1.Ack() + cmt.SendBatchCommit() + ack2.Ack() + cmt.SendBatchCommit() + + // Then send the retryable error, which results in reconnect. + barrier.Release() + cmt.StopVerifyNoError() +} + +func TestCommitterStopFlushesCommits(t *testing.T) { + subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} + ack1 := newAckConsumer(33, 0, nil) + ack2 := newAckConsumer(55, 0, nil) + acks := newAckTracker() + acks.Push(ack1) + acks.Push(ack2) + + verifiers := test.NewVerifiers(t) + stream := test.NewRPCVerifier(t) + stream.Push(initCommitReq(subscription), initCommitResp(), nil) + stream.Push(commitReq(34), commitResp(1), nil) + stream.Push(commitReq(56), commitResp(1), nil) + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + cmt := newTestCommitter(t, subscription, acks) + if gotErr := cmt.StartError(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + + ack1.Ack() + cmt.Stop() // Stop should flush the first offset + ack2.Ack() // Acks after Stop() are still processed + cmt.SendBatchCommit() + // Committer terminates when all acks are processed. + if gotErr := cmt.FinalError(); gotErr != nil { + t.Errorf("Final err: (%v), want: ", gotErr) + } +} + +func TestCommitterPermanentStreamError(t *testing.T) { + subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} + acks := newAckTracker() + wantErr := status.Error(codes.FailedPrecondition, "failed") + + verifiers := test.NewVerifiers(t) + stream := test.NewRPCVerifier(t) + stream.Push(initCommitReq(subscription), nil, wantErr) + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + cmt := newTestCommitter(t, subscription, acks) + if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr) + } +} + +func TestCommitterInvalidInitialResponse(t *testing.T) { + subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} + acks := newAckTracker() + + verifiers := test.NewVerifiers(t) + stream := test.NewRPCVerifier(t) + stream.Push(initCommitReq(subscription), commitResp(1234), nil) // Invalid initial response + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + cmt := newTestCommitter(t, subscription, acks) + + wantErr := errInvalidInitialCommitResponse + if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr) + } + if gotErr := cmt.FinalError(); !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr) + } +} + +func TestCommitterInvalidCommitResponse(t *testing.T) { + subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} + ack := newAckConsumer(33, 0, nil) + acks := newAckTracker() + acks.Push(ack) + + verifiers := test.NewVerifiers(t) + stream := test.NewRPCVerifier(t) + stream.Push(initCommitReq(subscription), initCommitResp(), nil) + stream.Push(commitReq(34), initCommitResp(), nil) // Invalid commit response + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + cmt := newTestCommitter(t, subscription, acks) + if gotErr := cmt.StartError(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + + ack.Ack() + cmt.SendBatchCommit() + + if gotErr, wantErr := cmt.FinalError(), errInvalidCommitResponse; !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr) + } +} + +func TestCommitterExcessConfirmedOffsets(t *testing.T) { + subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} + ack := newAckConsumer(33, 0, nil) + acks := newAckTracker() + acks.Push(ack) + + verifiers := test.NewVerifiers(t) + stream := test.NewRPCVerifier(t) + stream.Push(initCommitReq(subscription), initCommitResp(), nil) + stream.Push(commitReq(34), commitResp(2), nil) // More confirmed offsets than committed + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + cmt := newTestCommitter(t, subscription, acks) + if gotErr := cmt.StartError(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + + ack.Ack() + cmt.SendBatchCommit() + + wantMsg := "server acknowledged 2 cursor commits" + if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg) + } +} + +func TestCommitterZeroConfirmedOffsets(t *testing.T) { + subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} + ack := newAckConsumer(33, 0, nil) + acks := newAckTracker() + acks.Push(ack) + + verifiers := test.NewVerifiers(t) + stream := test.NewRPCVerifier(t) + stream.Push(initCommitReq(subscription), initCommitResp(), nil) + stream.Push(commitReq(34), commitResp(0), nil) // Zero confirmed offsets (invalid) + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + cmt := newTestCommitter(t, subscription, acks) + if gotErr := cmt.StartError(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + + ack.Ack() + cmt.SendBatchCommit() + + wantMsg := "server acknowledged an invalid commit count" + if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg) + } +} diff --git a/pubsublite/internal/wire/service_util_test.go b/pubsublite/internal/wire/service_util_test.go new file mode 100644 index 00000000000..e55d1609ba5 --- /dev/null +++ b/pubsublite/internal/wire/service_util_test.go @@ -0,0 +1,91 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "fmt" + "testing" + "time" +) + +func testReceiveSettings() ReceiveSettings { + settings := DefaultReceiveSettings + settings.Timeout = 5 * time.Second + return settings +} + +const serviceTestWaitTimeout = 30 * time.Second + +// serviceTestProxy wraps a `service` and provides some convenience methods for +// testing. +type serviceTestProxy struct { + t *testing.T + service service + name string + started chan struct{} + terminated chan struct{} +} + +func (sp *serviceTestProxy) initAndStart(t *testing.T, s service, name string) { + sp.t = t + sp.service = s + sp.name = name + sp.started = make(chan struct{}) + sp.terminated = make(chan struct{}) + s.AddStatusChangeReceiver(nil, sp.onStatusChange) + s.Start() +} + +func (sp *serviceTestProxy) onStatusChange(_ serviceHandle, status serviceStatus, _ error) { + if status == serviceActive { + close(sp.started) + } + if status == serviceTerminated { + close(sp.terminated) + } +} + +func (sp *serviceTestProxy) Start() { sp.service.Start() } +func (sp *serviceTestProxy) Stop() { sp.service.Stop() } + +// StartError waits for the service to start and returns the error. +func (sp *serviceTestProxy) StartError() error { + select { + case <-time.After(serviceTestWaitTimeout): + return fmt.Errorf("%s did not start within %v", sp.name, serviceTestWaitTimeout) + case <-sp.terminated: + return sp.service.Error() + case <-sp.started: + return sp.service.Error() + } +} + +// FinalError waits for the service to terminate and returns the error. +func (sp *serviceTestProxy) FinalError() error { + select { + case <-time.After(serviceTestWaitTimeout): + return fmt.Errorf("%s did not terminate within %v", sp.name, serviceTestWaitTimeout) + case <-sp.terminated: + return sp.service.Error() + } +} + +// StopVerifyNoError stops the service, waits for it to terminate and verifies +// that there is no error. +func (sp *serviceTestProxy) StopVerifyNoError() { + sp.service.Stop() + if gotErr := sp.FinalError(); gotErr != nil { + sp.t.Errorf("%s final err: (%v), want: ", sp.name, gotErr) + } +}