Skip to content

Commit

Permalink
refactor(pubsublite): allow subscribe stream trackers to be reset (#4145
Browse files Browse the repository at this point in the history
)

Enables the flowControlBatcher, subscriberOffsetTracker and messageDeliveryQueue to be reset.
  • Loading branch information
tmdiep committed Jun 3, 2021
1 parent cac48ea commit 94485e6
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 23 deletions.
12 changes: 12 additions & 0 deletions pubsublite/internal/wire/flow_control.go
Expand Up @@ -95,6 +95,13 @@ func exceedsExpediteRatio(pending, client int64) bool {
return client > 0 && (float64(pending)/float64(client)) >= expediteBatchRequestRatio
}

// Reset client tokens to the given values and reset pending tokens.
func (fc *flowControlBatcher) Reset(tokens flowControlTokens) {
fc.clientTokens.Reset()
fc.clientTokens.Add(tokens)
fc.pendingTokens.Reset()
}

// OnClientFlow increments flow control tokens. This occurs when:
// - Initialization from ReceiveSettings.
// - The user acks messages.
Expand Down Expand Up @@ -147,6 +154,11 @@ type subscriberOffsetTracker struct {
minNextOffset int64
}

// Reset the offset tracker to the initial state.
func (ot *subscriberOffsetTracker) Reset() {
ot.minNextOffset = 0
}

// RequestForRestart returns the seek request to send when a new subscribe
// stream reconnects. Returns nil if the subscriber has just started, in which
// case the server returns the offset of the last committed cursor.
Expand Down
22 changes: 22 additions & 0 deletions pubsublite/internal/wire/flow_control_test.go
Expand Up @@ -222,6 +222,28 @@ func TestFlowControlBatcher(t *testing.T) {
})
}

func TestFlowControlBatcherReset(t *testing.T) {
var batcher flowControlBatcher

initialTokens := flowControlTokens{Bytes: 400, Messages: 40}
batcher.OnClientFlow(initialTokens)
if got, want := batcher.clientTokens.ToFlowControlRequest(), flowControlReq(initialTokens); !proto.Equal(got, want) {
t.Errorf("flowControlBatcher.clientTokens.ToFlowControlRequest(): got %v, want %v", got, want)
}
if got, want := batcher.pendingTokens.ToFlowControlRequest(), flowControlReq(initialTokens); !proto.Equal(got, want) {
t.Errorf("flowControlBatcher.pendingTokens.ToFlowControlRequest(): got %v, want %v", got, want)
}

updatedTokens := flowControlTokens{Bytes: 500, Messages: 50}
batcher.Reset(updatedTokens)
if got, want := batcher.clientTokens.ToFlowControlRequest(), flowControlReq(updatedTokens); !proto.Equal(got, want) {
t.Errorf("flowControlBatcher.clientTokens.ToFlowControlRequest(): got %v, want %v", got, want)
}
if got, want := batcher.pendingTokens.ToFlowControlRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
t.Errorf("flowControlBatcher.pendingTokens.ToFlowControlRequest(): got %v, want %v", got, want)
}
}

