Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pubsublite): wire user context to api clients #4318

Merged
merged 2 commits into from Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 50 additions & 10 deletions pubsublite/pscompat/integration_test.go
Expand Up @@ -30,7 +30,9 @@ import (
"cloud.google.com/go/pubsublite/internal/wire"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"google.golang.org/api/option"
"google.golang.org/grpc/codes"

vkit "cloud.google.com/go/pubsublite/apiv1"
pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
Expand Down Expand Up @@ -167,7 +169,7 @@ func partitionNumbers(partitionCount int) []int {

func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgs ...*pubsub.Message) {
ctx := context.Background()
publisher := publisherClient(ctx, t, settings, topic)
publisher := publisherClient(context.Background(), t, settings, topic)
defer publisher.Stop()

var pubResults []*pubsub.PublishResult
Expand All @@ -179,7 +181,7 @@ func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPat

func publishPrefixedMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgPrefix string, msgCount, msgSize int) []string {
ctx := context.Background()
publisher := publisherClient(ctx, t, settings, topic)
publisher := publisherClient(context.Background(), t, settings, topic)
defer publisher.Stop()

orderingSender := test.NewOrderingSender()
Expand Down Expand Up @@ -271,7 +273,7 @@ func receiveAllMessages(t *testing.T, msgTracker *test.MsgTracker, settings Rece
}
}

subscriber := subscriberClient(cctx, t, settings, subscription)
subscriber := subscriberClient(context.Background(), t, settings, subscription)
if err := subscriber.Receive(cctx, messageReceiver); err != nil {
t.Errorf("Receive() got err: %v", err)
}
Expand All @@ -298,7 +300,7 @@ func receiveAndVerifyMessage(t *testing.T, want *pubsub.Message, settings Receiv
}
}

subscriber := subscriberClient(cctx, t, settings, subscription)
subscriber := subscriberClient(context.Background(), t, settings, subscription)
if err := subscriber.Receive(cctx, messageReceiver); err != nil {
t.Errorf("Receive() got err: %v", err)
}
Expand Down Expand Up @@ -383,7 +385,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
}
got.Nack()
}
subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath)
subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath)
if gotErr := subscriber.Receive(cctx, messageReceiver1); !test.ErrorEqual(gotErr, errNackCalled) {
t.Errorf("Receive() got err: (%v), want err: (%v)", gotErr, errNackCalled)
}
Expand All @@ -400,7 +402,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
}
return fmt.Errorf("Received unexpected message: %q", truncateMsg(string(msg.Data)))
}
subscriber = subscriberClient(cctx, t, customSettings, subscriptionPath)
subscriber = subscriberClient(context.Background(), t, customSettings, subscriptionPath)

messageReceiver2 := func(ctx context.Context, got *pubsub.Message) {
got.Nack()
Expand Down Expand Up @@ -434,7 +436,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
got.Ack()
stopSubscriber()
}
subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath)
subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath)

// The message receiver stops the subscriber after receiving the first
// message. However, the subscriber isn't guaranteed to immediately stop, so
Expand Down Expand Up @@ -485,7 +487,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
// next test, which would receive an incorrect message.
got.Ack()
}
subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath)
subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath)

if err := subscriber.Receive(cctx, messageReceiver); err != nil {
t.Errorf("Receive() got err: %v", err)
Expand Down Expand Up @@ -539,6 +541,44 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
receiveAllMessages(t, msgTracker, recvSettings, subscriptionPath)
})

// Verifies that cancelling the context passed to NewPublisherClient can shut
// down the publisher.
t.Run("CancelPublisherContext", func(t *testing.T) {
cctx, cancel := context.WithCancel(context.Background())
publisher := publisherClient(cctx, t, DefaultPublishSettings, topicPath)

cancel()

wantCode := codes.Canceled
result := publisher.Publish(ctx, &pubsub.Message{Data: []byte("cancel_publisher_context")})
if _, err := result.Get(ctx); !test.ErrorHasCode(err, wantCode) {
t.Errorf("Publish() got err: %v, want code: %v", err, wantCode)
}
if err := xerrors.Unwrap(publisher.Error()); !test.ErrorHasCode(err, wantCode) {
t.Errorf("Error() got err: %v, want code: %v", err, wantCode)
}
publisher.Stop()
})

// Verifies that cancelling the context passed to NewSubscriberClient can shut
// down the subscriber.
t.Run("CancelSubscriberContext", func(t *testing.T) {
msg := &pubsub.Message{Data: []byte("cancel_subscriber_context")}
publishMessages(t, DefaultPublishSettings, topicPath, msg)

cctx, cancel := context.WithCancel(context.Background())
subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath)

subsErr := subscriber.Receive(context.Background(), func(ctx context.Context, got *pubsub.Message) {
got.Ack()
cancel()
})

if err, wantCode := xerrors.Unwrap(subsErr), codes.Canceled; !test.ErrorHasCode(err, wantCode) {
t.Errorf("Receive() got err: %v, want code: %v", err, wantCode)
}
})

// NOTE: This should be the last test case.
// Verifies that increasing the number of topic partitions is handled
// correctly by publishers.
Expand All @@ -547,7 +587,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
const pollPeriod = 5 * time.Second
pubSettings := DefaultPublishSettings
pubSettings.configPollPeriod = pollPeriod // Poll updates more frequently
publisher := publisherClient(ctx, t, pubSettings, topicPath)
publisher := publisherClient(context.Background(), t, pubSettings, topicPath)
defer publisher.Stop()

