From ae34396b1a2a970a0d871cd5496527294f3310d4 Mon Sep 17 00:00:00 2001 From: tmdiep Date: Sat, 26 Jun 2021 07:17:55 +1000 Subject: [PATCH] fix(pubsublite): wire user context to api clients (#4318) --- pubsublite/pscompat/integration_test.go | 60 ++++++++++++++++++++----- pubsublite/pscompat/publisher.go | 5 +-- pubsublite/pscompat/subscriber.go | 23 +++++----- pubsublite/pscompat/subscriber_test.go | 4 +- 4 files changed, 66 insertions(+), 26 deletions(-) diff --git a/pubsublite/pscompat/integration_test.go b/pubsublite/pscompat/integration_test.go index a22358ead8e..ab14ad7b973 100644 --- a/pubsublite/pscompat/integration_test.go +++ b/pubsublite/pscompat/integration_test.go @@ -30,7 +30,9 @@ import ( "cloud.google.com/go/pubsublite/internal/wire" "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" "google.golang.org/api/option" + "google.golang.org/grpc/codes" vkit "cloud.google.com/go/pubsublite/apiv1" pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" @@ -167,7 +169,7 @@ func partitionNumbers(partitionCount int) []int { func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgs ...*pubsub.Message) { ctx := context.Background() - publisher := publisherClient(ctx, t, settings, topic) + publisher := publisherClient(context.Background(), t, settings, topic) defer publisher.Stop() var pubResults []*pubsub.PublishResult @@ -179,7 +181,7 @@ func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPat func publishPrefixedMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgPrefix string, msgCount, msgSize int) []string { ctx := context.Background() - publisher := publisherClient(ctx, t, settings, topic) + publisher := publisherClient(context.Background(), t, settings, topic) defer publisher.Stop() orderingSender := test.NewOrderingSender() @@ -271,7 +273,7 @@ func receiveAllMessages(t *testing.T, msgTracker *test.MsgTracker, settings Rece } } - subscriber := subscriberClient(cctx, t, settings, subscription) + subscriber := subscriberClient(context.Background(), t, settings, subscription) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) } @@ -298,7 +300,7 @@ func receiveAndVerifyMessage(t *testing.T, want *pubsub.Message, settings Receiv } } - subscriber := subscriberClient(cctx, t, settings, subscription) + subscriber := subscriberClient(context.Background(), t, settings, subscription) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) } @@ -383,7 +385,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { } got.Nack() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) if gotErr := subscriber.Receive(cctx, messageReceiver1); !test.ErrorEqual(gotErr, errNackCalled) { t.Errorf("Receive() got err: (%v), want err: (%v)", gotErr, errNackCalled) } @@ -400,7 +402,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { } return fmt.Errorf("Received unexpected message: %q", truncateMsg(string(msg.Data))) } - subscriber = subscriberClient(cctx, t, customSettings, subscriptionPath) + subscriber = subscriberClient(context.Background(), t, customSettings, subscriptionPath) messageReceiver2 := func(ctx context.Context, got *pubsub.Message) { got.Nack() @@ -434,7 +436,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { got.Ack() stopSubscriber() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) // The message receiver stops the subscriber after receiving the first // message. However, the subscriber isn't guaranteed to immediately stop, so @@ -485,7 +487,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { // next test, which would receive an incorrect message. got.Ack() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) @@ -539,6 +541,44 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { receiveAllMessages(t, msgTracker, recvSettings, subscriptionPath) }) + // Verifies that cancelling the context passed to NewPublisherClient can shut + // down the publisher. + t.Run("CancelPublisherContext", func(t *testing.T) { + cctx, cancel := context.WithCancel(context.Background()) + publisher := publisherClient(cctx, t, DefaultPublishSettings, topicPath) + + cancel() + + wantCode := codes.Canceled + result := publisher.Publish(ctx, &pubsub.Message{Data: []byte("cancel_publisher_context")}) + if _, err := result.Get(ctx); !test.ErrorHasCode(err, wantCode) { + t.Errorf("Publish() got err: %v, want code: %v", err, wantCode) + } + if err := xerrors.Unwrap(publisher.Error()); !test.ErrorHasCode(err, wantCode) { + t.Errorf("Error() got err: %v, want code: %v", err, wantCode) + } + publisher.Stop() + }) + + // Verifies that cancelling the context passed to NewSubscriberClient can shut + // down the subscriber. + t.Run("CancelSubscriberContext", func(t *testing.T) { + msg := &pubsub.Message{Data: []byte("cancel_subscriber_context")} + publishMessages(t, DefaultPublishSettings, topicPath, msg) + + cctx, cancel := context.WithCancel(context.Background()) + subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + + subsErr := subscriber.Receive(context.Background(), func(ctx context.Context, got *pubsub.Message) { + got.Ack() + cancel() + }) + + if err, wantCode := xerrors.Unwrap(subsErr), codes.Canceled; !test.ErrorHasCode(err, wantCode) { + t.Errorf("Receive() got err: %v, want code: %v", err, wantCode) + } + }) + // NOTE: This should be the last test case. // Verifies that increasing the number of topic partitions is handled // correctly by publishers. @@ -547,7 +587,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { const pollPeriod = 5 * time.Second pubSettings := DefaultPublishSettings pubSettings.configPollPeriod = pollPeriod // Poll updates more frequently - publisher := publisherClient(ctx, t, pubSettings, topicPath) + publisher := publisherClient(context.Background(), t, pubSettings, topicPath) defer publisher.Stop() // Update the number of partitions. @@ -661,7 +701,7 @@ func TestIntegration_PublishSubscribeMultiPartition(t *testing.T) { for i := 0; i < subscriberCount; i++ { // Subscribers must be started in a goroutine as Receive() blocks. g.Go(func() error { - subscriber := subscriberClient(cctx, t, DefaultReceiveSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, DefaultReceiveSettings, subscriptionPath) err := subscriber.Receive(cctx, messageReceiver) if err != nil { t.Errorf("Receive() got err: %v", err) diff --git a/pubsublite/pscompat/publisher.go b/pubsublite/pscompat/publisher.go index 6f4e9415a82..b9f8c82e19a 100644 --- a/pubsublite/pscompat/publisher.go +++ b/pubsublite/pscompat/publisher.go @@ -82,10 +82,7 @@ func NewPublisherClientWithSettings(ctx context.Context, topic string, settings return nil, err } - // Note: ctx is not used to create the wire publisher, because if it is - // cancelled, the publisher will not be able to perform graceful shutdown - // (e.g. flush pending messages). - wirePub, err := wire.NewPublisher(context.Background(), settings.toWireSettings(), region, topic, opts...) + wirePub, err := wire.NewPublisher(ctx, settings.toWireSettings(), region, topic, opts...) if err != nil { return nil, err } diff --git a/pubsublite/pscompat/subscriber.go b/pubsublite/pscompat/subscriber.go index dafb25ee526..d76cc9d290b 100644 --- a/pubsublite/pscompat/subscriber.go +++ b/pubsublite/pscompat/subscriber.go @@ -72,7 +72,7 @@ func (ah *pslAckHandler) OnNack() { // wireSubscriberFactory is a factory for creating wire subscribers, which can // be overridden with a mock in unit tests. type wireSubscriberFactory interface { - New(wire.MessageReceiverFunc) (wire.Subscriber, error) + New(context.Context, wire.MessageReceiverFunc) (wire.Subscriber, error) } type wireSubscriberFactoryImpl struct { @@ -82,8 +82,8 @@ type wireSubscriberFactoryImpl struct { options []option.ClientOption } -func (f *wireSubscriberFactoryImpl) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { - return wire.NewSubscriber(context.Background(), f.settings, receiver, f.region, f.subscription.String(), f.options...) +func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { + return wire.NewSubscriber(ctx, f.settings, receiver, f.region, f.subscription.String(), f.options...) } type messageReceiverFunc = func(context.Context, *pubsub.Message) @@ -103,8 +103,8 @@ type subscriberInstance struct { err error } -func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) { - recvCtx, recvCancel := context.WithCancel(ctx) +func newSubscriberInstance(recvCtx, clientCtx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) { + recvCtx, recvCancel := context.WithCancel(recvCtx) subInstance := &subscriberInstance{ settings: settings, recvCtx: recvCtx, @@ -112,10 +112,11 @@ func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, s receiver: receiver, } - // Note: ctx is not used to create the wire subscriber, because if it is - // cancelled, the subscriber will not be able to perform graceful shutdown - // (e.g. process acks and commit the final cursor offset). - wireSub, err := factory.New(subInstance.onMessage) + // Note: The context from Receive (recvCtx) should not be used, as when it is + // cancelled, the gRPC streams will be disconnected and the subscriber will + // not be able to process acks and commit the final cursor offset. Use the + // context from NewSubscriberClient (clientCtx) instead. + wireSub, err := factory.New(clientCtx, subInstance.onMessage) if err != nil { return nil, err } @@ -229,6 +230,7 @@ func (si *subscriberInstance) Wait(ctx context.Context) error { // See https://cloud.google.com/pubsub/lite/docs/subscribing for more // information about receiving messages. type SubscriberClient struct { + clientCtx context.Context settings ReceiveSettings wireSubFactory wireSubscriberFactory @@ -265,6 +267,7 @@ func NewSubscriberClientWithSettings(ctx context.Context, subscription string, s options: opts, } subClient := &SubscriberClient{ + clientCtx: ctx, settings: settings, wireSubFactory: factory, } @@ -303,7 +306,7 @@ func (s *SubscriberClient) Receive(ctx context.Context, f func(context.Context, defer s.setReceiveActive(false) // Initialize a subscriber instance. - subInstance, err := newSubscriberInstance(ctx, s.wireSubFactory, s.settings, f) + subInstance, err := newSubscriberInstance(ctx, s.clientCtx, s.wireSubFactory, s.settings, f) if err != nil { return err } diff --git a/pubsublite/pscompat/subscriber_test.go b/pubsublite/pscompat/subscriber_test.go index 429c9f55141..5c737f8ff42 100644 --- a/pubsublite/pscompat/subscriber_test.go +++ b/pubsublite/pscompat/subscriber_test.go @@ -113,7 +113,7 @@ func (ms *mockWireSubscriber) WaitStopped() error { type mockWireSubscriberFactory struct{} -func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { +func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { return &mockWireSubscriber{ receiver: receiver, msgsC: make(chan *wire.ReceivedMessage, 10), @@ -122,7 +122,7 @@ func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire } func newTestSubscriberInstance(ctx context.Context, settings ReceiveSettings, receiver messageReceiverFunc) *subscriberInstance { - sub, _ := newSubscriberInstance(ctx, new(mockWireSubscriberFactory), settings, receiver) + sub, _ := newSubscriberInstance(ctx, context.Background(), new(mockWireSubscriberFactory), settings, receiver) return sub }