From 94485e6a5df08ef1ef668c71e4fcf916b40af852 Mon Sep 17 00:00:00 2001 From: tmdiep Date: Thu, 3 Jun 2021 10:00:39 +1000 Subject: [PATCH] refactor(pubsublite): allow subscribe stream trackers to be reset (#4145) Enables the flowControlBatcher, subscriberOffsetTracker and messageDeliveryQueue to be reset. --- pubsublite/internal/wire/flow_control.go | 12 ++++ pubsublite/internal/wire/flow_control_test.go | 22 +++++++ pubsublite/internal/wire/subscriber.go | 62 ++++++++++++------- pubsublite/internal/wire/subscriber_test.go | 48 +++++++++++++- 4 files changed, 121 insertions(+), 23 deletions(-) diff --git a/pubsublite/internal/wire/flow_control.go b/pubsublite/internal/wire/flow_control.go index 3ac3a848bd0..7a0421392e0 100644 --- a/pubsublite/internal/wire/flow_control.go +++ b/pubsublite/internal/wire/flow_control.go @@ -95,6 +95,13 @@ func exceedsExpediteRatio(pending, client int64) bool { return client > 0 && (float64(pending)/float64(client)) >= expediteBatchRequestRatio } +// Reset client tokens to the given values and reset pending tokens. +func (fc *flowControlBatcher) Reset(tokens flowControlTokens) { + fc.clientTokens.Reset() + fc.clientTokens.Add(tokens) + fc.pendingTokens.Reset() +} + // OnClientFlow increments flow control tokens. This occurs when: // - Initialization from ReceiveSettings. // - The user acks messages. @@ -147,6 +154,11 @@ type subscriberOffsetTracker struct { minNextOffset int64 } +// Reset the offset tracker to the initial state. +func (ot *subscriberOffsetTracker) Reset() { + ot.minNextOffset = 0 +} + // RequestForRestart returns the seek request to send when a new subscribe // stream reconnects. Returns nil if the subscriber has just started, in which // case the server returns the offset of the last committed cursor. diff --git a/pubsublite/internal/wire/flow_control_test.go b/pubsublite/internal/wire/flow_control_test.go index f38255b3ee4..d070f1fde0b 100644 --- a/pubsublite/internal/wire/flow_control_test.go +++ b/pubsublite/internal/wire/flow_control_test.go @@ -222,6 +222,28 @@ func TestFlowControlBatcher(t *testing.T) { }) } +func TestFlowControlBatcherReset(t *testing.T) { + var batcher flowControlBatcher + + initialTokens := flowControlTokens{Bytes: 400, Messages: 40} + batcher.OnClientFlow(initialTokens) + if got, want := batcher.clientTokens.ToFlowControlRequest(), flowControlReq(initialTokens); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.clientTokens.ToFlowControlRequest(): got %v, want %v", got, want) + } + if got, want := batcher.pendingTokens.ToFlowControlRequest(), flowControlReq(initialTokens); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.pendingTokens.ToFlowControlRequest(): got %v, want %v", got, want) + } + + updatedTokens := flowControlTokens{Bytes: 500, Messages: 50} + batcher.Reset(updatedTokens) + if got, want := batcher.clientTokens.ToFlowControlRequest(), flowControlReq(updatedTokens); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.clientTokens.ToFlowControlRequest(): got %v, want %v", got, want) + } + if got, want := batcher.pendingTokens.ToFlowControlRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.pendingTokens.ToFlowControlRequest(): got %v, want %v", got, want) + } +} + func TestOffsetTrackerRequestForRestart(t *testing.T) { for _, tc := range []struct { desc string diff --git a/pubsublite/internal/wire/subscriber.go b/pubsublite/internal/wire/subscriber.go index 56344cae0d2..cad68b4d2ae 100644 --- a/pubsublite/internal/wire/subscriber.go +++ b/pubsublite/internal/wire/subscriber.go @@ -17,6 +17,7 @@ import ( "context" "errors" "reflect" + "sync" "time" "github.com/google/uuid" @@ -46,57 +47,74 @@ type ReceivedMessage struct { type MessageReceiverFunc func(*ReceivedMessage) // messageDeliveryQueue delivers received messages to the client-provided -// MessageReceiverFunc sequentially. +// MessageReceiverFunc sequentially. It is only accessed by the subscribeStream. type messageDeliveryQueue struct { - receiver MessageReceiverFunc - messagesC chan *ReceivedMessage - stopC chan struct{} - acks *ackTracker - status serviceStatus + bufferSize int + acks *ackTracker + receiver MessageReceiverFunc + messagesC chan *ReceivedMessage + stopC chan struct{} + active sync.WaitGroup } func newMessageDeliveryQueue(acks *ackTracker, receiver MessageReceiverFunc, bufferSize int) *messageDeliveryQueue { return &messageDeliveryQueue{ - acks: acks, - receiver: receiver, - messagesC: make(chan *ReceivedMessage, bufferSize), - stopC: make(chan struct{}), + bufferSize: bufferSize, + acks: acks, + receiver: receiver, } } +// Start the message delivery, if not already started. func (mq *messageDeliveryQueue) Start() { - if mq.status == serviceUninitialized { - go mq.deliverMessages() - mq.status = serviceActive + if mq.stopC != nil { + return } + + mq.stopC = make(chan struct{}) + mq.messagesC = make(chan *ReceivedMessage, mq.bufferSize) + mq.active.Add(1) + go mq.deliverMessages(mq.messagesC, mq.stopC) } +// Stop message delivery and discard undelivered messages. func (mq *messageDeliveryQueue) Stop() { - if mq.status < serviceTerminated { - close(mq.stopC) - mq.status = serviceTerminated + if mq.stopC == nil { + return } + + close(mq.stopC) + mq.stopC = nil + mq.messagesC = nil +} + +// Wait until the message delivery goroutine has terminated. +func (mq *messageDeliveryQueue) Wait() { + mq.active.Wait() } func (mq *messageDeliveryQueue) Add(msg *ReceivedMessage) { - if mq.status == serviceActive { + if mq.messagesC != nil { mq.messagesC <- msg } } -func (mq *messageDeliveryQueue) deliverMessages() { +func (mq *messageDeliveryQueue) deliverMessages(messagesC chan *ReceivedMessage, stopC chan struct{}) { + // Notify the wait group that the goroutine has terminated upon exit. + defer mq.active.Done() + for { // stopC has higher priority. select { - case <-mq.stopC: + case <-stopC: return // Ends the goroutine. default: } select { - case <-mq.stopC: + case <-stopC: return // Ends the goroutine. - case msg := <-mq.messagesC: + case msg := <-messagesC: // Register outstanding acks, which are primarily handled by the // `committer`. mq.acks.Push(msg.Ack.(*ackConsumer)) @@ -175,7 +193,7 @@ func (s *subscribeStream) Start() { s.pollFlowControl.Start() s.messageQueue.Start() - s.flowControl.OnClientFlow(flowControlTokens{ + s.flowControl.Reset(flowControlTokens{ Bytes: int64(s.settings.MaxOutstandingBytes), Messages: int64(s.settings.MaxOutstandingMessages), }) diff --git a/pubsublite/internal/wire/subscriber_test.go b/pubsublite/internal/wire/subscriber_test.go index b29aa0c3233..a5b417c8b1d 100644 --- a/pubsublite/internal/wire/subscriber_test.go +++ b/pubsublite/internal/wire/subscriber_test.go @@ -154,7 +154,7 @@ func (tr *testBlockingMessageReceiver) Return() { tr.blockReceive <- void } -func TestMessageDeliveryQueue(t *testing.T) { +func TestMessageDeliveryQueueStartStop(t *testing.T) { acks := newAckTracker() receiver := newTestMessageReceiver(t) messageQueue := newMessageDeliveryQueue(acks, receiver.onMessage, 10) @@ -189,9 +189,55 @@ func TestMessageDeliveryQueue(t *testing.T) { messageQueue.Stop() messageQueue.Stop() // Check duplicate stop messageQueue.Add(&ReceivedMessage{Msg: msg4, Ack: ack4}) + messageQueue.Wait() receiver.VerifyNoMsgs() }) + + t.Run("Restart", func(t *testing.T) { + msg5 := seqMsgWithOffset(5) + ack5 := newAckConsumer(5, 0, nil) + + messageQueue.Start() + messageQueue.Add(&ReceivedMessage{Msg: msg5, Ack: ack5}) + + receiver.ValidateMsg(msg5) + }) + + t.Run("Stop", func(t *testing.T) { + messageQueue.Stop() + messageQueue.Wait() + + receiver.VerifyNoMsgs() + }) +} + +func TestMessageDeliveryQueueDiscardMessages(t *testing.T) { + acks := newAckTracker() + blockingReceiver := newTestBlockingMessageReceiver(t) + messageQueue := newMessageDeliveryQueue(acks, blockingReceiver.onMessage, 10) + + msg1 := seqMsgWithOffset(1) + ack1 := newAckConsumer(1, 0, nil) + msg2 := seqMsgWithOffset(2) + ack2 := newAckConsumer(2, 0, nil) + + messageQueue.Start() + messageQueue.Add(&ReceivedMessage{Msg: msg1, Ack: ack1}) + messageQueue.Add(&ReceivedMessage{Msg: msg2, Ack: ack2}) + + // The blocking receiver suspends after receiving msg1. + blockingReceiver.ValidateMsg(msg1) + // Stopping the message queue should discard undelivered msg2. + messageQueue.Stop() + + // Unsuspend the blocking receiver and verify msg2 is not received. + blockingReceiver.Return() + messageQueue.Wait() + blockingReceiver.VerifyNoMsgs() + if got, want := acks.outstandingAcks.Len(), 1; got != want { + t.Errorf("ackTracker.outstandingAcks.Len() got %v, want %v", got, want) + } } // testSubscribeStream wraps a subscribeStream for ease of testing.