diff --git a/pubsublite/internal/wire/publish_batcher.go b/pubsublite/internal/wire/publish_batcher.go index 63999b15919..e50d3af8b31 100644 --- a/pubsublite/internal/wire/publish_batcher.go +++ b/pubsublite/internal/wire/publish_batcher.go @@ -41,6 +41,7 @@ type messageHolder struct { // MessagePublishRequest. type publishBatch struct { msgHolders []*messageHolder + totalSize int } func (b *publishBatch) ToPublishRequest() *pb.PublishRequest { @@ -93,7 +94,11 @@ func newPublishMessageBatcher(settings *PublishSettings, partition int, onNewBat // singlePartitionPublisher.onNewBatch() receives the new batch from the // Bundler, which calls publishMessageBatcher.AddBatch(). Only the // publisher's mutex is required. - onNewBatch(&publishBatch{msgHolders: msgs}) + batch := &publishBatch{msgHolders: msgs} + for _, msg := range batch.msgHolders { + batch.totalSize += msg.size + } + onNewBatch(batch) }) msgBundler.DelayThreshold = settings.DelayThreshold msgBundler.BundleCountThreshold = settings.CountThreshold @@ -164,10 +169,24 @@ func (b *publishMessageBatcher) OnPermanentError(err error) { func (b *publishMessageBatcher) InFlightBatches() []*publishBatch { var batches []*publishBatch - for elem := b.publishQueue.Front(); elem != nil; elem = elem.Next() { - if batch, ok := elem.Value.(*publishBatch); ok { - batches = append(batches, batch) + for elem := b.publishQueue.Front(); elem != nil; { + batch := elem.Value.(*publishBatch) + if elem.Prev() != nil { + // Merge current batch with previous if within max bytes and count limits. + prevBatch := elem.Prev().Value.(*publishBatch) + totalSize := prevBatch.totalSize + batch.totalSize + totalLen := len(prevBatch.msgHolders) + len(batch.msgHolders) + if totalSize <= MaxPublishRequestBytes && totalLen <= MaxPublishRequestCount { + prevBatch.totalSize = totalSize + prevBatch.msgHolders = append(prevBatch.msgHolders, batch.msgHolders...) + removeElem := elem + elem = elem.Next() + b.publishQueue.Remove(removeElem) + continue + } } + batches = append(batches, batch) + elem = elem.Next() } return batches } diff --git a/pubsublite/internal/wire/publish_batcher_test.go b/pubsublite/internal/wire/publish_batcher_test.go index 690c7edeab7..014bd0d6c05 100644 --- a/pubsublite/internal/wire/publish_batcher_test.go +++ b/pubsublite/internal/wire/publish_batcher_test.go @@ -128,8 +128,8 @@ func (br *testPublishBatchReceiver) ValidateBatches(want []*publishBatch) { } } - if !testutil.Equal(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})) { - br.t.Errorf("Batches got: %v\nwant: %v", got, want) + if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" { + br.t.Errorf("Batches got: -, want: +\n%s", diff) } } @@ -144,6 +144,15 @@ func makeMsgHolder(msg *pb.PubSubMessage, receiver ...*testPublishResultReceiver return h } +func makePublishBatch(msgs ...*messageHolder) *publishBatch { + batch := new(publishBatch) + for _, msg := range msgs { + batch.msgHolders = append(batch.msgHolders, msg) + batch.totalSize += msg.size + } + return batch +} + func TestPublishBatcherAddMessage(t *testing.T) { const initAvailableBytes = MaxPublishRequestBytes settings := DefaultPublishSettings @@ -199,22 +208,16 @@ func TestPublishBatcherBundlerCountThreshold(t *testing.T) { // Batch 1 msg1 := &pb.PubSubMessage{Data: []byte{'1'}} msg2 := &pb.PubSubMessage{Data: []byte{'2'}} - wantBatch1 := &publishBatch{ - []*messageHolder{makeMsgHolder(msg1), makeMsgHolder(msg2)}, - } + wantBatch1 := makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2)) // Batch 2 msg3 := &pb.PubSubMessage{Data: []byte{'3'}} msg4 := &pb.PubSubMessage{Data: []byte{'4'}} - wantBatch2 := &publishBatch{ - []*messageHolder{makeMsgHolder(msg3), makeMsgHolder(msg4)}, - } + wantBatch2 := makePublishBatch(makeMsgHolder(msg3), makeMsgHolder(msg4)) // Batch 3 msg5 := &pb.PubSubMessage{Data: []byte{'5'}} - wantBatch3 := &publishBatch{ - []*messageHolder{makeMsgHolder(msg5)}, - } + wantBatch3 := makePublishBatch(makeMsgHolder(msg5)) receiver := newTestPublishBatchReceiver(t) batcher := newPublishMessageBatcher(&settings, 0, receiver.onNewBatch) @@ -236,15 +239,11 @@ func TestPublishBatcherBundlerBatchingDelay(t *testing.T) { // Batch 1 msg1 := &pb.PubSubMessage{Data: []byte{'1'}} - wantBatch1 := &publishBatch{ - []*messageHolder{makeMsgHolder(msg1)}, - } + wantBatch1 := makePublishBatch(makeMsgHolder(msg1)) // Batch 2 msg2 := &pb.PubSubMessage{Data: []byte{'2'}} - wantBatch2 := &publishBatch{ - []*messageHolder{makeMsgHolder(msg2)}, - } + wantBatch2 := makePublishBatch(makeMsgHolder(msg2)) receiver := newTestPublishBatchReceiver(t) batcher := newPublishMessageBatcher(&settings, 0, receiver.onNewBatch) @@ -271,12 +270,7 @@ func TestPublishBatcherBundlerOnPermanentError(t *testing.T) { msg2 := &pb.PubSubMessage{Data: []byte{'2'}} pubResult1 := newTestPublishResultReceiver(t, msg1) pubResult2 := newTestPublishResultReceiver(t, msg2) - batcher.AddBatch(&publishBatch{ - []*messageHolder{ - makeMsgHolder(msg1, pubResult1), - makeMsgHolder(msg2, pubResult2), - }, - }) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1, pubResult1), makeMsgHolder(msg2, pubResult2))) wantErr := status.Error(codes.FailedPrecondition, "failed") batcher.OnPermanentError(wantErr) @@ -306,17 +300,8 @@ func TestPublishBatcherBundlerOnPublishResponse(t *testing.T) { pubResult2 := newTestPublishResultReceiver(t, msg2) pubResult3 := newTestPublishResultReceiver(t, msg3) - batcher.AddBatch(&publishBatch{ - []*messageHolder{ - makeMsgHolder(msg1, pubResult1), - makeMsgHolder(msg2, pubResult2), - }, - }) - batcher.AddBatch(&publishBatch{ - []*messageHolder{ - makeMsgHolder(msg3, pubResult3), - }, - }) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1, pubResult1), makeMsgHolder(msg2, pubResult2))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3, pubResult3))) if err := batcher.OnPublishResponse(70); err != nil { t.Errorf("OnPublishResponse() got err: %v", err) } @@ -332,14 +317,126 @@ func TestPublishBatcherBundlerOnPublishResponse(t *testing.T) { t.Run("inconsistent offset", func(t *testing.T) { msg := &pb.PubSubMessage{Data: []byte{'4'}} pubResult := newTestPublishResultReceiver(t, msg) - batcher.AddBatch(&publishBatch{ - []*messageHolder{ - makeMsgHolder(msg, pubResult), - }, - }) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg, pubResult))) if gotErr, wantMsg := batcher.OnPublishResponse(80), "inconsistent start offset = 80"; !test.ErrorHasMsg(gotErr, wantMsg) { t.Errorf("OnPublishResponse() got err: %v, want err msg: %q", gotErr, wantMsg) } }) } + +func TestPublishBatcherRebatching(t *testing.T) { + const partition = 2 + receiver := newTestPublishBatchReceiver(t) + + t.Run("single batch", func(t *testing.T) { + msg1 := &pb.PubSubMessage{Data: []byte{'1'}} + + batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1))) + + got := batcher.InFlightBatches() + want := []*publishBatch{ + makePublishBatch(makeMsgHolder(msg1)), + } + if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" { + t.Errorf("Batches got: -, want: +\n%s", diff) + } + }) + + t.Run("merge into single batch", func(t *testing.T) { + msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, 100)} + msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, 200)} + msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, 300)} + msg4 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{4}, 400)} + + batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2), makeMsgHolder(msg3))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg4))) + + got := batcher.InFlightBatches() + want := []*publishBatch{ + makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2), makeMsgHolder(msg3), makeMsgHolder(msg4)), + } + if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" { + t.Errorf("Batches got: -, want: +\n%s", diff) + } + }) + + t.Run("no rebatching", func(t *testing.T) { + msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, MaxPublishRequestBytes-10)} + msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, MaxPublishRequestBytes/2)} + msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, MaxPublishRequestBytes/2)} + + batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3))) + + got := batcher.InFlightBatches() + want := []*publishBatch{ + makePublishBatch(makeMsgHolder(msg1)), + makePublishBatch(makeMsgHolder(msg2)), + makePublishBatch(makeMsgHolder(msg3)), + } + if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" { + t.Errorf("Batches got: -, want: +\n%s", diff) + } + }) + + t.Run("mixed rebatching", func(t *testing.T) { + // Should be merged into a single batch. + msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, MaxPublishRequestBytes/2)} + msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, 200)} + msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, 300)} + // Not merged due to byte limit. + msg4 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{4}, MaxPublishRequestBytes-500)} + msg5 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{5}, 500)} + + batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg4))) + batcher.AddBatch(makePublishBatch(makeMsgHolder(msg5))) + + got := batcher.InFlightBatches() + want := []*publishBatch{ + makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2), makeMsgHolder(msg3)), + makePublishBatch(makeMsgHolder(msg4)), + makePublishBatch(makeMsgHolder(msg5)), + } + if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" { + t.Errorf("Batches got: -, want: +\n%s", diff) + } + }) + + t.Run("max count", func(t *testing.T) { + var msgs []*pb.PubSubMessage + var batch1 []*messageHolder + var batch2 []*messageHolder + batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch) + for i := 0; i <= MaxPublishRequestCount; i++ { + msg := &pb.PubSubMessage{Data: []byte{'0'}} + msgs = append(msgs, msg) + + msgHolder := makeMsgHolder(msg) + if i < MaxPublishRequestCount { + batch1 = append(batch1, msgHolder) + } else { + batch2 = append(batch2, msgHolder) + } + batcher.AddBatch(makePublishBatch(msgHolder)) + } + + got := batcher.InFlightBatches() + want := []*publishBatch{ + makePublishBatch(batch1...), + makePublishBatch(batch2...), + } + if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" { + t.Errorf("Batches got: -, want: +\n%s", diff) + } + }) +} diff --git a/pubsublite/internal/wire/publisher_test.go b/pubsublite/internal/wire/publisher_test.go index 1f58d3a0dcc..600d0dd8e11 100644 --- a/pubsublite/internal/wire/publisher_test.go +++ b/pubsublite/internal/wire/publisher_test.go @@ -196,8 +196,7 @@ func TestSinglePartitionPublisherResendMessages(t *testing.T) { // The publisher should resend all in-flight batches to the second stream. stream2 := test.NewRPCVerifier(t) stream2.Push(initPubReq(topic), initPubResp(), nil) - stream2.Push(msgPubReq(msg1), msgPubResp(0), nil) - stream2.Push(msgPubReq(msg2), msgPubResp(1), nil) + stream2.Push(msgPubReq(msg1, msg2), msgPubResp(0), nil) stream2.Push(msgPubReq(msg3), msgPubResp(2), nil) verifiers.AddPublishStream(topic.Path, topic.Partition, stream2)