diff --git a/pubsublite/internal/test/msg_tracker.go b/pubsublite/internal/test/msg_tracker.go new file mode 100644 index 000000000000..f89adeb67f2b --- /dev/null +++ b/pubsublite/internal/test/msg_tracker.go @@ -0,0 +1,90 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "fmt" + "sync" + "time" +) + +// MsgTracker is a helper for checking whether a set of messages make a full +// round trip from publisher to subscriber. +// +// Add() registers published messages. Remove() should be called when messages +// are received by subscribers. Call Wait() to block until all tracked messages +// are received. The same MsgTracker instance can be reused to repeat this +// sequence for multiple test cycles. +// +// Add() and Remove() calls should not be interleaved. +type MsgTracker struct { + msgMap map[string]bool + done chan struct{} + mu sync.Mutex +} + +// NewMsgTracker creates a new message tracker. +func NewMsgTracker() *MsgTracker { + return &MsgTracker{ + msgMap: make(map[string]bool), + done: make(chan struct{}, 1), + } +} + +// Add a set of tracked messages. +func (mt *MsgTracker) Add(msgs ...string) { + mt.mu.Lock() + defer mt.mu.Unlock() + + for _, msg := range msgs { + mt.msgMap[msg] = true + } +} + +// Remove and return true if `msg` is tracked. Signals the `done` channel once +// all messages have been received. +func (mt *MsgTracker) Remove(msg string) bool { + mt.mu.Lock() + defer mt.mu.Unlock() + + if _, exists := mt.msgMap[msg]; exists { + delete(mt.msgMap, msg) + if len(mt.msgMap) == 0 { + var void struct{} + mt.done <- void + } + return true + } + return false +} + +// Wait up to `timeout` to receive all tracked messages. +func (mt *MsgTracker) Wait(timeout time.Duration) error { + mt.mu.Lock() + totalCount := len(mt.msgMap) + mt.mu.Unlock() + + select { + case <-time.After(timeout): + mt.mu.Lock() + receivedCount := totalCount - len(mt.msgMap) + err := fmt.Errorf("received %d of %d messages", receivedCount, totalCount) + mt.msgMap = make(map[string]bool) + mt.mu.Unlock() + return err + + case <-mt.done: + return nil + } +} diff --git a/pubsublite/internal/test/msg_tracker_test.go b/pubsublite/internal/test/msg_tracker_test.go new file mode 100644 index 000000000000..642bfd3c18d2 --- /dev/null +++ b/pubsublite/internal/test/msg_tracker_test.go @@ -0,0 +1,56 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "testing" + "time" +) + +func TestMsgTrackerWaitSuccess(t *testing.T) { + msgs := []string{"a", "b", "c"} + + msgTracker := NewMsgTracker() + msgTracker.Add(msgs...) + + for _, msg := range msgs { + if got, want := msgTracker.Remove(msg), true; got != want { + t.Errorf("MsgTracker.Remove(%q) got %v, want %v", msg, got, want) + } + } + for _, msg := range []string{"d", "e"} { + if got, want := msgTracker.Remove(msg), false; got != want { + t.Errorf("MsgTracker.Remove(%q) got %v, want %v", msg, got, want) + } + } + if gotErr := msgTracker.Wait(time.Millisecond); gotErr != nil { + t.Errorf("MsgTracker.Wait() got err: %v", gotErr) + } +} + +func TestMsgTrackerWaitTimeout(t *testing.T) { + msgs := []string{"a", "b", "c"} + + msgTracker := NewMsgTracker() + msgTracker.Add(msgs...) + + for _, msg := range []string{"a", "c"} { + if got, want := msgTracker.Remove(msg), true; got != want { + t.Errorf("MsgTracker.Remove(%q) got %v, want %v", msg, got, want) + } + } + if gotErr, wantMsg := msgTracker.Wait(time.Millisecond), "received 2 of 3 messages"; ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("MsgTracker.Wait() got err: %v, want msg: %q", gotErr, wantMsg) + } +} diff --git a/pubsublite/internal/test/msg_validators.go b/pubsublite/internal/test/msg_validators.go new file mode 100644 index 000000000000..323e314d2697 --- /dev/null +++ b/pubsublite/internal/test/msg_validators.go @@ -0,0 +1,181 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "fmt" + "strconv" + "strings" + "sync" +) + +// OrderingSender generates strings containing a message index to use for +// verifying message ordering. It is used on conjunction with Publishers. +type OrderingSender struct { + TotalMsgCount int64 +} + +// NewOrderingSender creats a new OrderingSender. +func NewOrderingSender() *OrderingSender { + return new(OrderingSender) +} + +// Next generates the next string to publish. +func (os *OrderingSender) Next(prefix string) string { + os.TotalMsgCount++ + return fmt.Sprintf("%s/%d", prefix, os.TotalMsgCount) +} + +// OrderingReceiver consumes a message string generated by OrderingSender and +// verifies that messages in a partition are ordered. It is used in conjunction +// with Subscribers. +type OrderingReceiver struct { + mu sync.Mutex + // Map of key and last received message index. Messages are only guaranteed to + // be received in order within a partition. + received map[string]int64 +} + +// NewOrderingReceiver creates a new OrderingReceiver. +func NewOrderingReceiver() *OrderingReceiver { + return &OrderingReceiver{ + received: make(map[string]int64), + } +} + +func parseMsgIndex(msg string) int64 { + pos := strings.LastIndex(msg, "/") + if pos >= 0 { + if n, err := strconv.ParseInt(msg[pos+1:], 10, 64); err == nil { + return n + } + } + return -1 +} + +// Receive checks the given message data and key and returns an error if +// unordered messages are detected. +// +// Note: a normal scenario resulting in unordered messages is when the Publish +// stream breaks while there are in-flight batches, which are resent upon +// stream reconnect. +func (or *OrderingReceiver) Receive(data, key string) error { + or.mu.Lock() + defer or.mu.Unlock() + + idx := parseMsgIndex(data) + if idx < 0 { + return fmt.Errorf("failed to parse index from message: %q", data) + } + + // Verify non-decreasing ordering. Allow duplicates, which can be verified + // with DuplicateMsgDetector. + lastIdx, exists := or.received[key] + if exists && idx < lastIdx { + return fmt.Errorf("message ordering failed for key %s, expected message idx >= %d, got %d", key, lastIdx, idx) + } + or.received[key] = idx + return nil +} + +var void struct{} + +type msgMetadata struct { + offsets map[int64]struct{} +} + +func newMsgMetadata() *msgMetadata { + return &msgMetadata{ + offsets: make(map[int64]struct{}), + } +} + +func (mm *msgMetadata) ContainsOffset(offset int64) bool { + _, exists := mm.offsets[offset] + return exists +} + +func (mm *msgMetadata) AddOffset(offset int64) { + mm.offsets[offset] = void +} + +// DuplicateMsgDetector can be used to detect duplicate messages, either due to +// duplicate publishes or receives. +type DuplicateMsgDetector struct { + mu sync.Mutex + // Map of Pub/Sub message data and associated metadata. + msgs map[string]*msgMetadata + duplicatePublishCount int64 + duplicateReceiveCount int64 +} + +// NewDuplicateMsgDetector creates a new DuplicateMsgDetector. +func NewDuplicateMsgDetector() *DuplicateMsgDetector { + return &DuplicateMsgDetector{ + msgs: make(map[string]*msgMetadata), + } +} + +// Receive checks the given message data and offset. +func (dm *DuplicateMsgDetector) Receive(data string, offset int64) { + dm.mu.Lock() + defer dm.mu.Unlock() + + if metadata, exists := dm.msgs[data]; exists { + if metadata.ContainsOffset(offset) { + // If the message contains the same offset, it means it was received + // multiple times. This is not expected within a single test run. But it + // is normal when processes are stopped & restarted without committing + // cursors. + dm.duplicateReceiveCount++ + } else { + // If the message contains a different offset, it means a message was + // republished, which can occur when a publish stream reconnects with + // in-flight published messages. + dm.duplicatePublishCount++ + metadata.AddOffset(offset) + } + } else { + metadata = newMsgMetadata() + metadata.AddOffset(offset) + dm.msgs[data] = metadata + } +} + +// Status returns a non-empty status string if there were duplicates detected. +func (dm *DuplicateMsgDetector) Status() string { + dm.mu.Lock() + defer dm.mu.Unlock() + + if (dm.duplicateReceiveCount + dm.duplicatePublishCount) == 0 { + return "" + } + return fmt.Sprintf("duplicate publish count = %d, receive count = %d", dm.duplicatePublishCount, dm.duplicateReceiveCount) +} + +// HasPublishDuplicates returns true if duplicate published messages were +// detected. +func (dm *DuplicateMsgDetector) HasPublishDuplicates() bool { + dm.mu.Lock() + defer dm.mu.Unlock() + return dm.duplicatePublishCount > 0 +} + +// HasReceiveDuplicates returns true if duplicate received messages were +// detected. +func (dm *DuplicateMsgDetector) HasReceiveDuplicates() bool { + dm.mu.Lock() + defer dm.mu.Unlock() + return dm.duplicateReceiveCount > 0 +} diff --git a/pubsublite/internal/test/msg_validators_test.go b/pubsublite/internal/test/msg_validators_test.go new file mode 100644 index 000000000000..5def0c620cbf --- /dev/null +++ b/pubsublite/internal/test/msg_validators_test.go @@ -0,0 +1,140 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "testing" +) + +func TestOrderingSender(t *testing.T) { + sender := NewOrderingSender() + if got, want := sender.Next("prefix"), "prefix/1"; got != want { + t.Errorf("OrderingSender.Next() got %q, want %q", got, want) + } + if got, want := sender.Next("prefix"), "prefix/2"; got != want { + t.Errorf("OrderingSender.Next() got %q, want %q", got, want) + } + if got, want := sender.Next("foo"), "foo/3"; got != want { + t.Errorf("OrderingSender.Next() got %q, want %q", got, want) + } +} + +func TestOrderingReceiver(t *testing.T) { + receiver := NewOrderingReceiver() + + t.Run("Invalid message", func(t *testing.T) { + if gotErr, wantMsg := receiver.Receive("invalid", "ignored"), "failed to parse index"; !ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("OrderingReceiver.Receive() got err: %v, want msg: %q", gotErr, wantMsg) + } + }) + + t.Run("Key=foo", func(t *testing.T) { + if gotErr := receiver.Receive("foo/1", "foo"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr := receiver.Receive("foo/3", "foo"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr := receiver.Receive("foo/3", "foo"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr := receiver.Receive("foo/4", "foo"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr, wantMsg := receiver.Receive("foo/3", "foo"), "expected message idx >= 4, got 3"; !ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("OrderingReceiver.Receive() got err: %v, want msg: %q", gotErr, wantMsg) + } + }) + + t.Run("Key=bar", func(t *testing.T) { + if gotErr := receiver.Receive("bar/30", "bar"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr := receiver.Receive("bar/30", "bar"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr, wantMsg := receiver.Receive("bar/29", "bar"), "expected message idx >= 30, got 29"; !ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("OrderingReceiver.Receive() got err: %v, want msg: %q", gotErr, wantMsg) + } + }) +} + +func TestDuplicateMsgDetector(t *testing.T) { + t.Run("No duplicates", func(t *testing.T) { + duplicateDetector := NewDuplicateMsgDetector() + duplicateDetector.Receive("foo", 10) + duplicateDetector.Receive("bar", 20) + + if got, want := duplicateDetector.duplicatePublishCount, int64(0); got != want { + t.Errorf("DuplicateMsgDetector.duplicatePublishCount() got %v, want %v", got, want) + } + if got, want := duplicateDetector.duplicateReceiveCount, int64(0); got != want { + t.Errorf("DuplicateMsgDetector.duplicateReceiveCount got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasPublishDuplicates(), false; got != want { + t.Errorf("DuplicateMsgDetector.HasPublishDuplicates() got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasReceiveDuplicates(), false; got != want { + t.Errorf("DuplicateMsgDetector.HasReceiveDuplicates() got %v, want %v", got, want) + } + if got, want := duplicateDetector.Status(), ""; got != want { + t.Errorf("DuplicateMsgDetector.Status() got %q, want %q", got, want) + } + }) + + t.Run("Duplicate publish", func(t *testing.T) { + duplicateDetector := NewDuplicateMsgDetector() + duplicateDetector.Receive("foo", 10) + duplicateDetector.Receive("foo", 11) + duplicateDetector.Receive("foo", 12) + + if got, want := duplicateDetector.duplicatePublishCount, int64(2); got != want { + t.Errorf("DuplicateMsgDetector.duplicatePublishCount() got %v, want %v", got, want) + } + if got, want := duplicateDetector.duplicateReceiveCount, int64(0); got != want { + t.Errorf("DuplicateMsgDetector.duplicateReceiveCount got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasPublishDuplicates(), true; got != want { + t.Errorf("DuplicateMsgDetector.HasPublishDuplicates() got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasReceiveDuplicates(), false; got != want { + t.Errorf("DuplicateMsgDetector.HasReceiveDuplicates() got %v, want %v", got, want) + } + if got := duplicateDetector.Status(); got == "" { + t.Errorf("DuplicateMsgDetector.Status() got %q, want status string", got) + } + }) + + t.Run("Duplicate receive", func(t *testing.T) { + duplicateDetector := NewDuplicateMsgDetector() + duplicateDetector.Receive("bar", 20) + duplicateDetector.Receive("bar", 20) + + if got, want := duplicateDetector.duplicatePublishCount, int64(0); got != want { + t.Errorf("DuplicateMsgDetector.duplicatePublishCount() got %v, want %v", got, want) + } + if got, want := duplicateDetector.duplicateReceiveCount, int64(1); got != want { + t.Errorf("DuplicateMsgDetector.duplicateReceiveCount got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasPublishDuplicates(), false; got != want { + t.Errorf("DuplicateMsgDetector.HasPublishDuplicates() got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasReceiveDuplicates(), true; got != want { + t.Errorf("DuplicateMsgDetector.HasReceiveDuplicates() got %v, want %v", got, want) + } + if got := duplicateDetector.Status(); got == "" { + t.Errorf("DuplicateMsgDetector.Status() got %q, want status string", got) + } + }) +} diff --git a/pubsublite/internal/wire/message_router.go b/pubsublite/internal/wire/message_router.go index 23edebb9af82..58027c9c678d 100644 --- a/pubsublite/internal/wire/message_router.go +++ b/pubsublite/internal/wire/message_router.go @@ -19,14 +19,9 @@ import ( "math/rand" ) -// messageRouter outputs a partition number, given an ordering key. Results are -// undefined when: -// - setPartitionCount() is called with count <= 0. -// - route() is called before setPartitionCount() to initialize the router. -// -// Message routers need to accommodate topic partition resizing. +// messageRouter outputs a partition number, given an ordering key (which it may +// ignore). type messageRouter interface { - SetPartitionCount(count int) Route(orderingKey []byte) int } @@ -38,12 +33,15 @@ type roundRobinMsgRouter struct { nextPartition int } -func (r *roundRobinMsgRouter) SetPartitionCount(count int) { - r.partitionCount = count - r.nextPartition = int(r.rng.Int63n(int64(count))) +func newRoundRobinMsgRouter(rng *rand.Rand, count int) *roundRobinMsgRouter { + return &roundRobinMsgRouter{ + rng: rng, + partitionCount: count, + nextPartition: int(rng.Int63n(int64(count))), + } } -func (r *roundRobinMsgRouter) Route(orderingKey []byte) (partition int) { +func (r *roundRobinMsgRouter) Route(_ []byte) (partition int) { partition = r.nextPartition r.nextPartition = (partition + 1) % r.partitionCount return @@ -58,8 +56,10 @@ type hashingMsgRouter struct { partitionCount *big.Int } -func (r *hashingMsgRouter) SetPartitionCount(count int) { - r.partitionCount = big.NewInt(int64(count)) +func newHashingMsgRouter(count int) *hashingMsgRouter { + return &hashingMsgRouter{ + partitionCount: big.NewInt(int64(count)), + } } func (r *hashingMsgRouter) Route(orderingKey []byte) int { @@ -79,11 +79,6 @@ type compositeMsgRouter struct { keylessRouter messageRouter } -func (r *compositeMsgRouter) SetPartitionCount(count int) { - r.keyedRouter.SetPartitionCount(count) - r.keylessRouter.SetPartitionCount(count) -} - func (r *compositeMsgRouter) Route(orderingKey []byte) int { if len(orderingKey) > 0 { return r.keyedRouter.Route(orderingKey) @@ -91,11 +86,19 @@ func (r *compositeMsgRouter) Route(orderingKey []byte) int { return r.keylessRouter.Route(orderingKey) } -// defaultMessageRouter returns a compositeMsgRouter that uses hashingMsgRouter -// for messages with ordering key and roundRobinMsgRouter for messages without. -func newDefaultMessageRouter(rng *rand.Rand) messageRouter { +type messageRouterFactory struct { + rng *rand.Rand +} + +func newMessageRouterFactory(rng *rand.Rand) *messageRouterFactory { + return &messageRouterFactory{rng: rng} +} + +// New returns a compositeMsgRouter that uses hashingMsgRouter for messages with +// ordering key and roundRobinMsgRouter for messages without. +func (f *messageRouterFactory) New(partitionCount int) messageRouter { return &compositeMsgRouter{ - keyedRouter: &hashingMsgRouter{}, - keylessRouter: &roundRobinMsgRouter{rng: rng}, + keyedRouter: newHashingMsgRouter(partitionCount), + keylessRouter: newRoundRobinMsgRouter(f.rng, partitionCount), } } diff --git a/pubsublite/internal/wire/message_router_test.go b/pubsublite/internal/wire/message_router_test.go index 0f847074afae..79d2159bec5c 100644 --- a/pubsublite/internal/wire/message_router_test.go +++ b/pubsublite/internal/wire/message_router_test.go @@ -21,25 +21,7 @@ import ( "cloud.google.com/go/pubsublite/internal/test" ) -type fakeMsgRouter struct { - multiplier int - partitionCount int -} - -func (f *fakeMsgRouter) SetPartitionCount(count int) { - f.partitionCount = count -} - -func (f *fakeMsgRouter) Route(orderingKey []byte) int { - return f.partitionCount * f.multiplier -} - func TestRoundRobinMsgRouter(t *testing.T) { - // Using the same msgRouter for each test run ensures that it reinitializes - // when the partition count changes. - source := &test.FakeSource{} - msgRouter := &roundRobinMsgRouter{rng: rand.New(source)} - for _, tc := range []struct { partitionCount int source int64 @@ -57,8 +39,9 @@ func TestRoundRobinMsgRouter(t *testing.T) { }, } { t.Run(fmt.Sprintf("partitionCount=%d", tc.partitionCount), func(t *testing.T) { - source.Ret = tc.source - msgRouter.SetPartitionCount(tc.partitionCount) + source := &test.FakeSource{Ret: tc.source} + msgRouter := newRoundRobinMsgRouter(rand.New(source), tc.partitionCount) + for i, want := range tc.want { got := msgRouter.Route([]byte("IGNORED")) if got != want { @@ -70,10 +53,6 @@ func TestRoundRobinMsgRouter(t *testing.T) { } func TestHashingMsgRouter(t *testing.T) { - // Using the same msgRouter for each test run ensures that it reinitializes - // when the partition count changes. - msgRouter := &hashingMsgRouter{} - keys := [][]byte{ []byte("foo1"), []byte("foo2"), @@ -89,7 +68,7 @@ func TestHashingMsgRouter(t *testing.T) { {partitionCount: 5}, } { t.Run(fmt.Sprintf("partitionCount=%d", tc.partitionCount), func(t *testing.T) { - msgRouter.SetPartitionCount(tc.partitionCount) + msgRouter := newHashingMsgRouter(tc.partitionCount) for _, key := range keys { p1 := msgRouter.Route(key) p2 := msgRouter.Route(key) @@ -104,14 +83,16 @@ func TestHashingMsgRouter(t *testing.T) { } } -func TestCompositeMsgRouter(t *testing.T) { - keyedRouter := &fakeMsgRouter{multiplier: 10} - keylessRouter := &fakeMsgRouter{multiplier: 100} - msgRouter := &compositeMsgRouter{ - keyedRouter: keyedRouter, - keylessRouter: keylessRouter, - } +type fakeMsgRouter struct { + multiplier int + partitionCount int +} +func (f *fakeMsgRouter) Route(orderingKey []byte) int { + return f.partitionCount * f.multiplier +} + +func TestCompositeMsgRouter(t *testing.T) { for _, tc := range []struct { desc string partitionCount int @@ -138,7 +119,17 @@ func TestCompositeMsgRouter(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - msgRouter.SetPartitionCount(tc.partitionCount) + msgRouter := &compositeMsgRouter{ + keyedRouter: &fakeMsgRouter{ + multiplier: 10, + partitionCount: tc.partitionCount, + }, + keylessRouter: &fakeMsgRouter{ + multiplier: 100, + partitionCount: tc.partitionCount, + }, + } + if got := msgRouter.Route(tc.key); got != tc.want { t.Errorf("Route() = %d, want = %d", got, tc.want) }