// Update the number of partitions.
Expand Down Expand Up @@ -661,7 +701,7 @@ func TestIntegration_PublishSubscribeMultiPartition(t *testing.T) {
for i := 0; i < subscriberCount; i++ {
// Subscribers must be started in a goroutine as Receive() blocks.
g.Go(func() error {
subscriber := subscriberClient(cctx, t, DefaultReceiveSettings, subscriptionPath)
subscriber := subscriberClient(context.Background(), t, DefaultReceiveSettings, subscriptionPath)
err := subscriber.Receive(cctx, messageReceiver)
if err != nil {
t.Errorf("Receive() got err: %v", err)
Expand Down
5 changes: 1 addition & 4 deletions pubsublite/pscompat/publisher.go
Expand Up @@ -82,10 +82,7 @@ func NewPublisherClientWithSettings(ctx context.Context, topic string, settings
return nil, err
}

// Note: ctx is not used to create the wire publisher, because if it is
// cancelled, the publisher will not be able to perform graceful shutdown
// (e.g. flush pending messages).
wirePub, err := wire.NewPublisher(context.Background(), settings.toWireSettings(), region, topic, opts...)
wirePub, err := wire.NewPublisher(ctx, settings.toWireSettings(), region, topic, opts...)
if err != nil {
return nil, err
}
Expand Down
23 changes: 13 additions & 10 deletions pubsublite/pscompat/subscriber.go
Expand Up @@ -72,7 +72,7 @@ func (ah *pslAckHandler) OnNack() {
// wireSubscriberFactory is a factory for creating wire subscribers, which can
// be overridden with a mock in unit tests.
type wireSubscriberFactory interface {
New(wire.MessageReceiverFunc) (wire.Subscriber, error)
New(context.Context, wire.MessageReceiverFunc) (wire.Subscriber, error)
}

type wireSubscriberFactoryImpl struct {
Expand All @@ -82,8 +82,8 @@ type wireSubscriberFactoryImpl struct {
options []option.ClientOption
}

func (f *wireSubscriberFactoryImpl) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
return wire.NewSubscriber(context.Background(), f.settings, receiver, f.region, f.subscription.String(), f.options...)
func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
return wire.NewSubscriber(ctx, f.settings, receiver, f.region, f.subscription.String(), f.options...)
}

type messageReceiverFunc = func(context.Context, *pubsub.Message)
Expand All @@ -103,19 +103,20 @@ type subscriberInstance struct {
err error
}

func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) {
recvCtx, recvCancel := context.WithCancel(ctx)
func newSubscriberInstance(recvCtx, clientCtx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) {
recvCtx, recvCancel := context.WithCancel(recvCtx)
subInstance := &subscriberInstance{
settings: settings,
recvCtx: recvCtx,
recvCancel: recvCancel,
receiver: receiver,
}

// Note: ctx is not used to create the wire subscriber, because if it is
// cancelled, the subscriber will not be able to perform graceful shutdown
// (e.g. process acks and commit the final cursor offset).
wireSub, err := factory.New(subInstance.onMessage)
// Note: The context from Receive (recvCtx) should not be used, as when it is
// cancelled, the gRPC streams will be disconnected and the subscriber will
// not be able to process acks and commit the final cursor offset. Use the
// context from NewSubscriberClient (clientCtx) instead.
wireSub, err := factory.New(clientCtx, subInstance.onMessage)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -229,6 +230,7 @@ func (si *subscriberInstance) Wait(ctx context.Context) error {
// See https://cloud.google.com/pubsub/lite/docs/subscribing for more
// information about receiving messages.
type SubscriberClient struct {
clientCtx context.Context
settings ReceiveSettings
wireSubFactory wireSubscriberFactory

Expand Down Expand Up @@ -265,6 +267,7 @@ func NewSubscriberClientWithSettings(ctx context.Context, subscription string, s
options: opts,
}
subClient := &SubscriberClient{
clientCtx: ctx,
settings: settings,
wireSubFactory: factory,
}
Expand Down Expand Up @@ -303,7 +306,7 @@ func (s *SubscriberClient) Receive(ctx context.Context, f func(context.Context,
defer s.setReceiveActive(false)

// Initialize a subscriber instance.
subInstance, err := newSubscriberInstance(ctx, s.wireSubFactory, s.settings, f)
subInstance, err := newSubscriberInstance(ctx, s.clientCtx, s.wireSubFactory, s.settings, f)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pubsublite/pscompat/subscriber_test.go
Expand Up @@ -113,7 +113,7 @@ func (ms *mockWireSubscriber) WaitStopped() error {

type mockWireSubscriberFactory struct{}

func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
return &mockWireSubscriber{
receiver: receiver,
msgsC: make(chan *wire.ReceivedMessage, 10),
Expand All @@ -122,7 +122,7 @@ func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire
}

func newTestSubscriberInstance(ctx context.Context, settings ReceiveSettings, receiver messageReceiverFunc) *subscriberInstance {
sub, _ := newSubscriberInstance(ctx, new(mockWireSubscriberFactory), settings, receiver)
sub, _ := newSubscriberInstance(ctx, context.Background(), new(mockWireSubscriberFactory), settings, receiver)
return sub
}

Expand Down