func TestOffsetTrackerRequestForRestart(t *testing.T) {
for _, tc := range []struct {
desc string
Expand Down
62 changes: 40 additions & 22 deletions pubsublite/internal/wire/subscriber.go
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"errors"
"reflect"
"sync"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -46,57 +47,74 @@ type ReceivedMessage struct {
type MessageReceiverFunc func(*ReceivedMessage)

// messageDeliveryQueue delivers received messages to the client-provided
// MessageReceiverFunc sequentially.
// MessageReceiverFunc sequentially. It is only accessed by the subscribeStream.
type messageDeliveryQueue struct {
receiver MessageReceiverFunc
messagesC chan *ReceivedMessage
stopC chan struct{}
acks *ackTracker
status serviceStatus
bufferSize int
acks *ackTracker
receiver MessageReceiverFunc
messagesC chan *ReceivedMessage
stopC chan struct{}
active sync.WaitGroup
}

func newMessageDeliveryQueue(acks *ackTracker, receiver MessageReceiverFunc, bufferSize int) *messageDeliveryQueue {
return &messageDeliveryQueue{
acks: acks,
receiver: receiver,
messagesC: make(chan *ReceivedMessage, bufferSize),
stopC: make(chan struct{}),
bufferSize: bufferSize,
acks: acks,
receiver: receiver,
}
}

// Start the message delivery, if not already started.
func (mq *messageDeliveryQueue) Start() {
if mq.status == serviceUninitialized {
go mq.deliverMessages()
mq.status = serviceActive
if mq.stopC != nil {
return
}

mq.stopC = make(chan struct{})
mq.messagesC = make(chan *ReceivedMessage, mq.bufferSize)
mq.active.Add(1)
go mq.deliverMessages(mq.messagesC, mq.stopC)
}

// Stop message delivery and discard undelivered messages.
func (mq *messageDeliveryQueue) Stop() {
if mq.status < serviceTerminated {
close(mq.stopC)
mq.status = serviceTerminated
if mq.stopC == nil {
return
}

close(mq.stopC)
mq.stopC = nil
mq.messagesC = nil
}

// Wait until the message delivery goroutine has terminated.
func (mq *messageDeliveryQueue) Wait() {
mq.active.Wait()
}

func (mq *messageDeliveryQueue) Add(msg *ReceivedMessage) {
if mq.status == serviceActive {
if mq.messagesC != nil {
mq.messagesC <- msg
}
}

func (mq *messageDeliveryQueue) deliverMessages() {
func (mq *messageDeliveryQueue) deliverMessages(messagesC chan *ReceivedMessage, stopC chan struct{}) {
// Notify the wait group that the goroutine has terminated upon exit.
defer mq.active.Done()

for {
// stopC has higher priority.
select {
case <-mq.stopC:
case <-stopC:
return // Ends the goroutine.
default:
}

select {
case <-mq.stopC:
case <-stopC:
return // Ends the goroutine.
case msg := <-mq.messagesC:
case msg := <-messagesC:
// Register outstanding acks, which are primarily handled by the
// `committer`.
mq.acks.Push(msg.Ack.(*ackConsumer))
Expand Down Expand Up @@ -175,7 +193,7 @@ func (s *subscribeStream) Start() {
s.pollFlowControl.Start()
s.messageQueue.Start()

s.flowControl.OnClientFlow(flowControlTokens{
s.flowControl.Reset(flowControlTokens{
Bytes: int64(s.settings.MaxOutstandingBytes),
Messages: int64(s.settings.MaxOutstandingMessages),
})
Expand Down
48 changes: 47 additions & 1 deletion pubsublite/internal/wire/subscriber_test.go
Expand Up @@ -154,7 +154,7 @@ func (tr *testBlockingMessageReceiver) Return() {
tr.blockReceive <- void
}

func TestMessageDeliveryQueue(t *testing.T) {
func TestMessageDeliveryQueueStartStop(t *testing.T) {
acks := newAckTracker()
receiver := newTestMessageReceiver(t)
messageQueue := newMessageDeliveryQueue(acks, receiver.onMessage, 10)
Expand Down Expand Up @@ -189,9 +189,55 @@ func TestMessageDeliveryQueue(t *testing.T) {
messageQueue.Stop()
messageQueue.Stop() // Check duplicate stop
messageQueue.Add(&ReceivedMessage{Msg: msg4, Ack: ack4})
messageQueue.Wait()

receiver.VerifyNoMsgs()
})

t.Run("Restart", func(t *testing.T) {
msg5 := seqMsgWithOffset(5)
ack5 := newAckConsumer(5, 0, nil)

messageQueue.Start()
messageQueue.Add(&ReceivedMessage{Msg: msg5, Ack: ack5})

receiver.ValidateMsg(msg5)
})

t.Run("Stop", func(t *testing.T) {
messageQueue.Stop()
messageQueue.Wait()

receiver.VerifyNoMsgs()
})
}

func TestMessageDeliveryQueueDiscardMessages(t *testing.T) {
acks := newAckTracker()
blockingReceiver := newTestBlockingMessageReceiver(t)
messageQueue := newMessageDeliveryQueue(acks, blockingReceiver.onMessage, 10)

msg1 := seqMsgWithOffset(1)
ack1 := newAckConsumer(1, 0, nil)
msg2 := seqMsgWithOffset(2)
ack2 := newAckConsumer(2, 0, nil)

messageQueue.Start()
messageQueue.Add(&ReceivedMessage{Msg: msg1, Ack: ack1})
messageQueue.Add(&ReceivedMessage{Msg: msg2, Ack: ack2})

// The blocking receiver suspends after receiving msg1.
blockingReceiver.ValidateMsg(msg1)
// Stopping the message queue should discard undelivered msg2.
messageQueue.Stop()

// Unsuspend the blocking receiver and verify msg2 is not received.
blockingReceiver.Return()
messageQueue.Wait()
blockingReceiver.VerifyNoMsgs()
if got, want := acks.outstandingAcks.Len(), 1; got != want {
t.Errorf("ackTracker.outstandingAcks.Len() got %v, want %v", got, want)
}
}

// testSubscribeStream wraps a subscribeStream for ease of testing.
Expand Down

0 comments on commit 94485e6

Please sign in to comment.