diff --git a/pubsublite/internal/wire/acks.go b/pubsublite/internal/wire/acks.go index 0d17baf3953..2e55fd7a688 100644 --- a/pubsublite/internal/wire/acks.go +++ b/pubsublite/internal/wire/acks.go @@ -169,6 +169,13 @@ func (at *ackTracker) unsafeProcessAcks() { } } +// Empty returns true if there are no outstanding acks. +func (at *ackTracker) Empty() bool { + at.mu.Lock() + defer at.mu.Unlock() + return at.outstandingAcks.Front() == nil +} + // commitCursorTracker tracks pending and last successful committed offsets. // It is only accessed by the committer. type commitCursorTracker struct { diff --git a/pubsublite/internal/wire/committer.go b/pubsublite/internal/wire/committer.go index 63c7b487774..89eb2d9f997 100644 --- a/pubsublite/internal/wire/committer.go +++ b/pubsublite/internal/wire/committer.go @@ -90,8 +90,8 @@ func (c *committer) Start() { } } -// Stop initiates shutdown of the committer. The final commit offset will be -// send to the server, but acks that arrive afterward will be discarded. +// Stop initiates shutdown of the committer. It will wait for outstanding acks +// and send the final commit offset to the server. func (c *committer) Stop() { c.mu.Lock() defer c.mu.Unlock() @@ -194,23 +194,22 @@ func (c *committer) unsafeInitiateShutdown(targetStatus serviceStatus, err error return } - // Discard outstanding acks. - c.acks.Release() - // If it's a graceful shutdown, expedite sending final commits to the stream. if targetStatus == serviceTerminating { c.unsafeCommitOffsetToStream() c.unsafeCheckDone() return } - // Otherwise immediately terminate the stream. + + // Otherwise discard outstanding acks and immediately terminate the stream. + c.acks.Release() c.unsafeTerminate() } func (c *committer) unsafeCheckDone() { // The commit stream can be closed once the final commit offset has been - // confirmed. - if c.status == serviceTerminating && c.cursorTracker.UpToDate() { + // confirmed and there are no outstanding acks. + if c.status == serviceTerminating && c.cursorTracker.UpToDate() && c.acks.Empty() { c.unsafeTerminate() } } diff --git a/pubsublite/internal/wire/committer_test.go b/pubsublite/internal/wire/committer_test.go index 610a48e7827..2f0144ea0da 100644 --- a/pubsublite/internal/wire/committer_test.go +++ b/pubsublite/internal/wire/committer_test.go @@ -102,6 +102,7 @@ func TestCommitterStopFlushesCommits(t *testing.T) { stream := test.NewRPCVerifier(t) stream.Push(initCommitReq(subscription), initCommitResp(), nil) stream.Push(commitReq(34), commitResp(1), nil) + stream.Push(commitReq(56), commitResp(1), nil) verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream) mockServer.OnTestStart(verifiers) @@ -114,7 +115,7 @@ func TestCommitterStopFlushesCommits(t *testing.T) { ack1.Ack() cmt.Stop() // Stop should flush the first offset - ack2.Ack() // Acks after Stop() are discarded + ack2.Ack() // Acks after Stop() are processed cmt.SendBatchCommit() // Committer terminates when all acks are processed. if gotErr := cmt.FinalError(); gotErr != nil { diff --git a/pubsublite/internal/wire/subscriber.go b/pubsublite/internal/wire/subscriber.go index deb488b4f8d..9f1d2c6a3de 100644 --- a/pubsublite/internal/wire/subscriber.go +++ b/pubsublite/internal/wire/subscriber.go @@ -17,7 +17,6 @@ import ( "context" "errors" "reflect" - "sync" "time" "google.golang.org/grpc" @@ -40,42 +39,37 @@ type ReceivedMessage struct { Ack AckConsumer } -// MessageReceiverFunc receives a batch of Pub/Sub messages from a topic -// partition. -type MessageReceiverFunc func([]*ReceivedMessage) +// MessageReceiverFunc receives a Pub/Sub message from a topic partition. +type MessageReceiverFunc func(*ReceivedMessage) -const maxMessagesBufferSize = 1000 +const maxMessageBufferSize = 10000 // messageDeliveryQueue delivers received messages to the client-provided // MessageReceiverFunc sequentially. type messageDeliveryQueue struct { receiver MessageReceiverFunc - messagesC chan []*ReceivedMessage + messagesC chan *ReceivedMessage stopC chan struct{} - - // Fields below must be guarded with mu. - mu sync.Mutex - status serviceStatus + acks *ackTracker + status serviceStatus } -func newMessageDeliveryQueue(receiver MessageReceiverFunc, bufferSize int) *messageDeliveryQueue { - // The buffer size is based on ReceiveSettings.MaxOutstandingMessages to - // handle the worst case of single messages. But ensure there's a reasonable - // limit as channel buffer capacity is allocated on creation. - if bufferSize > maxMessagesBufferSize { - bufferSize = maxMessagesBufferSize +func newMessageDeliveryQueue(acks *ackTracker, receiver MessageReceiverFunc, bufferSize int) *messageDeliveryQueue { + // The buffer size is based on ReceiveSettings.MaxOutstandingMessages. But + // ensure there's a reasonable limit as channel buffer capacity is allocated + // on creation. + if bufferSize > maxMessageBufferSize { + bufferSize = maxMessageBufferSize } return &messageDeliveryQueue{ + acks: acks, receiver: receiver, - messagesC: make(chan []*ReceivedMessage, bufferSize), + messagesC: make(chan *ReceivedMessage, bufferSize), stopC: make(chan struct{}), } } func (mq *messageDeliveryQueue) Start() { - mq.mu.Lock() - defer mq.mu.Unlock() - if mq.status == serviceUninitialized { go mq.deliverMessages() mq.status = serviceActive @@ -83,9 +77,6 @@ func (mq *messageDeliveryQueue) Start() { } func (mq *messageDeliveryQueue) Stop() { - mq.mu.Lock() - defer mq.mu.Unlock() - if mq.status < serviceTerminated { close(mq.stopC) mq.status = serviceTerminated @@ -93,11 +84,10 @@ func (mq *messageDeliveryQueue) Stop() { } func (mq *messageDeliveryQueue) Add(messages []*ReceivedMessage) { - mq.mu.Lock() - defer mq.mu.Unlock() - if mq.status == serviceActive { - mq.messagesC <- messages + for _, msg := range messages { + mq.messagesC <- msg + } } } @@ -113,8 +103,11 @@ func (mq *messageDeliveryQueue) deliverMessages() { select { case <-mq.stopC: return // Ends the goroutine. - case msgs := <-mq.messagesC: - mq.receiver(msgs) + case msg := <-mq.messagesC: + // Register outstanding acks, which are primarily handled by the + // `committer`. + mq.acks.Push(msg.Ack.(*ackConsumer)) + mq.receiver(msg) } } } @@ -138,7 +131,6 @@ type subscribeStream struct { // Fields below must be guarded with mu. stream *retryableStream - acks *ackTracker offsetTracker subscriberOffsetTracker flowControl flowControlBatcher pollFlowControl *periodicTask @@ -162,8 +154,7 @@ func newSubscribeStream(ctx context.Context, subClient *vkit.SubscriberClient, s }, }, }, - messageQueue: newMessageDeliveryQueue(receiver, settings.MaxOutstandingMessages), - acks: acks, + messageQueue: newMessageDeliveryQueue(acks, receiver, settings.MaxOutstandingMessages), } s.stream = newRetryableStream(ctx, s, settings.Timeout, reflect.TypeOf(pb.SubscribeResponse{})) @@ -290,12 +281,7 @@ func (s *subscribeStream) unsafeOnMessageResponse(response *pb.MessageResponse) var receivedMsgs []*ReceivedMessage for _, msg := range response.Messages { - // Register outstanding acks, which are primarily handled by the - // `committer`. ack := newAckConsumer(msg.GetCursor().GetOffset(), msg.GetSizeBytes(), s.onAck) - if err := s.acks.Push(ack); err != nil { - return err - } receivedMsgs = append(receivedMsgs, &ReceivedMessage{Msg: msg, Ack: ack}) } s.messageQueue.Add(receivedMsgs) diff --git a/pubsublite/internal/wire/subscriber_test.go b/pubsublite/internal/wire/subscriber_test.go index ef7dd98a3a2..9749a51e8a0 100644 --- a/pubsublite/internal/wire/subscriber_test.go +++ b/pubsublite/internal/wire/subscriber_test.go @@ -21,9 +21,9 @@ import ( "cloud.google.com/go/internal/testutil" "cloud.google.com/go/pubsublite/internal/test" - "github.com/golang/protobuf/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" ) @@ -53,10 +53,8 @@ func newTestMessageReceiver(t *testing.T) *testMessageReceiver { } } -func (tr *testMessageReceiver) onMessages(msgs []*ReceivedMessage) { - for _, msg := range msgs { - tr.received <- msg - } +func (tr *testMessageReceiver) onMessage(msg *ReceivedMessage) { + tr.received <- msg } func (tr *testMessageReceiver) ValidateMsg(want *pb.SequencedMessage) AckConsumer { @@ -126,45 +124,56 @@ func newTestBlockingMessageReceiver(t *testing.T) *testBlockingMessageReceiver { } } -// onMessages is the message receiver func and blocks until there is a call to +// onMessage is the message receiver func and blocks until there is a call to // Return(). -func (tr *testBlockingMessageReceiver) onMessages(msgs []*ReceivedMessage) { - tr.testMessageReceiver.onMessages(msgs) +func (tr *testBlockingMessageReceiver) onMessage(msg *ReceivedMessage) { + tr.testMessageReceiver.onMessage(msg) <-tr.blockReceive } -// Return signals onMessages to return. +// Return signals onMessage to return. func (tr *testBlockingMessageReceiver) Return() { var void struct{} tr.blockReceive <- void } -func TestNewMessageDeliveryQueue(t *testing.T) { +func TestMessageDeliveryQueue(t *testing.T) { + acks := newAckTracker() receiver := newTestMessageReceiver(t) - messageQueue := newMessageDeliveryQueue(receiver.onMessages, 10) + messageQueue := newMessageDeliveryQueue(acks, receiver.onMessage, 10) t.Run("Add before start", func(t *testing.T) { - messageQueue.Add([]*ReceivedMessage{{Msg: seqMsgWithOffset(3)}}) + msg1 := seqMsgWithOffset(1) + ack1 := newAckConsumer(1, 0, nil) + messageQueue.Add([]*ReceivedMessage{{Msg: msg1, Ack: ack1}}) receiver.VerifyNoMsgs() }) t.Run("Add after start", func(t *testing.T) { - msg1 := seqMsgWithOffset(1) msg2 := seqMsgWithOffset(2) + ack2 := newAckConsumer(2, 0, nil) + msg3 := seqMsgWithOffset(3) + ack3 := newAckConsumer(3, 0, nil) messageQueue.Start() messageQueue.Start() // Check duplicate starts - messageQueue.Add([]*ReceivedMessage{{Msg: msg1}, {Msg: msg2}}) + messageQueue.Add([]*ReceivedMessage{ + {Msg: msg2, Ack: ack2}, + {Msg: msg3, Ack: ack3}, + }) - receiver.ValidateMsg(msg1) receiver.ValidateMsg(msg2) + receiver.ValidateMsg(msg3) }) t.Run("Add after stop", func(t *testing.T) { + msg4 := seqMsgWithOffset(4) + ack4 := newAckConsumer(4, 0, nil) + messageQueue.Stop() messageQueue.Stop() // Check duplicate stop - messageQueue.Add([]*ReceivedMessage{{Msg: seqMsgWithOffset(4)}}) + messageQueue.Add([]*ReceivedMessage{{Msg: msg4, Ack: ack4}}) receiver.VerifyNoMsgs() }) @@ -189,7 +198,7 @@ func newTestSubscribeStream(t *testing.T, subscription subscriptionPartition, se Receiver: newTestMessageReceiver(t), t: t, } - ts.sub = newSubscribeStream(ctx, subClient, settings, ts.Receiver.onMessages, subscription, acks, true) + ts.sub = newSubscribeStream(ctx, subClient, settings, ts.Receiver.onMessage, subscription, acks, true) ts.initAndStart(t, ts.sub, "Subscriber") return ts } @@ -486,7 +495,7 @@ func TestSinglePartitionSubscriberStartStop(t *testing.T) { mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestSinglePartitionSubscriber(t, receiver.onMessages, subscription) + sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } @@ -518,7 +527,7 @@ func TestSinglePartitionSubscriberSimpleMsgAck(t *testing.T) { mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestSinglePartitionSubscriber(t, receiver.onMessages, subscription) + sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } @@ -557,29 +566,32 @@ func TestSinglePartitionSubscriberMessageQueue(t *testing.T) { cmtStream := test.NewRPCVerifier(t) cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil) + cmtStream.Push(commitReq(4), commitResp(1), nil) verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream) mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestSinglePartitionSubscriber(t, receiver.onMessages, subscription) + sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } // Verifies that messageDeliveryQueue delivers messages sequentially and waits - // for the client message receiver func to return. - receiver.ValidateMsg(msg1) - receiver.VerifyNoMsgs() - receiver.Return() - - receiver.ValidateMsg(msg2) - receiver.VerifyNoMsgs() - receiver.Return() + // for the client message receiver func to return before delivering the next + // message. + var acks []AckConsumer + for _, msg := range []*pb.SequencedMessage{msg1, msg2, msg3} { + ack := receiver.ValidateMsg(msg) + acks = append(acks, ack) + receiver.VerifyNoMsgs() + receiver.Return() + } - receiver.ValidateMsg(msg3) - receiver.VerifyNoMsgs() - receiver.Return() + // Ack all messages so that the committer terminates. + for _, ack := range acks { + ack.Ack() + } sub.Stop() if gotErr := sub.WaitStopped(); gotErr != nil { @@ -597,24 +609,23 @@ func TestSinglePartitionSubscriberStopDuringReceive(t *testing.T) { subStream := test.NewRPCVerifier(t) subStream.Push(initSubReq(subscription), initSubResp(), nil) - subStream.Push(initFlowControlReq(), msgSubResp(msg1), nil) - subStream.Push(nil, msgSubResp(msg2), nil) + subStream.Push(initFlowControlReq(), msgSubResp(msg1, msg2), nil) verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream) cmtStream := test.NewRPCVerifier(t) cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil) + cmtStream.Push(commitReq(2), commitResp(1), nil) verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream) mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestSinglePartitionSubscriber(t, receiver.onMessages, subscription) + sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } - receiver.ValidateMsg(msg1) - receiver.VerifyNoMsgs() + receiver.ValidateMsg(msg1).Ack() // Stop the subscriber before returning from the message receiver func. sub.Stop() @@ -653,7 +664,7 @@ func newTestMultiPartitionSubscriber(t *testing.T, receiverFunc MessageReceiverF } func TestMultiPartitionSubscriberMultipleMessages(t *testing.T) { - subscription := "projects/123456/locations/us-central1-b/subscriptions/my-sub" + const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub" receiver := newTestMessageReceiver(t) msg1 := seqMsgWithOffsetAndSize(22, 100) msg2 := seqMsgWithOffsetAndSize(23, 200) @@ -689,7 +700,7 @@ func TestMultiPartitionSubscriberMultipleMessages(t *testing.T) { mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestMultiPartitionSubscriber(t, receiver.onMessages, subscription, []int{1, 2}) + sub := newTestMultiPartitionSubscriber(t, receiver.onMessage, subscription, []int{1, 2}) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } @@ -701,7 +712,7 @@ func TestMultiPartitionSubscriberMultipleMessages(t *testing.T) { } func TestMultiPartitionSubscriberPermanentError(t *testing.T) { - subscription := "projects/123456/locations/us-central1-b/subscriptions/my-sub" + const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub" receiver := newTestMessageReceiver(t) msg1 := seqMsgWithOffsetAndSize(22, 100) msg2 := seqMsgWithOffsetAndSize(23, 200) @@ -737,12 +748,12 @@ func TestMultiPartitionSubscriberPermanentError(t *testing.T) { mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestMultiPartitionSubscriber(t, receiver.onMessages, subscription, []int{1, 2}) + sub := newTestMultiPartitionSubscriber(t, receiver.onMessage, subscription, []int{1, 2}) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } receiver.ValidateMsgs([]*pb.SequencedMessage{msg1, msg3}) - errorBarrier.Release() // Send server error + errorBarrier.Release() // Release server error now to ensure test is deterministic if gotErr := sub.WaitStopped(); !test.ErrorEqual(gotErr, serverErr) { t.Errorf("Final error got: (%v), want: (%v)", gotErr, serverErr) }