diff --git a/pubsublite/internal/wire/assigner.go b/pubsublite/internal/wire/assigner.go index ea4be1b7e43..8f0ad4dc9cf 100644 --- a/pubsublite/internal/wire/assigner.go +++ b/pubsublite/internal/wire/assigner.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "reflect" + "sort" "github.com/google/uuid" "google.golang.org/grpc" @@ -26,26 +27,45 @@ import ( pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" ) -// partitionSet is a set of partition numbers. -type partitionSet map[int]struct{} +// PartitionSet is a set of partition numbers. +type PartitionSet map[int]struct{} -func newPartitionSet(assignmentpb *pb.PartitionAssignment) partitionSet { +// NewPartitionSet creates a partition set initialized from the given partition +// numbers. +func NewPartitionSet(partitions []int) PartitionSet { var void struct{} - partitions := make(map[int]struct{}) + partitionSet := make(map[int]struct{}) + for _, p := range partitions { + partitionSet[p] = void + } + return partitionSet +} + +func newPartitionSet(assignmentpb *pb.PartitionAssignment) PartitionSet { + var partitions []int for _, p := range assignmentpb.GetPartitions() { - partitions[int(p)] = void + partitions = append(partitions, int(p)) } - return partitionSet(partitions) + return NewPartitionSet(partitions) } -func (ps partitionSet) Ints() (partitions []int) { +// Ints returns the partitions contained in this set as an unsorted slice. +func (ps PartitionSet) Ints() (partitions []int) { for p := range ps { partitions = append(partitions, p) } return } -func (ps partitionSet) Contains(partition int) bool { +// SortedInts returns the partitions contained in this set as a sorted slice. +func (ps PartitionSet) SortedInts() (partitions []int) { + partitions = ps.Ints() + sort.Ints(partitions) + return +} + +// Contains returns true if this set contains the specified partition. +func (ps PartitionSet) Contains(partition int) bool { _, exists := ps[partition] return exists } @@ -54,9 +74,8 @@ func (ps partitionSet) Contains(partition int) bool { type generateUUIDFunc func() (uuid.UUID, error) // partitionAssignmentReceiver must enact the received partition assignment from -// the server, or otherwise return an error, which will break the stream. The -// receiver must not call the assigner, as this would result in a deadlock. -type partitionAssignmentReceiver func(partitionSet) error +// the server, or otherwise return an error, which will break the stream. +type partitionAssignmentReceiver func(PartitionSet) error // assigner wraps the partition assignment stream and notifies a receiver when // the server sends a new set of partition assignments for a subscriber. diff --git a/pubsublite/internal/wire/assigner_test.go b/pubsublite/internal/wire/assigner_test.go index ad3ae010c82..8d98c8ac6e1 100644 --- a/pubsublite/internal/wire/assigner_test.go +++ b/pubsublite/internal/wire/assigner_test.go @@ -16,7 +16,6 @@ package wire import ( "context" "errors" - "sort" "testing" "time" @@ -46,9 +45,7 @@ func TestPartitionSet(t *testing.T) { } } - gotPartitions := partitions.Ints() - sort.Ints(gotPartitions) - if !testutil.Equal(gotPartitions, wantPartitions) { + if gotPartitions := partitions.SortedInts(); !testutil.Equal(gotPartitions, wantPartitions) { t.Errorf("Ints() got %v, want %v", gotPartitions, wantPartitions) } } @@ -91,9 +88,8 @@ func newTestAssigner(t *testing.T, subscription string, recvErr error) *testAssi return ta } -func (ta *testAssigner) receiveAssignment(partitions partitionSet) error { - p := partitions.Ints() - sort.Ints(p) +func (ta *testAssigner) receiveAssignment(partitions PartitionSet) error { + p := partitions.SortedInts() ta.partitions <- p if ta.recvError != nil { diff --git a/pubsublite/internal/wire/subscriber.go b/pubsublite/internal/wire/subscriber.go index b6860a6dc21..d8b84b433d8 100644 --- a/pubsublite/internal/wire/subscriber.go +++ b/pubsublite/internal/wire/subscriber.go @@ -415,13 +415,14 @@ func (f *singlePartitionSubscriberFactory) New(partition int) *singlePartitionSu // partitions. type multiPartitionSubscriber struct { // Immutable after creation. - subscribers []*singlePartitionSubscriber + subscribers map[int]*singlePartitionSubscriber apiClientService } func newMultiPartitionSubscriber(allClients apiClients, subFactory *singlePartitionSubscriberFactory) *multiPartitionSubscriber { ms := &multiPartitionSubscriber{ + subscribers: make(map[int]*singlePartitionSubscriber), apiClientService: apiClientService{clients: allClients}, } ms.init() @@ -429,7 +430,7 @@ func newMultiPartitionSubscriber(allClients apiClients, subFactory *singlePartit for _, partition := range subFactory.settings.Partitions { subscriber := subFactory.New(partition) ms.unsafeAddServices(subscriber) - ms.subscribers = append(ms.subscribers, subscriber) + ms.subscribers[partition] = subscriber } return ms } @@ -445,13 +446,23 @@ func (ms *multiPartitionSubscriber) Terminate() { } } +// PartitionActive returns whether the partition is active. +func (ms *multiPartitionSubscriber) PartitionActive(partition int) bool { + _, exists := ms.subscribers[partition] + return exists +} + +// ReassignmentHandlerFunc receives a partition assignment change. +type ReassignmentHandlerFunc func(before, after PartitionSet) error + // assigningSubscriber uses the Pub/Sub Lite partition assignment service to // listen to its assigned partition numbers and dynamically add/remove // singlePartitionSubscribers. type assigningSubscriber struct { // Immutable after creation. - subFactory *singlePartitionSubscriberFactory - assigner *assigner + reassignmentHandler ReassignmentHandlerFunc + subFactory *singlePartitionSubscriberFactory + assigner *assigner // Fields below must be guarded with mu. // Subscribers keyed by partition number. Updated as assignments change. @@ -460,11 +471,13 @@ type assigningSubscriber struct { apiClientService } -func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.PartitionAssignmentClient, genUUID generateUUIDFunc, subFactory *singlePartitionSubscriberFactory) (*assigningSubscriber, error) { +func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.PartitionAssignmentClient, reassignmentHandler ReassignmentHandlerFunc, + genUUID generateUUIDFunc, subFactory *singlePartitionSubscriberFactory) (*assigningSubscriber, error) { as := &assigningSubscriber{ - apiClientService: apiClientService{clients: allClients}, - subFactory: subFactory, - subscribers: make(map[int]*singlePartitionSubscriber), + apiClientService: apiClientService{clients: allClients}, + reassignmentHandler: reassignmentHandler, + subFactory: subFactory, + subscribers: make(map[int]*singlePartitionSubscriber), } as.init() @@ -477,12 +490,17 @@ func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.Partit return as, nil } -func (as *assigningSubscriber) handleAssignment(partitions partitionSet) error { - removedSubscribers, err := as.doHandleAssignment(partitions) +func (as *assigningSubscriber) handleAssignment(nextPartitions PartitionSet) error { + previousPartitions, removedSubscribers, err := as.doHandleAssignment(nextPartitions) if err != nil { return err } + // Notify the user reassignment handler. + if err := as.reassignmentHandler(previousPartitions, nextPartitions); err != nil { + return err + } + // Wait for removed subscribers to completely stop (which waits for commit // acknowledgments from the server) before acking the assignment. This avoids // commits racing with the new assigned client. @@ -492,17 +510,23 @@ func (as *assigningSubscriber) handleAssignment(partitions partitionSet) error { return nil } -func (as *assigningSubscriber) doHandleAssignment(partitions partitionSet) ([]*singlePartitionSubscriber, error) { +// Returns the previous set of partitions and removed subscribers. +func (as *assigningSubscriber) doHandleAssignment(nextPartitions PartitionSet) (PartitionSet, []*singlePartitionSubscriber, error) { as.mu.Lock() defer as.mu.Unlock() + var previousPartitions []int + for partition := range as.subscribers { + previousPartitions = append(previousPartitions, partition) + } + // Handle new partitions. - for _, partition := range partitions.Ints() { + for _, partition := range nextPartitions.Ints() { if _, exists := as.subscribers[partition]; !exists { subscriber := as.subFactory.New(partition) if err := as.unsafeAddServices(subscriber); err != nil { // Occurs when the assigningSubscriber is stopping/stopped. - return nil, err + return nil, nil, err } as.subscribers[partition] = subscriber } @@ -511,7 +535,7 @@ func (as *assigningSubscriber) doHandleAssignment(partitions partitionSet) ([]*s // Handle removed partitions. var removedSubscribers []*singlePartitionSubscriber for partition, subscriber := range as.subscribers { - if !partitions.Contains(partition) { + if !nextPartitions.Contains(partition) { // Ignore unacked messages from this point on to avoid conflicting with // the commits of the new subscriber that will be assigned this partition. subscriber.Terminate() @@ -523,7 +547,7 @@ func (as *assigningSubscriber) doHandleAssignment(partitions partitionSet) ([]*s delete(as.subscribers, partition) } } - return removedSubscribers, nil + return NewPartitionSet(previousPartitions), removedSubscribers, nil } // Terminate shuts down all singlePartitionSubscribers without waiting for @@ -537,6 +561,15 @@ func (as *assigningSubscriber) Terminate() { } } +// PartitionActive returns whether the partition is still active. +func (as *assigningSubscriber) PartitionActive(partition int) bool { + as.mu.Lock() + defer as.mu.Unlock() + + _, exists := as.subscribers[partition] + return exists +} + // Subscriber is the client interface exported from this package for receiving // messages. type Subscriber interface { @@ -545,10 +578,12 @@ type Subscriber interface { Stop() WaitStopped() error Terminate() + PartitionActive(int) bool } // NewSubscriber creates a new client for receiving messages. -func NewSubscriber(ctx context.Context, settings ReceiveSettings, receiver MessageReceiverFunc, region, subscriptionPath string, opts ...option.ClientOption) (Subscriber, error) { +func NewSubscriber(ctx context.Context, settings ReceiveSettings, receiver MessageReceiverFunc, reassignmentHandler ReassignmentHandlerFunc, + region, subscriptionPath string, opts ...option.ClientOption) (Subscriber, error) { if err := ValidateRegion(region); err != nil { return nil, err } @@ -588,5 +623,5 @@ func NewSubscriber(ctx context.Context, settings ReceiveSettings, receiver Messa return nil, err } allClients = append(allClients, partitionClient) - return newAssigningSubscriber(allClients, partitionClient, uuid.NewRandom, subFactory) + return newAssigningSubscriber(allClients, partitionClient, reassignmentHandler, uuid.NewRandom, subFactory) } diff --git a/pubsublite/internal/wire/subscriber_test.go b/pubsublite/internal/wire/subscriber_test.go index 6aa5fe77933..f7dca0359fb 100644 --- a/pubsublite/internal/wire/subscriber_test.go +++ b/pubsublite/internal/wire/subscriber_test.go @@ -15,6 +15,7 @@ package wire import ( "context" + "errors" "sort" "sync" "testing" @@ -30,17 +31,22 @@ import ( pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" ) +const ( + maxMessages int = 10 + maxBytes int = 1000 +) + func testSubscriberSettings() ReceiveSettings { settings := testReceiveSettings() - settings.MaxOutstandingMessages = 10 - settings.MaxOutstandingBytes = 1000 + settings.MaxOutstandingMessages = maxMessages + settings.MaxOutstandingBytes = maxBytes return settings } // initFlowControlReq returns the first expected flow control request when // testSubscriberSettings are used. func initFlowControlReq() *pb.SubscribeRequest { - return flowControlSubReq(flowControlTokens{Bytes: 1000, Messages: 10}) + return flowControlSubReq(flowControlTokens{Bytes: int64(maxBytes), Messages: int64(maxMessages)}) } func partitionMsgs(partition int, msgs ...*pb.SequencedMessage) []*ReceivedMessage { @@ -929,6 +935,15 @@ func TestSinglePartitionSubscriberStopDuringAdminSeek(t *testing.T) { } } +func verifyPartitionsActive(t *testing.T, sub Subscriber, want bool, partitions ...int) { + t.Helper() + for _, p := range partitions { + if got := sub.PartitionActive(p); got != want { + t.Errorf("PartitionActive(%d) got %v, want %v", p, got, want) + } + } +} + func newTestMultiPartitionSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscriptionPath string, partitions []int) *multiPartitionSubscriber { ctx := context.Background() subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn()) @@ -994,6 +1009,9 @@ func TestMultiPartitionSubscriberMultipleMessages(t *testing.T) { defer mockServer.OnTestEnd() sub := newTestMultiPartitionSubscriber(t, receiver.onMessage, subscription, []int{1, 2}) + verifyPartitionsActive(t, sub, true, 1, 2) + verifyPartitionsActive(t, sub, false, 0, 3) + if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } @@ -1056,18 +1074,6 @@ func TestMultiPartitionSubscriberPermanentError(t *testing.T) { receiver.VerifyNoMsgs() } -func (as *assigningSubscriber) Partitions() []int { - as.mu.Lock() - defer as.mu.Unlock() - - var partitions []int - for p := range as.subscribers { - partitions = append(partitions, p) - } - sort.Ints(partitions) - return partitions -} - func (as *assigningSubscriber) Subscribers() []*singlePartitionSubscriber { as.mu.Lock() defer as.mu.Unlock() @@ -1088,7 +1094,11 @@ func (as *assigningSubscriber) FlushCommits() { } } -func newTestAssigningSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscriptionPath string) *assigningSubscriber { +func noopReassignmentHandler(_, _ PartitionSet) error { + return nil +} + +func newTestAssigningSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, reassignmentHandler ReassignmentHandlerFunc, subscriptionPath string) *assigningSubscriber { ctx := context.Background() subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn()) if err != nil { @@ -1113,7 +1123,7 @@ func newTestAssigningSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, receiver: receiverFunc, disableTasks: true, // Background tasks disabled to control event order } - sub, err := newAssigningSubscriber(allClients, assignmentClient, fakeGenerateUUID, f) + sub, err := newAssigningSubscriber(allClients, assignmentClient, reassignmentHandler, fakeGenerateUUID, f) if err != nil { t.Fatal(err) } @@ -1179,23 +1189,21 @@ func TestAssigningSubscriberAddRemovePartitions(t *testing.T) { mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription) + sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } // Partition assignments are initially {3, 6}. receiver.ValidateMsgs(join(partitionMsgs(3, msg1), partitionMsgs(6, msg3))) - if got, want := sub.Partitions(), []int{3, 6}; !testutil.Equal(got, want) { - t.Errorf("subscriber partitions: got %d, want %d", got, want) - } + verifyPartitionsActive(t, sub, true, 3, 6) + verifyPartitionsActive(t, sub, false, 1, 8) // Partition assignments will now be {3, 8}. assignmentBarrier1.Release() receiver.ValidateMsgs(partitionMsgs(8, msg5)) - if got, want := sub.Partitions(), []int{3, 8}; !testutil.Equal(got, want) { - t.Errorf("subscriber partitions: got %d, want %d", got, want) - } + verifyPartitionsActive(t, sub, true, 3, 8) + verifyPartitionsActive(t, sub, false, 2, 6) // msg2 is from partition 3 and should be received. msg4 is from partition 6 // (removed) and should be discarded. @@ -1255,7 +1263,7 @@ func TestAssigningSubscriberPermanentError(t *testing.T) { mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription) + sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } @@ -1298,7 +1306,7 @@ func TestAssigningSubscriberIgnoreOutstandingAcks(t *testing.T) { mockServer.OnTestStart(verifiers) defer mockServer.OnTestEnd() - sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription) + sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription) if gotErr := sub.WaitStarted(); gotErr != nil { t.Errorf("Start() got err: (%v)", gotErr) } @@ -1330,6 +1338,116 @@ func TestAssigningSubscriberIgnoreOutstandingAcks(t *testing.T) { } } +func TestAssigningSubscriberStoppedWhileReassignmentHandlerActive(t *testing.T) { + const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub" + receiver := newTestMessageReceiver(t) + + verifiers := test.NewVerifiers(t) + + // Assignment stream + asnStream := test.NewRPCVerifier(t) + asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{1}), nil) + verifiers.AddAssignmentStream(subscription, asnStream) + + // Partition 1 + subStream := test.NewRPCVerifier(t) + subStream.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil) + subBarrier := subStream.PushWithBarrier(initFlowControlReq(), nil, nil) + verifiers.AddSubscribeStream(subscription, 1, subStream) + + cmtStream := test.NewRPCVerifier(t) + cmtBarrier := cmtStream.PushWithBarrier(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil) + verifiers.AddCommitStream(subscription, 1, cmtStream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + reassignmentHandlerCalled := test.NewCondition("reassignment handler called") + returnReassignmentHandler := test.NewCondition("return reassignment handler") + onReassignment := func(before, after PartitionSet) error { + if got, want := len(before.SortedInts()), 0; got != want { + t.Errorf("len(before): got %v, want %v", got, want) + } + if got, want := after.SortedInts(), []int{1}; !testutil.Equal(got, want) { + t.Errorf("after: got %v, want %v", got, want) + } + reassignmentHandlerCalled.SetDone() + returnReassignmentHandler.WaitUntilDone(t, serviceTestWaitTimeout) + return nil + } + + sub := newTestAssigningSubscriber(t, receiver.onMessage, onReassignment, subscription) + if gotErr := sub.WaitStarted(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + + // Used to control order of execution to ensure the test is deterministic. + subBarrier.Release() + cmtBarrier.Release() + + // Ensure there are no deadlocks if the reassignment handler blocks and the + // subscriber is stopped. + reassignmentHandlerCalled.WaitUntilDone(t, serviceTestWaitTimeout) + sub.Stop() + returnReassignmentHandler.SetDone() + + if gotErr := sub.WaitStopped(); gotErr != nil { + t.Errorf("WaitStopped() got err: (%v)", gotErr) + } +} + +func TestAssigningSubscriberReassignmentHandlerReturnsError(t *testing.T) { + const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub" + receiver := newTestMessageReceiver(t) + + verifiers := test.NewVerifiers(t) + + // Assignment stream + asnStream := test.NewRPCVerifier(t) + asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{1}), nil) + verifiers.AddAssignmentStream(subscription, asnStream) + + // Partition 1 + subStream := test.NewRPCVerifier(t) + subStream.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil) + subBarrier := subStream.PushWithBarrier(initFlowControlReq(), nil, nil) + verifiers.AddSubscribeStream(subscription, 1, subStream) + + cmtStream := test.NewRPCVerifier(t) + cmtBarrier := cmtStream.PushWithBarrier(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil) + verifiers.AddCommitStream(subscription, 1, cmtStream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + reassignmentErr := errors.New("reassignment handler error") + returnReassignmentErr := test.NewCondition("return reassignment error") + onAssignment := func(before, after PartitionSet) error { + if got, want := len(before.SortedInts()), 0; got != want { + t.Errorf("len(before): got %v, want %v", got, want) + } + if got, want := after.SortedInts(), []int{1}; !testutil.Equal(got, want) { + t.Errorf("after: got %v, want %v", got, want) + } + returnReassignmentErr.WaitUntilDone(t, serviceTestWaitTimeout) + return reassignmentErr + } + + sub := newTestAssigningSubscriber(t, receiver.onMessage, onAssignment, subscription) + if gotErr := sub.WaitStarted(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + + // Used to control order of execution to ensure the test is deterministic. + subBarrier.Release() + cmtBarrier.Release() + returnReassignmentErr.SetDone() + + if gotErr := sub.WaitStopped(); !test.ErrorEqual(gotErr, reassignmentErr) { + t.Errorf("WaitStopped() got err: (%v), want err: (%v)", gotErr, reassignmentErr) + } +} + func TestNewSubscriberValidatesSettings(t *testing.T) { const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub" const region = "us-central1" @@ -1337,7 +1455,7 @@ func TestNewSubscriberValidatesSettings(t *testing.T) { settings := DefaultReceiveSettings settings.MaxOutstandingMessages = 0 - if _, err := NewSubscriber(context.Background(), settings, receiver.onMessage, region, subscription); err == nil { + if _, err := NewSubscriber(context.Background(), settings, receiver.onMessage, noopReassignmentHandler, region, subscription); err == nil { t.Error("NewSubscriber() did not return error") } } diff --git a/pubsublite/pscompat/integration_test.go b/pubsublite/pscompat/integration_test.go index 59aa779ba6d..86da439fc8f 100644 --- a/pubsublite/pscompat/integration_test.go +++ b/pubsublite/pscompat/integration_test.go @@ -17,6 +17,7 @@ import ( "context" "errors" "fmt" + "sort" "strings" "sync" "sync/atomic" @@ -729,12 +730,29 @@ func TestIntegration_PublishSubscribeMultiPartition(t *testing.T) { } } + // Verify partition reassignment notifications. + var allPartitions []int + var mu sync.Mutex + reassignmentHandler := func(before, after []int) error { + t.Logf("Partition assignments: before %v, after %v", before, after) + if got, want := len(before), 0; got != want { + t.Errorf("Partition assignments len(before): got %d, want %d", got, want) + } + mu.Lock() + allPartitions = append(allPartitions, after...) + mu.Unlock() + return nil + } + + receiveSettings := DefaultReceiveSettings + receiveSettings.ReassignmentHandler = reassignmentHandler + cctx, stopSubscribers := context.WithTimeout(context.Background(), defaultTestTimeout) g, _ := errgroup.WithContext(ctx) for i := 0; i < subscriberCount; i++ { // Subscribers must be started in a goroutine as Receive() blocks. g.Go(func() error { - subscriber := subscriberClient(context.Background(), t, DefaultReceiveSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, receiveSettings, subscriptionPath) err := subscriber.Receive(cctx, messageReceiver) if err != nil { t.Errorf("Receive() got err: %v", err) @@ -748,6 +766,13 @@ func TestIntegration_PublishSubscribeMultiPartition(t *testing.T) { stopSubscribers() // Wait until all subscribers have terminated. g.Wait() + + mu.Lock() + sort.Ints(allPartitions) + if got, want := allPartitions, partitionNumbers(partitionCount); !testutil.Equal(got, want) { + t.Errorf("Assigned partition numbers: got %v, want %v", got, want) + } + mu.Unlock() }) } diff --git a/pubsublite/pscompat/settings.go b/pubsublite/pscompat/settings.go index 59b84c84556..a4b8b51c817 100644 --- a/pubsublite/pscompat/settings.go +++ b/pubsublite/pscompat/settings.go @@ -164,6 +164,30 @@ type NackHandler func(*pubsub.Message) error // error and terminate. type ReceiveMessageTransformerFunc func(*pb.SequencedMessage, *pubsub.Message) error +// ReassignmentHandlerFunc is called any time a new partition assignment is +// received from the server. It will be called with both the previous and new +// partition numbers as decided by the server. Both slices of partition numbers +// are sorted in ascending order. +// +// When this handler is called, partitions that are being assigned away are +// stopping and new partitions are starting. Acks and nacks for messages from +// partitions that are being assigned away will have no effect, but message +// deliveries may still be in flight. +// +// The client library will not acknowledge the assignment until this handler +// returns. The server will not assign any of the partitions in +// `previousPartitions` to another client unless the assignment is acknowledged, +// or a client takes too long to acknowledge (currently 30 seconds from the time +// the assignment is sent from server's point of view). +// +// Because of the above, as long as reassignment handling is processed quickly, +// it can be used to abort outstanding operations on partitions which are being +// assigned away from this client. +// +// If this handler returns an error, the SubscriberClient will consider this a +// fatal error and terminate. +type ReassignmentHandlerFunc func(previousPartitions, nextPartitions []int) error + // ReceiveSettings configure the SubscriberClient. Flow control settings // (MaxOutstandingMessages, MaxOutstandingBytes) apply per partition. // @@ -210,6 +234,10 @@ type ReceiveSettings struct { // Optional custom function that transforms a SequencedMessage API proto to a // pubsub.Message. MessageTransformer ReceiveMessageTransformerFunc + + // Optional custom function that is called when a new partition assignment has + // been delivered to the client. + ReassignmentHandler ReassignmentHandlerFunc } // DefaultReceiveSettings holds the default values for ReceiveSettings. diff --git a/pubsublite/pscompat/subscriber.go b/pubsublite/pscompat/subscriber.go index 368d735e347..d0d00e0a3f1 100644 --- a/pubsublite/pscompat/subscriber.go +++ b/pubsublite/pscompat/subscriber.go @@ -40,6 +40,7 @@ func handleNack(_ *pubsub.Message) error { type pslAckHandler struct { ackh wire.AckConsumer msg *pubsub.Message + partition int nackh NackHandler subInstance *subscriberInstance } @@ -58,6 +59,11 @@ func (ah *pslAckHandler) OnNack() { return } + // Ignore nacks for partitions that have been assigned away. + if !ah.subInstance.wireSub.PartitionActive(ah.partition) { + return + } + err := ah.nackh(ah.msg) if err != nil { // If the NackHandler returns an error, shut down the subscriber client. @@ -72,7 +78,7 @@ func (ah *pslAckHandler) OnNack() { // wireSubscriberFactory is a factory for creating wire subscribers, which can // be overridden with a mock in unit tests. type wireSubscriberFactory interface { - New(context.Context, wire.MessageReceiverFunc) (wire.Subscriber, error) + New(context.Context, wire.MessageReceiverFunc, wire.ReassignmentHandlerFunc) (wire.Subscriber, error) } type wireSubscriberFactoryImpl struct { @@ -82,8 +88,8 @@ type wireSubscriberFactoryImpl struct { options []option.ClientOption } -func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { - return wire.NewSubscriber(ctx, f.settings, receiver, f.region, f.subscription.String(), f.options...) +func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc, onReassignment wire.ReassignmentHandlerFunc) (wire.Subscriber, error) { + return wire.NewSubscriber(ctx, f.settings, receiver, onReassignment, f.region, f.subscription.String(), f.options...) } type messageReceiverFunc = func(context.Context, *pubsub.Message) @@ -116,7 +122,7 @@ func newSubscriberInstance(recvCtx, clientCtx context.Context, factory wireSubsc // cancelled, the gRPC streams will be disconnected and the subscriber will // not be able to process acks and commit the final cursor offset. Use the // context from NewSubscriberClient (clientCtx) instead. - wireSub, err := factory.New(clientCtx, subInstance.onMessage) + wireSub, err := factory.New(clientCtx, subInstance.onMessage, subInstance.onReassignment) if err != nil { return nil, err } @@ -131,6 +137,13 @@ func newSubscriberInstance(recvCtx, clientCtx context.Context, factory wireSubsc return subInstance, nil } +func (si *subscriberInstance) onReassignment(before, after wire.PartitionSet) error { + if si.settings.ReassignmentHandler != nil { + return si.settings.ReassignmentHandler(before.SortedInts(), after.SortedInts()) + } + return nil +} + func (si *subscriberInstance) transformMessage(in *wire.ReceivedMessage, out *pubsub.Message) error { if err := si.settings.MessageTransformer(in.Msg, out); err != nil { return err @@ -147,6 +160,7 @@ func (si *subscriberInstance) onMessage(msg *wire.ReceivedMessage) { pslAckh := &pslAckHandler{ ackh: msg.Ack, nackh: si.settings.NackHandler, + partition: msg.Partition, subInstance: si, } psMsg := ipubsub.NewMessage(pslAckh) diff --git a/pubsublite/pscompat/subscriber_test.go b/pubsublite/pscompat/subscriber_test.go index 5c737f8ff42..6bd4ac8c788 100644 --- a/pubsublite/pscompat/subscriber_test.go +++ b/pubsublite/pscompat/subscriber_test.go @@ -30,7 +30,10 @@ import ( pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" ) -const defaultSubscriberTestTimeout = 10 * time.Second +const ( + defaultSubscriberTestTimeout = 10 * time.Second + activePartition = 1 +) // mockAckConsumer is a mock implementation of the wire.AckConsumer interface. type mockAckConsumer struct { @@ -43,12 +46,14 @@ func (ac *mockAckConsumer) Ack() { // mockWireSubscriber is a mock implementation of the wire.Subscriber interface. type mockWireSubscriber struct { - receiver wire.MessageReceiverFunc - msgsC chan *wire.ReceivedMessage - stopC chan struct{} - err error - Stopped bool - Terminated bool + receiver wire.MessageReceiverFunc + onReassignment wire.ReassignmentHandlerFunc + activePartitions wire.PartitionSet + msgsC chan *wire.ReceivedMessage + stopC chan struct{} + err error + Stopped bool + Terminated bool } // DeliverMessages should be called from the test to simulate a message @@ -59,6 +64,14 @@ func (ms *mockWireSubscriber) DeliverMessages(msgs ...*wire.ReceivedMessage) { } } +// OnReassignment should be called from the test to simulate a partition +// reassignment. +func (ms *mockWireSubscriber) DeliverReassignment(before, after wire.PartitionSet) { + if err := ms.onReassignment(before, after); err != nil { + ms.SimulateFatalError(err) + } +} + // SimulateFatalError should be called from the test to simulate a fatal error // occurring in the wire subscriber. func (ms *mockWireSubscriber) SimulateFatalError(err error) { @@ -111,13 +124,19 @@ func (ms *mockWireSubscriber) WaitStopped() error { return ms.err } +func (ms *mockWireSubscriber) PartitionActive(partition int) bool { + return ms.activePartitions.Contains(partition) +} + type mockWireSubscriberFactory struct{} -func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { +func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc, onReassignment wire.ReassignmentHandlerFunc) (wire.Subscriber, error) { return &mockWireSubscriber{ - receiver: receiver, - msgsC: make(chan *wire.ReceivedMessage, 10), - stopC: make(chan struct{}), + receiver: receiver, + onReassignment: onReassignment, + activePartitions: wire.NewPartitionSet([]int{activePartition}), + msgsC: make(chan *wire.ReceivedMessage, 10), + stopC: make(chan struct{}), }, nil } @@ -289,6 +308,7 @@ func TestSubscriberInstanceNack(t *testing.T) { desc string // mutateSettings is passed a copy of DefaultReceiveSettings to mutate. mutateSettings func(settings *ReceiveSettings) + msgPartition int wantErr error wantAckCount int wantStopped bool @@ -297,10 +317,19 @@ func TestSubscriberInstanceNack(t *testing.T) { { desc: "default settings", mutateSettings: func(settings *ReceiveSettings) {}, + msgPartition: activePartition, wantErr: errNackCalled, wantAckCount: 0, wantTerminated: true, }, + { + desc: "message partition inactive", + mutateSettings: func(settings *ReceiveSettings) {}, + msgPartition: activePartition + 1, + wantErr: nil, + wantAckCount: 0, + wantStopped: true, + }, { desc: "nack handler returns nil", mutateSettings: func(settings *ReceiveSettings) { @@ -308,6 +337,7 @@ func TestSubscriberInstanceNack(t *testing.T) { return nil } }, + msgPartition: activePartition, wantErr: nil, wantAckCount: 1, wantStopped: true, @@ -319,6 +349,7 @@ func TestSubscriberInstanceNack(t *testing.T) { return nackErr } }, + msgPartition: activePartition, wantErr: nackErr, wantAckCount: 0, wantTerminated: true, @@ -329,7 +360,7 @@ func TestSubscriberInstanceNack(t *testing.T) { tc.mutateSettings(&settings) ack := &mockAckConsumer{} - msg := &wire.ReceivedMessage{Msg: msg, Ack: ack} + msg := &wire.ReceivedMessage{Msg: msg, Ack: ack, Partition: tc.msgPartition} cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout) messageReceiver := func(ctx context.Context, got *pubsub.Message) { @@ -363,6 +394,67 @@ func TestSubscriberInstanceNack(t *testing.T) { } } +func TestSubscriberInstanceReassignmentHandler(t *testing.T) { + reassignmentErr := errors.New("reassignment failure") + before := wire.NewPartitionSet([]int{3, 2, 1}) + after := wire.NewPartitionSet([]int{4, 5, 3}) + ctx := context.Background() + + for _, tc := range []struct { + desc string + // mutateSettings is passed a copy of DefaultReceiveSettings to mutate. + mutateSettings func(settings *ReceiveSettings) + wantErr error + }{ + { + desc: "default settings", + mutateSettings: func(settings *ReceiveSettings) {}, + }, + { + desc: "reassignment handler returns nil", + mutateSettings: func(settings *ReceiveSettings) { + settings.ReassignmentHandler = func(before, after []int) error { + if got, want := before, []int{1, 2, 3}; !testutil.Equal(got, want) { + t.Errorf("before: got %d, want %d", got, want) + } + if got, want := after, []int{3, 4, 5}; !testutil.Equal(got, want) { + t.Errorf("after: got %d, want %d", got, want) + } + return nil + } + }, + }, + { + desc: "reassignment handler returns error", + mutateSettings: func(settings *ReceiveSettings) { + settings.ReassignmentHandler = func(before, after []int) error { + return reassignmentErr + } + }, + wantErr: reassignmentErr, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + settings := DefaultReceiveSettings + tc.mutateSettings(&settings) + + cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout) + messageReceiver := func(ctx context.Context, got *pubsub.Message) { + t.Error("Message receiver should not be called") + } + subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver) + subInstance.wireSub.(*mockWireSubscriber).DeliverReassignment(before, after) + if tc.wantErr == nil { + stopSubscriber() + } + + if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) { + t.Errorf("subscriberInstance.Wait() got err: (%v), want err: (%v)", gotErr, tc.wantErr) + } + }) + } +} + func TestSubscriberInstanceWireSubscriberFails(t *testing.T) { fatalErr := errors.New("server error")