diff --git a/pubsublite/go.mod b/pubsublite/go.mod index c5d7c5091d1..9150918a4ff 100644 --- a/pubsublite/go.mod +++ b/pubsublite/go.mod @@ -5,6 +5,7 @@ go 1.11 require ( cloud.google.com/go v0.71.0 github.com/golang/protobuf v1.4.3 + github.com/google/go-cmp v0.5.2 github.com/googleapis/gax-go/v2 v2.0.5 golang.org/x/tools v0.0.0-20201102212025-f46e4245211d // indirect google.golang.org/api v0.34.0 diff --git a/pubsublite/internal/wire/flow_control.go b/pubsublite/internal/wire/flow_control.go new file mode 100644 index 00000000000..785cdd097bb --- /dev/null +++ b/pubsublite/internal/wire/flow_control.go @@ -0,0 +1,180 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "errors" + "fmt" + "math" + + pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" +) + +var ( + errTokenCounterBytesNegative = errors.New("pubsublite: received messages that account for more bytes than were requested") + errTokenCounterMessagesNegative = errors.New("pubsublite: received more messages than were requested") + errOutOfOrderMessages = errors.New("pubsublite: server delivered messages out of order") +) + +type flowControlTokens struct { + Bytes int64 + Messages int64 +} + +// A TokenCounter stores the amount of outstanding byte and message flow control +// tokens that the client believes exists for the stream. +type tokenCounter struct { + Bytes int64 + Messages int64 +} + +func saturatedAdd(sum, delta int64) int64 { + remainder := math.MaxInt64 - sum + if delta >= remainder { + return math.MaxInt64 + } + return sum + delta +} + +func (tc *tokenCounter) Add(delta flowControlTokens) { + tc.Bytes = saturatedAdd(tc.Bytes, delta.Bytes) + tc.Messages = saturatedAdd(tc.Messages, delta.Messages) +} + +func (tc *tokenCounter) Sub(delta flowControlTokens) error { + if delta.Bytes > tc.Bytes { + return errTokenCounterBytesNegative + } + if delta.Messages > tc.Messages { + return errTokenCounterMessagesNegative + } + tc.Bytes -= delta.Bytes + tc.Messages -= delta.Messages + return nil +} + +func (tc *tokenCounter) Reset() { + tc.Bytes = 0 + tc.Messages = 0 +} + +func (tc *tokenCounter) ToFlowControlRequest() *pb.FlowControlRequest { + if tc.Bytes <= 0 && tc.Messages <= 0 { + return nil + } + return &pb.FlowControlRequest{ + AllowedBytes: tc.Bytes, + AllowedMessages: tc.Messages, + } +} + +// flowControlBatcher tracks flow control tokens and manages batching of flow +// control requests to avoid overwhelming the server. It is only accessed by +// the wireSubscriber. +type flowControlBatcher struct { + // The current amount of outstanding byte and message flow control tokens. + clientTokens tokenCounter + // The pending batch flow control request that needs to be sent to the stream. + pendingTokens tokenCounter +} + +const expediteBatchRequestRatio = 0.5 + +func exceedsExpediteRatio(pending, client int64) bool { + return client > 0 && (float64(pending)/float64(client)) >= expediteBatchRequestRatio +} + +// OnClientFlow increments flow control tokens. This occurs when: +// - Initialization from ReceiveSettings. +// - The user acks messages. +func (fc *flowControlBatcher) OnClientFlow(tokens flowControlTokens) { + fc.clientTokens.Add(tokens) + fc.pendingTokens.Add(tokens) +} + +// OnMessages decrements flow control tokens when messages are received from the +// server. +func (fc *flowControlBatcher) OnMessages(msgs []*pb.SequencedMessage) error { + var totalBytes int64 + for _, msg := range msgs { + totalBytes += msg.GetSizeBytes() + } + return fc.clientTokens.Sub(flowControlTokens{Bytes: totalBytes, Messages: int64(len(msgs))}) +} + +// RequestForRestart returns a FlowControlRequest that should be sent when a new +// subscriber stream is connected. May return nil. +func (fc *flowControlBatcher) RequestForRestart() *pb.FlowControlRequest { + fc.pendingTokens.Reset() + return fc.clientTokens.ToFlowControlRequest() +} + +// ReleasePendingRequest returns a non-nil request when there is a batch +// FlowControlRequest to send to the stream. +func (fc *flowControlBatcher) ReleasePendingRequest() *pb.FlowControlRequest { + req := fc.pendingTokens.ToFlowControlRequest() + fc.pendingTokens.Reset() + return req +} + +// ShouldExpediteBatchRequest returns true if a batch FlowControlRequest should +// be sent ASAP to avoid starving the client of messages. This occurs when the +// client is rapidly acking messages. +func (fc *flowControlBatcher) ShouldExpediteBatchRequest() bool { + if exceedsExpediteRatio(fc.pendingTokens.Bytes, fc.clientTokens.Bytes) { + return true + } + if exceedsExpediteRatio(fc.pendingTokens.Messages, fc.clientTokens.Messages) { + return true + } + return false +} + +// subscriberOffsetTracker tracks the expected offset of the next message +// received from the server. It is only accessed by the wireSubscriber. +type subscriberOffsetTracker struct { + minNextOffset int64 +} + +// RequestForRestart returns the seek request to send when a new subscribe +// stream reconnects. Returns nil if the subscriber has just started, in which +// case the server returns the offset of the last committed cursor. +func (ot *subscriberOffsetTracker) RequestForRestart() *pb.SeekRequest { + if ot.minNextOffset <= 0 { + return nil + } + return &pb.SeekRequest{ + Target: &pb.SeekRequest_Cursor{ + Cursor: &pb.Cursor{Offset: ot.minNextOffset}, + }, + } +} + +// OnMessages verifies that messages are delivered in order and updates the next +// expected offset. +func (ot *subscriberOffsetTracker) OnMessages(msgs []*pb.SequencedMessage) error { + nextOffset := ot.minNextOffset + for i, msg := range msgs { + offset := msg.GetCursor().GetOffset() + if offset < nextOffset { + if i == 0 { + return fmt.Errorf("pubsublite: server delivered messages with start offset = %d, expected >= %d", offset, ot.minNextOffset) + } + return errOutOfOrderMessages + } + nextOffset = offset + 1 + } + ot.minNextOffset = nextOffset + return nil +} diff --git a/pubsublite/internal/wire/flow_control_test.go b/pubsublite/internal/wire/flow_control_test.go new file mode 100644 index 00000000000..ca6d0661510 --- /dev/null +++ b/pubsublite/internal/wire/flow_control_test.go @@ -0,0 +1,325 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "math" + "testing" + + "cloud.google.com/go/internal/testutil" + "cloud.google.com/go/pubsublite/internal/test" + "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" + + pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" +) + +func flowControlReq(tokens flowControlTokens) *pb.FlowControlRequest { + return &pb.FlowControlRequest{ + AllowedBytes: tokens.Bytes, + AllowedMessages: tokens.Messages, + } +} + +func seqMsgWithOffset(offset int64) *pb.SequencedMessage { + return &pb.SequencedMessage{ + Cursor: &pb.Cursor{Offset: offset}, + } +} + +func seqMsgWithSizeBytes(size int64) *pb.SequencedMessage { + return &pb.SequencedMessage{ + SizeBytes: size, + } +} + +func TestTokenCounterAdd(t *testing.T) { + // Note: tests are applied to this counter instance in order. + counter := tokenCounter{} + + for _, tc := range []struct { + desc string + delta flowControlTokens + want tokenCounter + }{ + { + desc: "Initialize", + delta: flowControlTokens{Bytes: 9876543, Messages: 1234}, + want: tokenCounter{Bytes: 9876543, Messages: 1234}, + }, + { + desc: "Add delta", + delta: flowControlTokens{Bytes: 1, Messages: 2}, + want: tokenCounter{Bytes: 9876544, Messages: 1236}, + }, + { + desc: "Overflow", + delta: flowControlTokens{Bytes: math.MaxInt64, Messages: math.MaxInt64}, + want: tokenCounter{Bytes: math.MaxInt64, Messages: math.MaxInt64}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + counter.Add(tc.delta) + if !testutil.Equal(counter, tc.want) { + t.Errorf("tokenCounter.Add(%v): got %v, want %v", tc.delta, counter, tc.want) + } + }) + } +} + +func TestTokenCounterSub(t *testing.T) { + for _, tc := range []struct { + desc string + counter tokenCounter + delta flowControlTokens + want tokenCounter + wantErr error + }{ + { + desc: "Result zero", + counter: tokenCounter{Bytes: 9876543, Messages: 1234}, + delta: flowControlTokens{Bytes: 9876543, Messages: 1234}, + want: tokenCounter{Bytes: 0, Messages: 0}, + }, + { + desc: "Result non-zero", + counter: tokenCounter{Bytes: 9876543, Messages: 1234}, + delta: flowControlTokens{Bytes: 9876500, Messages: 1200}, + want: tokenCounter{Bytes: 43, Messages: 34}, + }, + { + desc: "Bytes negative", + counter: tokenCounter{Bytes: 9876543, Messages: 1234}, + delta: flowControlTokens{Bytes: 9876544, Messages: 1234}, + want: tokenCounter{Bytes: 9876543, Messages: 1234}, + wantErr: errTokenCounterBytesNegative, + }, + { + desc: "Messages negative", + counter: tokenCounter{Bytes: 9876543, Messages: 1234}, + delta: flowControlTokens{Bytes: 9876543, Messages: 1235}, + want: tokenCounter{Bytes: 9876543, Messages: 1234}, + wantErr: errTokenCounterMessagesNegative, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + gotErr := tc.counter.Sub(tc.delta) + if !testutil.Equal(tc.counter, tc.want) { + t.Errorf("tokenCounter.Sub(%v): got %v, want %v", tc.delta, tc.counter, tc.want) + } + if !test.ErrorEqual(gotErr, tc.wantErr) { + t.Errorf("tokenCounter.Sub(%v) error: got %v, want %v", tc.delta, gotErr, tc.wantErr) + } + }) + } +} + +func TestTokenCounterToFlowControlRequest(t *testing.T) { + for _, tc := range []struct { + desc string + counter tokenCounter + want *pb.FlowControlRequest + }{ + { + desc: "Uninitialized counter", + counter: tokenCounter{}, + want: nil, + }, + { + desc: "Bytes non-zero", + counter: tokenCounter{Bytes: 1}, + want: &pb.FlowControlRequest{AllowedBytes: 1}, + }, + { + desc: "Messages non-zero", + counter: tokenCounter{Messages: 1}, + want: &pb.FlowControlRequest{AllowedMessages: 1}, + }, + { + desc: "Messages and bytes", + counter: tokenCounter{Bytes: 56, Messages: 32}, + want: &pb.FlowControlRequest{AllowedBytes: 56, AllowedMessages: 32}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + got := tc.counter.ToFlowControlRequest() + if !proto.Equal(got, tc.want) { + t.Errorf("tokenCounter(%v).ToFlowControlRequest(): got %v, want %v", tc.counter, got, tc.want) + } + }) + } +} + +func TestFlowControlBatcher(t *testing.T) { + var batcher flowControlBatcher + + t.Run("Uninitialized", func(t *testing.T) { + if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want { + t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want) + } + if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want) + } + if got, want := batcher.RequestForRestart(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want) + } + }) + + t.Run("OnClientFlow-1", func(t *testing.T) { + deltaTokens := flowControlTokens{Bytes: 500, Messages: 10} + batcher.OnClientFlow(deltaTokens) + + if got, want := batcher.ShouldExpediteBatchRequest(), true; got != want { + t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want) + } + if got, want := batcher.ReleasePendingRequest(), flowControlReq(deltaTokens); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want) + } + if got, want := batcher.RequestForRestart(), flowControlReq(deltaTokens); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want) + } + }) + + t.Run("OnClientFlow-2", func(t *testing.T) { + deltaTokens := flowControlTokens{Bytes: 100, Messages: 1} + batcher.OnClientFlow(deltaTokens) + + if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want { + t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want) + } + if got, want := batcher.ReleasePendingRequest(), flowControlReq(deltaTokens); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want) + } + if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 600, Messages: 11}); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want) + } + }) + + t.Run("OnMessages-Valid", func(t *testing.T) { + msgs := []*pb.SequencedMessage{seqMsgWithSizeBytes(10), seqMsgWithSizeBytes(20)} + if gotErr := batcher.OnMessages(msgs); gotErr != nil { + t.Errorf("flowControlBatcher.OnMessages(): got err (%v), want err ", gotErr) + } + + if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want { + t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want) + } + if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want) + } + if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 570, Messages: 9}); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want) + } + }) + + t.Run("OnMessages-Underflow", func(t *testing.T) { + msgs := []*pb.SequencedMessage{seqMsgWithSizeBytes(400), seqMsgWithSizeBytes(200)} + if gotErr, wantErr := batcher.OnMessages(msgs), errTokenCounterBytesNegative; !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("flowControlBatcher.OnMessages(): got err (%v), want err (%v)", gotErr, wantErr) + } + + if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want { + t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want) + } + if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want) + } + if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 570, Messages: 9}); !proto.Equal(got, want) { + t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want) + } + }) +} + +func TestOffsetTrackerRequestForRestart(t *testing.T) { + for _, tc := range []struct { + desc string + tracker subscriberOffsetTracker + want *pb.SeekRequest + }{ + { + desc: "Uninitialized tracker", + tracker: subscriberOffsetTracker{}, + want: nil, + }, + { + desc: "Next offset positive", + tracker: subscriberOffsetTracker{minNextOffset: 1}, + want: &pb.SeekRequest{ + Target: &pb.SeekRequest_Cursor{ + Cursor: &pb.Cursor{Offset: 1}, + }, + }, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + got := tc.tracker.RequestForRestart() + if !proto.Equal(got, tc.want) { + t.Errorf("subscriberOffsetTracker(%v).RequestForRestart(): got %v, want %v", tc.tracker, got, tc.want) + } + }) + } +} + +func TestOffsetTrackerOnMessages(t *testing.T) { + for _, tc := range []struct { + desc string + tracker subscriberOffsetTracker + msgs []*pb.SequencedMessage + want subscriberOffsetTracker + wantErr bool + }{ + { + desc: "Uninitialized tracker", + tracker: subscriberOffsetTracker{}, + msgs: []*pb.SequencedMessage{seqMsgWithOffset(0)}, + want: subscriberOffsetTracker{minNextOffset: 1}, + }, + { + desc: "Consecutive message offsets", + tracker: subscriberOffsetTracker{minNextOffset: 5}, + msgs: []*pb.SequencedMessage{seqMsgWithOffset(5), seqMsgWithOffset(6), seqMsgWithOffset(7)}, + want: subscriberOffsetTracker{minNextOffset: 8}, + }, + { + desc: "Skip message offsets", + tracker: subscriberOffsetTracker{minNextOffset: 5}, + msgs: []*pb.SequencedMessage{seqMsgWithOffset(10), seqMsgWithOffset(15)}, + want: subscriberOffsetTracker{minNextOffset: 16}, + }, + { + desc: "Start offset before minNextOffset", + tracker: subscriberOffsetTracker{minNextOffset: 5}, + msgs: []*pb.SequencedMessage{seqMsgWithOffset(4)}, + want: subscriberOffsetTracker{minNextOffset: 5}, + wantErr: true, + }, + { + desc: "Unordered messages", + tracker: subscriberOffsetTracker{minNextOffset: 5}, + msgs: []*pb.SequencedMessage{seqMsgWithOffset(5), seqMsgWithOffset(10), seqMsgWithOffset(9)}, + want: subscriberOffsetTracker{minNextOffset: 5}, + wantErr: true, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + err := tc.tracker.OnMessages(tc.msgs) + if !testutil.Equal(tc.tracker, tc.want, cmp.AllowUnexported(subscriberOffsetTracker{})) { + t.Errorf("subscriberOffsetTracker().OnMessages(): got %v, want %v", tc.tracker, tc.want) + } + if gotErr := err != nil; gotErr != tc.wantErr { + t.Errorf("subscriberOffsetTracker().OnMessages() error: got (%v), want err=%v", err, tc.wantErr) + } + }) + } +}