diff --git a/pubsublite/go.mod b/pubsublite/go.mod index d707a1af497..a143bc0d9cd 100644 --- a/pubsublite/go.mod +++ b/pubsublite/go.mod @@ -9,6 +9,7 @@ require ( github.com/google/go-cmp v0.5.4 github.com/google/uuid v1.1.4 github.com/googleapis/gax-go/v2 v2.0.5 + golang.org/x/sync v0.0.0-20201207232520-09787c993a3a golang.org/x/tools v0.0.0-20210105210202-9ed45478a130 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 google.golang.org/api v0.36.0 diff --git a/pubsublite/internal/wire/committer.go b/pubsublite/internal/wire/committer.go index e6cf4be0079..0e32b95a346 100644 --- a/pubsublite/internal/wire/committer.go +++ b/pubsublite/internal/wire/committer.go @@ -98,6 +98,16 @@ func (c *committer) Stop() { c.unsafeInitiateShutdown(serviceTerminating, nil) } +// Terminate will discard outstanding acks and send the final commit offset to +// the server. +func (c *committer) Terminate() { + c.mu.Lock() + defer c.mu.Unlock() + + c.acks.Release() + c.unsafeInitiateShutdown(serviceTerminating, nil) +} + func (c *committer) newStream(ctx context.Context) (grpc.ClientStream, error) { return c.cursorClient.StreamingCommitCursor(ctx) } @@ -201,18 +211,18 @@ func (c *committer) unsafeInitiateShutdown(targetStatus serviceStatus, err error // Otherwise discard outstanding acks and immediately terminate the stream. c.acks.Release() - c.unsafeTerminate() + c.unsafeOnTerminated() } func (c *committer) unsafeCheckDone() { // The commit stream can be closed once the final commit offset has been // confirmed and there are no outstanding acks. if c.status == serviceTerminating && c.cursorTracker.UpToDate() && c.acks.Empty() { - c.unsafeTerminate() + c.unsafeOnTerminated() } } -func (c *committer) unsafeTerminate() { +func (c *committer) unsafeOnTerminated() { c.pollCommits.Stop() c.stream.Stop() } diff --git a/pubsublite/internal/wire/committer_test.go b/pubsublite/internal/wire/committer_test.go index 2f0144ea0da..6ebd4fbaf18 100644 --- a/pubsublite/internal/wire/committer_test.go +++ b/pubsublite/internal/wire/committer_test.go @@ -48,6 +48,10 @@ func (tc *testCommitter) SendBatchCommit() { tc.cmt.commitOffsetToStream() } +func (tc *testCommitter) Terminate() { + tc.cmt.Terminate() +} + func TestCommitterStreamReconnect(t *testing.T) { subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} ack1 := newAckConsumer(33, 0, nil) @@ -123,6 +127,37 @@ func TestCommitterStopFlushesCommits(t *testing.T) { } } +func TestCommitterTerminateDiscardsOutstandingAcks(t *testing.T) { + subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} + ack1 := newAckConsumer(33, 0, nil) + ack2 := newAckConsumer(55, 0, nil) + acks := newAckTracker() + acks.Push(ack1) + acks.Push(ack2) + + verifiers := test.NewVerifiers(t) + stream := test.NewRPCVerifier(t) + stream.Push(initCommitReq(subscription), initCommitResp(), nil) + stream.Push(commitReq(34), commitResp(1), nil) + verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + cmt := newTestCommitter(t, subscription, acks) + if gotErr := cmt.StartError(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + + ack1.Ack() + cmt.Terminate() // Terminate should flush the first offset + ack2.Ack() // Acks after Terminate() are discarded + cmt.SendBatchCommit() // Should do nothing (server does not expect second commit) + if gotErr := cmt.FinalError(); gotErr != nil { + t.Errorf("Final err: (%v), want: ", gotErr) + } +} + func TestCommitterPermanentStreamError(t *testing.T) { subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0} acks := newAckTracker() diff --git a/pubsublite/internal/wire/subscriber.go b/pubsublite/internal/wire/subscriber.go index 5965257609b..b7800d08d11 100644 --- a/pubsublite/internal/wire/subscriber.go +++ b/pubsublite/internal/wire/subscriber.go @@ -353,6 +353,13 @@ type singlePartitionSubscriber struct { compositeService } +// Terminate shuts down the singlePartitionSubscriber without waiting for +// outstanding acks. Alternatively, Stop() will wait for outstanding acks. +func (s *singlePartitionSubscriber) Terminate() { + s.subscriber.Stop() + s.committer.Terminate() +} + type singlePartitionSubscriberFactory struct { ctx context.Context subClient *vkit.SubscriberClient @@ -380,6 +387,8 @@ func (f *singlePartitionSubscriberFactory) New(partition int) *singlePartitionSu // multiPartitionSubscriber receives messages from a fixed set of topic // partitions. type multiPartitionSubscriber struct { + subscribers []*singlePartitionSubscriber + compositeService } @@ -390,10 +399,22 @@ func newMultiPartitionSubscriber(subFactory *singlePartitionSubscriberFactory) * for _, partition := range subFactory.settings.Partitions { subscriber := subFactory.New(partition) ms.unsafeAddServices(subscriber) + ms.subscribers = append(ms.subscribers, subscriber) } return ms } +// Terminate shuts down all singlePartitionSubscribers without waiting for +// outstanding acks. Alternatively, Stop() will wait for outstanding acks. +func (ms *multiPartitionSubscriber) Terminate() { + ms.mu.Lock() + defer ms.mu.Unlock() + + for _, sub := range ms.subscribers { + sub.Terminate() + } +} + // assigningSubscriber uses the Pub/Sub Lite partition assignment service to // listen to its assigned partition numbers and dynamically add/remove // singlePartitionSubscribers. @@ -453,6 +474,17 @@ func (as *assigningSubscriber) handleAssignment(partitions partitionSet) error { return nil } +// Terminate shuts down all singlePartitionSubscribers without waiting for +// outstanding acks. Alternatively, Stop() will wait for outstanding acks. +func (as *assigningSubscriber) Terminate() { + as.mu.Lock() + defer as.mu.Unlock() + + for _, sub := range as.subscribers { + sub.Terminate() + } +} + // Subscriber is the client interface exported from this package for receiving // messages. type Subscriber interface { @@ -460,6 +492,7 @@ type Subscriber interface { WaitStarted() error Stop() WaitStopped() error + Terminate() } // NewSubscriber creates a new client for receiving messages. diff --git a/pubsublite/ps/example_test.go b/pubsublite/ps/example_test.go index 957ffdd79f6..93f23a657ad 100644 --- a/pubsublite/ps/example_test.go +++ b/pubsublite/ps/example_test.go @@ -24,7 +24,12 @@ import ( func ExamplePublisherClient_Publish() { ctx := context.Background() - topic := pubsublite.TopicPath{Project: "project-id", Zone: "zone", TopicID: "topic-id"} + topic := pubsublite.TopicPath{ + Project: "project-id", + Zone: "zone", + TopicID: "topic-id", + } + // NOTE: DefaultPublishSettings and empty PublishSettings{} are equivalent. publisher, err := ps.NewPublisherClient(ctx, ps.DefaultPublishSettings, topic) if err != nil { // TODO: Handle error. @@ -48,7 +53,11 @@ func ExamplePublisherClient_Publish() { func ExamplePublisherClient_Error() { ctx := context.Background() - topic := pubsublite.TopicPath{Project: "project-id", Zone: "zone", TopicID: "topic-id"} + topic := pubsublite.TopicPath{ + Project: "project-id", + Zone: "zone", + TopicID: "topic-id", + } publisher, err := ps.NewPublisherClient(ctx, ps.DefaultPublishSettings, topic) if err != nil { // TODO: Handle error. @@ -72,3 +81,62 @@ func ExamplePublisherClient_Error() { fmt.Printf("Published a message with a message ID: %s\n", id) } } + +func ExampleSubscriberClient_Receive() { + ctx := context.Background() + subscription := pubsublite.SubscriptionPath{ + Project: "project-id", + Zone: "zone", + SubscriptionID: "subscription-id", + } + // NOTE: DefaultReceiveSettings and empty ReceiveSettings{} are equivalent. + subscriber, err := ps.NewSubscriberClient(ctx, ps.DefaultReceiveSettings, subscription) + if err != nil { + // TODO: Handle error. + } + cctx, cancel := context.WithCancel(ctx) + err = subscriber.Receive(cctx, func(ctx context.Context, m *pubsub.Message) { + // TODO: Handle message. + // NOTE: May be called concurrently; synchronize access to shared memory. + m.Ack() + }) + if err != nil { + // TODO: Handle error. + } + + // Call cancel from callback, or another goroutine. + cancel() +} + +// This example shows how to throttle SubscriberClient.Receive, which aims for +// high throughput by default. By limiting the number of messages and/or bytes +// being processed at once, you can bound your program's resource consumption. +// Note that ReceiveSettings apply per partition, so keep in mind the number of +// partitions in the associated topic. +func ExampleSubscriberClient_Receive_maxOutstanding() { + ctx := context.Background() + subscription := pubsublite.SubscriptionPath{ + Project: "project-id", + Zone: "zone", + SubscriptionID: "subscription-id", + } + settings := ps.DefaultReceiveSettings + settings.MaxOutstandingMessages = 5 + settings.MaxOutstandingBytes = 10e6 + subscriber, err := ps.NewSubscriberClient(ctx, settings, subscription) + if err != nil { + // TODO: Handle error. + } + cctx, cancel := context.WithCancel(ctx) + err = subscriber.Receive(cctx, func(ctx context.Context, m *pubsub.Message) { + // TODO: Handle message. + // NOTE: May be called concurrently; synchronize access to shared memory. + m.Ack() + }) + if err != nil { + // TODO: Handle error. + } + + // Call cancel from callback, or another goroutine. + cancel() +} diff --git a/pubsublite/ps/settings.go b/pubsublite/ps/settings.go index 6a8be9d1bb5..6b463296724 100644 --- a/pubsublite/ps/settings.go +++ b/pubsublite/ps/settings.go @@ -41,9 +41,9 @@ type KeyExtractorFunc func(*pubsub.Message) []byte // terminate. type PublishMessageTransformerFunc func(*pubsub.Message, *pb.PubSubMessage) error -// PublishSettings configure the PublisherClient. These settings apply per -// partition. If BufferedByteLimit is being used to bound memory usage, keep in -// mind the number of partitions in the topic. +// PublishSettings configure the PublisherClient. Batching settings +// (DelayThreshold, CountThreshold, ByteThreshold, BufferedByteLimit) apply per +// partition. // // A zero PublishSettings will result in values equivalent to // DefaultPublishSettings. @@ -76,6 +76,10 @@ type PublishSettings struct { // returning ErrOverflow. If BufferedByteLimit is 0, it will be treated as // DefaultPublishSettings.BufferedByteLimit. Otherwise must be > 0. // + // Note that this setting applies per partition. If BufferedByteLimit is being + // used to bound memory usage, keep in mind the number of partitions in the + // topic. + // // Note that Pub/Sub Lite topics are provisioned a publishing throughput // capacity, per partition, shared by all publisher clients. Setting a large // buffer size can mitigate transient publish spikes. However, consistently @@ -146,9 +150,8 @@ type NackHandler func(*pubsub.Message) error // will consider this a fatal error and terminate. type ReceiveMessageTransformerFunc func(*pb.SequencedMessage, *pubsub.Message) error -// ReceiveSettings configure the SubscriberClient. These settings apply per -// partition. If MaxOutstandingBytes is being used to bound memory usage, keep -// in mind the number of partitions in the associated topic. +// ReceiveSettings configure the SubscriberClient. Flow control settings +// (MaxOutstandingMessages, MaxOutstandingBytes) apply per partition. // // A zero ReceiveSettings will result in values equivalent to // DefaultReceiveSettings. @@ -161,6 +164,10 @@ type ReceiveSettings struct { // MaxOutstandingBytes is the maximum size (in quota bytes) of unacknowledged // messages. If MaxOutstandingBytes is 0, it will be treated as // DefaultReceiveSettings.MaxOutstandingBytes. Otherwise must be > 0. + // + // Note that this setting applies per partition. If MaxOutstandingBytes is + // being used to bound memory usage, keep in mind the number of partitions in + // the associated topic. MaxOutstandingBytes int // The maximum time that the client will attempt to establish a subscribe diff --git a/pubsublite/ps/subscriber.go b/pubsublite/ps/subscriber.go new file mode 100644 index 00000000000..8e58ba9d119 --- /dev/null +++ b/pubsublite/ps/subscriber.go @@ -0,0 +1,309 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package ps + +import ( + "context" + "errors" + "sync" + + "cloud.google.com/go/pubsub" + "cloud.google.com/go/pubsublite" + "cloud.google.com/go/pubsublite/internal/wire" + "google.golang.org/api/option" + + ipubsub "cloud.google.com/go/internal/pubsub" +) + +var ( + errNackCalled = errors.New("pubsublite: subscriber client does not support nack. See NackHandler for how to customize nack handling") + errDuplicateReceive = errors.New("pubsublite: receive is already in progress for this subscriber client") +) + +// handleNack is the default NackHandler implementation. +func handleNack(_ *pubsub.Message) error { + return errNackCalled +} + +// pslAckHandler is the AckHandler for Pub/Sub Lite. +type pslAckHandler struct { + ackh wire.AckConsumer + msg *pubsub.Message + nackh NackHandler + subInstance *subscriberInstance +} + +func (ah *pslAckHandler) OnAck() { + if ah.subInstance == nil { + return + } + + ah.ackh.Ack() + ah.subInstance = nil +} + +func (ah *pslAckHandler) OnNack() { + if ah.subInstance == nil { + return + } + + err := ah.nackh(ah.msg) + if err != nil { + // If the NackHandler returns an error, shut down the subscriber client. + ah.subInstance.Terminate(err) + } else { + // If the NackHandler succeeds, just ack the message. + ah.ackh.Ack() + } + ah.subInstance = nil +} + +// wireSubscriberFactory is a factory for creating wire subscribers, which can +// be overridden with a mock in unit tests. +type wireSubscriberFactory interface { + New(wire.MessageReceiverFunc) (wire.Subscriber, error) +} + +type wireSubscriberFactoryImpl struct { + settings wire.ReceiveSettings + region string + subscription pubsublite.SubscriptionPath + options []option.ClientOption +} + +func (f *wireSubscriberFactoryImpl) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { + return wire.NewSubscriber(context.Background(), f.settings, receiver, f.region, f.subscription.String(), f.options...) +} + +// subscriberInstance wraps an instance of a wire.Subscriber. A new instance is +// created for each invocation of SubscriberClient.Receive(). +type subscriberInstance struct { + settings ReceiveSettings + receiver MessageReceiverFunc + recvCtx context.Context // Context passed to the receiver + recvCancel context.CancelFunc // Corresponding cancel func for recvCtx + wireSub wire.Subscriber + activeReceivers sync.WaitGroup + + // Fields below must be guarded with mu. + mu sync.Mutex + err error +} + +func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver MessageReceiverFunc) (*subscriberInstance, error) { + recvCtx, recvCancel := context.WithCancel(ctx) + subInstance := &subscriberInstance{ + settings: settings, + recvCtx: recvCtx, + recvCancel: recvCancel, + receiver: receiver, + } + + // Note: ctx is not used to create the wire subscriber, because if it is + // cancelled, the subscriber will not be able to perform graceful shutdown + // (e.g. process acks and commit the final cursor offset). + wireSub, err := factory.New(subInstance.onMessage) + if err != nil { + return nil, err + } + + subInstance.wireSub = wireSub + if subInstance.settings.MessageTransformer == nil { + subInstance.settings.MessageTransformer = transformReceivedMessage + } + if subInstance.settings.NackHandler == nil { + subInstance.settings.NackHandler = handleNack + } + return subInstance, nil +} + +func (si *subscriberInstance) onMessage(msg *wire.ReceivedMessage) { + pslAckh := &pslAckHandler{ + ackh: msg.Ack, + nackh: si.settings.NackHandler, + subInstance: si, + } + psMsg := ipubsub.NewMessage(pslAckh) + pslAckh.msg = psMsg + if err := si.settings.MessageTransformer(msg.Msg, psMsg); err != nil { + si.Terminate(err) + return + } + + si.activeReceivers.Add(1) + si.receiver(si.recvCtx, psMsg) + si.activeReceivers.Done() +} + +// shutdown starts shutting down the subscriber client. The wire subscriber can +// optionally wait for all outstanding messages to be acked/nacked. +func (si *subscriberInstance) shutdown(waitForAcks bool, err error) { + si.mu.Lock() + defer si.mu.Unlock() + + // Don't clobber original error. + if si.err == nil { + si.err = err + } + + // Cancel recvCtx to notify message receiver funcs of shutdown. + si.recvCancel() + + // Either wait for acks, or terminate quickly upon fatal error. + if waitForAcks { + si.wireSub.Stop() + } else { + si.wireSub.Terminate() + } +} + +// Terminate shuts down the subscriber client without waiting for outstanding +// messages to be acked/nacked. +func (si *subscriberInstance) Terminate(err error) { + si.shutdown(false, err) +} + +// Wait for the subscriber to stop, or the context is done, whichever occurs +// first. +func (si *subscriberInstance) Wait(ctx context.Context) error { + si.wireSub.Start() + if err := si.wireSub.WaitStarted(); err != nil { + return err + } + + // Start a goroutine to monitor when the context is done. + subscriberStopped := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + si.shutdown(true, nil) + case <-subscriberStopped: + } + }() + err := si.wireSub.WaitStopped() + + // End goroutine above if the wire subscriber terminated due to fatal error + // and ctx is not done. + close(subscriberStopped) + // And also wait for all the receivers to finish. + si.recvCancel() + si.activeReceivers.Wait() + + si.mu.Lock() + defer si.mu.Unlock() + + if si.err != nil { + return si.err + } + return err +} + +// MessageReceiverFunc handles messages sent by the Cloud Pub/Sub Lite service. +// +// The implementation must arrange for pubsub.Message.Ack() or +// pubsub.Message.Nack() to be called after processing the message. +// +// The receiver func will be called from multiple goroutines if the subscriber +// is connected to multiple partitions. Only one call from any connected +// partition will be outstanding at a time, and blocking in this receiver +// callback will block the delivery of subsequent messages for the partition. +type MessageReceiverFunc func(context.Context, *pubsub.Message) + +// SubscriberClient is a Cloud Pub/Sub Lite client to receive messages for a +// given subscription. +// +// See https://cloud.google.com/pubsub/lite/docs/subscribing for more +// information about receiving messages. +type SubscriberClient struct { + settings ReceiveSettings + wireSubFactory wireSubscriberFactory + + // Fields below must be guarded with mu. + mu sync.Mutex + receiveActive bool +} + +// NewSubscriberClient creates a new Cloud Pub/Sub Lite client to receive +// messages for a given subscription. +// +// See https://cloud.google.com/pubsub/lite/docs/subscribing for more +// information about receiving messages. +func NewSubscriberClient(ctx context.Context, settings ReceiveSettings, subscription pubsublite.SubscriptionPath, opts ...option.ClientOption) (*SubscriberClient, error) { + region, err := pubsublite.ZoneToRegion(subscription.Zone) + if err != nil { + return nil, err + } + factory := &wireSubscriberFactoryImpl{ + settings: settings.toWireSettings(), + region: region, + subscription: subscription, + options: opts, + } + subClient := &SubscriberClient{ + settings: settings, + wireSubFactory: factory, + } + return subClient, nil +} + +// Receive calls f with the messages from the subscription. It blocks until ctx +// is done, or the service returns a non-retryable error. +// +// The standard way to terminate a Receive is to cancel its context: +// +// cctx, cancel := context.WithCancel(ctx) +// err := sub.Receive(cctx, callback) +// // Call cancel from callback, or another goroutine. +// +// If there is a fatal service error, Receive returns that error after all of +// the outstanding calls to f have returned. If ctx is done, Receive returns nil +// after all of the outstanding calls to f have returned and all messages have +// been acknowledged. +// +// Receive calls f concurrently from multiple goroutines if the SubscriberClient +// is connected to multiple partitions. All messages received by f must be ACKed +// or NACKed. Failure to do so can prevent Receive from returning. +// +// The context passed to f will be canceled when ctx is Done or there is a fatal +// service error. +// +// Each SubscriberClient may have only one invocation of Receive active at a +// time. +func (s *SubscriberClient) Receive(ctx context.Context, f MessageReceiverFunc) error { + if err := s.setReceiveActive(true); err != nil { + return err + } + defer s.setReceiveActive(false) + + // Initialize a subscriber instance. + subInstance, err := newSubscriberInstance(ctx, s.wireSubFactory, s.settings, f) + if err != nil { + return err + } + + // Wait for the subscriber without mutex held. Overlapping Receive invocations + // will return an error. + return subInstance.Wait(ctx) +} + +func (s *SubscriberClient) setReceiveActive(active bool) error { + s.mu.Lock() + defer s.mu.Unlock() + + if active && s.receiveActive { + return errDuplicateReceive + } + s.receiveActive = active + return nil +} diff --git a/pubsublite/ps/subscriber_test.go b/pubsublite/ps/subscriber_test.go new file mode 100644 index 00000000000..9a8ebdb5a04 --- /dev/null +++ b/pubsublite/ps/subscriber_test.go @@ -0,0 +1,407 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package ps + +import ( + "context" + "errors" + "testing" + "time" + + pubsub "cloud.google.com/go/internal/pubsub" + "cloud.google.com/go/internal/testutil" + "cloud.google.com/go/pubsublite/internal/test" + "cloud.google.com/go/pubsublite/internal/wire" + "github.com/google/go-cmp/cmp/cmpopts" + "golang.org/x/sync/errgroup" + + tspb "github.com/golang/protobuf/ptypes/timestamp" + pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" +) + +const defaultSubscriberTestTimeout = 10 * time.Second + +// mockAckConsumer is a mock implementation of the wire.AckConsumer interface. +type mockAckConsumer struct { + AckCount int +} + +func (ac *mockAckConsumer) Ack() { + ac.AckCount++ +} + +// mockWireSubscriber is a mock implementation of the wire.Subscriber interface. +type mockWireSubscriber struct { + receiver wire.MessageReceiverFunc + msgsC chan *wire.ReceivedMessage + stopC chan struct{} + err error + Stopped bool + Terminated bool +} + +// DeliverMessages should be called from the test to simulate a message +// delivery. +func (ms *mockWireSubscriber) DeliverMessages(msgs ...*wire.ReceivedMessage) { + for _, m := range msgs { + ms.msgsC <- m + } +} + +// SimulateFatalError should be called from the test to simulate a fatal error +// occurring in the wire subscriber. +func (ms *mockWireSubscriber) SimulateFatalError(err error) { + ms.err = err + close(ms.stopC) +} + +// wire.Subscriber implementation + +func (ms *mockWireSubscriber) Start() { + go func() { + for { + // Ensure stop has higher priority. + select { + case <-ms.stopC: + return // Exit goroutine + default: + } + + select { + case <-ms.stopC: + return // Exit goroutine + case msg := <-ms.msgsC: + ms.receiver(msg) + } + } + }() +} + +func (ms *mockWireSubscriber) WaitStarted() error { + return nil +} + +func (ms *mockWireSubscriber) Stop() { + if !ms.Stopped && !ms.Terminated { + ms.Stopped = true + close(ms.stopC) + } +} + +func (ms *mockWireSubscriber) Terminate() { + if !ms.Stopped && !ms.Terminated { + ms.Terminated = true + close(ms.stopC) + } +} + +func (ms *mockWireSubscriber) WaitStopped() error { + <-ms.stopC // Wait until Stopped + return ms.err +} + +type mockWireSubscriberFactory struct{} + +func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { + return &mockWireSubscriber{ + receiver: receiver, + msgsC: make(chan *wire.ReceivedMessage, 10), + stopC: make(chan struct{}), + }, nil +} + +func newTestSubscriberInstance(ctx context.Context, settings ReceiveSettings, receiver MessageReceiverFunc) *subscriberInstance { + sub, _ := newSubscriberInstance(ctx, new(mockWireSubscriberFactory), settings, receiver) + return sub +} + +func TestSubscriberInstanceTransformMessage(t *testing.T) { + ctx := context.Background() + input := &pb.SequencedMessage{ + Message: &pb.PubSubMessage{ + Data: []byte("data"), + Key: []byte("key"), + Attributes: map[string]*pb.AttributeValues{ + "attr": {Values: [][]byte{[]byte("value")}}, + }, + }, + Cursor: &pb.Cursor{Offset: 123}, + PublishTime: &tspb.Timestamp{ + Seconds: 1577836800, + Nanos: 900800700, + }, + } + + for _, tc := range []struct { + desc string + // mutateSettings is passed a copy of DefaultReceiveSettings to mutate. + mutateSettings func(settings *ReceiveSettings) + want *pubsub.Message + }{ + { + desc: "default settings", + mutateSettings: func(settings *ReceiveSettings) {}, + want: &pubsub.Message{ + Data: []byte("data"), + OrderingKey: "key", + Attributes: map[string]string{"attr": "value"}, + ID: "123", + PublishTime: time.Unix(1577836800, 900800700), + }, + }, + { + desc: "custom message transformer", + mutateSettings: func(settings *ReceiveSettings) { + settings.MessageTransformer = func(from *pb.SequencedMessage, to *pubsub.Message) error { + // Swaps data and key. + to.OrderingKey = string(from.Message.Data) + to.Data = from.Message.Key + return nil + } + }, + want: &pubsub.Message{ + Data: []byte("key"), + OrderingKey: "data", + }, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + settings := DefaultReceiveSettings + tc.mutateSettings(&settings) + + ack := &mockAckConsumer{} + msg := &wire.ReceivedMessage{Msg: input, Ack: ack} + + cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout) + messageReceiver := func(ctx context.Context, got *pubsub.Message) { + if diff := testutil.Diff(got, tc.want, cmpopts.IgnoreUnexported(pubsub.Message{}), cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Received message got: -, want: +\n%s", diff) + } + got.Ack() + got.Nack() // Should be ignored + stopSubscriber() + } + subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver) + subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg) + + if err := subInstance.Wait(cctx); err != nil { + t.Errorf("subscriberInstance.Wait() got err: %v", err) + } + if got, want := ack.AckCount, 1; got != want { + t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want) + } + if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) { + t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want) + } + if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, true; got != want { + t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want) + } + }) + } +} + +func TestSubscriberInstanceTransformMessageError(t *testing.T) { + wantErr := errors.New("message could not be converted") + + settings := DefaultReceiveSettings + settings.MessageTransformer = func(_ *pb.SequencedMessage, _ *pubsub.Message) error { + return wantErr + } + + ctx := context.Background() + ack := &mockAckConsumer{} + msg := &wire.ReceivedMessage{ + Ack: ack, + Msg: &pb.SequencedMessage{ + Message: &pb.PubSubMessage{Data: []byte("data")}, + }, + } + + cctx, _ := context.WithTimeout(ctx, defaultSubscriberTestTimeout) + messageReceiver := func(ctx context.Context, got *pubsub.Message) { + t.Errorf("Received unexpected message: %v", got) + got.Nack() + } + subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver) + subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg) + + if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, wantErr) + } + if got, want := ack.AckCount, 0; got != want { + t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want) + } + if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) { + t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want) + } + if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, true; got != want { + t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want) + } +} + +func TestSubscriberInstanceNack(t *testing.T) { + nackErr := errors.New("message nacked") + + ctx := context.Background() + msg := &pb.SequencedMessage{ + Message: &pb.PubSubMessage{ + Data: []byte("data"), + Key: []byte("key"), + }, + } + + for _, tc := range []struct { + desc string + // mutateSettings is passed a copy of DefaultReceiveSettings to mutate. + mutateSettings func(settings *ReceiveSettings) + wantErr error + wantAckCount int + wantStopped bool + wantTerminated bool + }{ + { + desc: "default settings", + mutateSettings: func(settings *ReceiveSettings) {}, + wantErr: errNackCalled, + wantAckCount: 0, + wantTerminated: true, + }, + { + desc: "nack handler returns nil", + mutateSettings: func(settings *ReceiveSettings) { + settings.NackHandler = func(_ *pubsub.Message) error { + return nil + } + }, + wantErr: nil, + wantAckCount: 1, + wantStopped: true, + }, + { + desc: "nack handler returns error", + mutateSettings: func(settings *ReceiveSettings) { + settings.NackHandler = func(_ *pubsub.Message) error { + return nackErr + } + }, + wantErr: nackErr, + wantAckCount: 0, + wantTerminated: true, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + settings := DefaultReceiveSettings + tc.mutateSettings(&settings) + + ack := &mockAckConsumer{} + msg := &wire.ReceivedMessage{Msg: msg, Ack: ack} + + cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout) + messageReceiver := func(ctx context.Context, got *pubsub.Message) { + got.Nack() + + // Only need to stop the subscriber when the nack handler actually acks + // the message. For other cases, the subscriber is forcibly terminated. + if tc.wantErr == nil { + stopSubscriber() + } + } + subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver) + subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg) + + if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) { + t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, tc.wantErr) + } + if got, want := ack.AckCount, tc.wantAckCount; got != want { + t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want) + } + if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) { + t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want) + } + if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, tc.wantStopped; got != want { + t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want) + } + if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, tc.wantTerminated; got != want { + t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want) + } + }) + } +} + +func TestSubscriberInstanceWireSubscriberFails(t *testing.T) { + fatalErr := errors.New("server error") + + ctx := context.Background() + msg := &wire.ReceivedMessage{ + Ack: &mockAckConsumer{}, + Msg: &pb.SequencedMessage{ + Message: &pb.PubSubMessage{Data: []byte("data")}, + }, + } + + cctx, _ := context.WithTimeout(ctx, defaultSubscriberTestTimeout) + messageReceiver := func(ctx context.Context, got *pubsub.Message) { + // Verifies that receivers are notified via ctx.Done when the subscriber is + // shutting down. + select { + case <-time.After(defaultSubscriberTestTimeout): + t.Errorf("MessageReceiverFunc context not closed within %v", defaultSubscriberTestTimeout) + case <-ctx.Done(): + } + } + subInstance := newTestSubscriberInstance(cctx, DefaultReceiveSettings, messageReceiver) + subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg) + time.AfterFunc(100*time.Millisecond, func() { + // Simulates a fatal server error that causes the wire subscriber to + // terminate from within. + subInstance.wireSub.(*mockWireSubscriber).SimulateFatalError(fatalErr) + }) + + if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, fatalErr) { + t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, fatalErr) + } + if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) { + t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want) + } + if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, false; got != want { + t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want) + } + if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, false; got != want { + t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want) + } +} + +func TestSubscriberClientDuplicateReceive(t *testing.T) { + ctx := context.Background() + subClient := &SubscriberClient{ + settings: DefaultReceiveSettings, + wireSubFactory: new(mockWireSubscriberFactory), + } + + messageReceiver := func(_ context.Context, got *pubsub.Message) { + t.Errorf("No messages expected, got: %v", got) + } + + g, gctx := errgroup.WithContext(ctx) + for i := 0; i < 3; i++ { + // Receive() is blocking, so we must start them in goroutines. Passing gctx + // to Receive will stop the subscribers once the first error occurs. + g.Go(func() error { + return subClient.Receive(gctx, messageReceiver) + }) + } + if gotErr, wantErr := g.Wait(), errDuplicateReceive; !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("SubscriberClient.Receive() got err: (%v), want: (%v)", gotErr, wantErr) + } +}