From e75262cf5eba845271965eab3c28c0a23bec14c4 Mon Sep 17 00:00:00 2001 From: tmdiep Date: Tue, 29 Jun 2021 06:46:12 +1000 Subject: [PATCH] fix(pubsublite): ensure timeout settings are respected (#4329) Fixes for retryableStream and partitionCountWatcher to ensure PublisherSettings.Timeout and ReceiveSettings.Timeout are respected. --- pubsublite/internal/wire/partition_count.go | 40 +++++++--- .../internal/wire/partition_count_test.go | 56 ++++++++++++- pubsublite/internal/wire/publisher_test.go | 5 +- pubsublite/internal/wire/request_timer.go | 78 +++++++++++++++++++ .../internal/wire/request_timer_test.go | 61 +++++++++++++++ pubsublite/internal/wire/streams.go | 65 +++++++++++----- pubsublite/internal/wire/streams_test.go | 41 ++++++++++ 7 files changed, 311 insertions(+), 35 deletions(-) create mode 100644 pubsublite/internal/wire/request_timer.go create mode 100644 pubsublite/internal/wire/request_timer_test.go diff --git a/pubsublite/internal/wire/partition_count.go b/pubsublite/internal/wire/partition_count.go index c4319a536f0..4255a647443 100644 --- a/pubsublite/internal/wire/partition_count.go +++ b/pubsublite/internal/wire/partition_count.go @@ -16,6 +16,9 @@ package wire import ( "context" "fmt" + "time" + + "golang.org/x/xerrors" vkit "cloud.google.com/go/pubsublite/apiv1" gax "github.com/googleapis/gax-go/v2" @@ -30,11 +33,13 @@ type partitionCountReceiver func(partitionCount int) // topic and notifies a receiver if it changes. type partitionCountWatcher struct { // Immutable after creation. - ctx context.Context - adminClient *vkit.AdminClient - topicPath string - receiver partitionCountReceiver - callOption gax.CallOption + ctx context.Context + adminClient *vkit.AdminClient + topicPath string + receiver partitionCountReceiver + callOption gax.CallOption + initialTimeout time.Duration + pollPeriod time.Duration // Fields below must be guarded with mu. partitionCount int @@ -47,11 +52,13 @@ func newPartitionCountWatcher(ctx context.Context, adminClient *vkit.AdminClient settings PublishSettings, topicPath string, receiver partitionCountReceiver) *partitionCountWatcher { p := &partitionCountWatcher{ - ctx: ctx, - adminClient: adminClient, - topicPath: topicPath, - receiver: receiver, - callOption: resourceExhaustedRetryer(), + ctx: ctx, + adminClient: adminClient, + topicPath: topicPath, + receiver: receiver, + callOption: resourceExhaustedRetryer(), + initialTimeout: settings.Timeout, + pollPeriod: settings.ConfigPollPeriod, } // Polling the topic partition count can be disabled in settings if the period @@ -88,8 +95,17 @@ func (p *partitionCountWatcher) updatePartitionCount() { p.mu.Unlock() newPartitionCount, err := func() (int, error) { + // Ensure the first update respects PublishSettings.Timeout. + timeout := p.initialTimeout + if prevPartitionCount > 0 { + timeout = p.pollPeriod + } + cctx, cancel := context.WithCancel(p.ctx) + rt := newRequestTimer(timeout, cancel, ErrBackendUnavailable) + req := &pb.GetTopicPartitionsRequest{Name: p.topicPath} - resp, err := p.adminClient.GetTopicPartitions(p.ctx, req, p.callOption) + resp, err := p.adminClient.GetTopicPartitions(cctx, req, p.callOption) + rt.Stop() p.mu.Lock() defer p.mu.Unlock() @@ -105,7 +121,7 @@ func (p *partitionCountWatcher) updatePartitionCount() { // TODO: Log the error. return p.partitionCount, nil } - err = fmt.Errorf("pubsublite: failed to update topic partition count: %v", err) + err = xerrors.Errorf("pubsublite: failed to update topic partition count: %w", rt.ResolveError(err)) p.unsafeInitiateShutdown(err) return 0, err } diff --git a/pubsublite/internal/wire/partition_count_test.go b/pubsublite/internal/wire/partition_count_test.go index aef4387e0b9..b203425eecf 100644 --- a/pubsublite/internal/wire/partition_count_test.go +++ b/pubsublite/internal/wire/partition_count_test.go @@ -16,6 +16,7 @@ package wire import ( "context" "testing" + "time" "cloud.google.com/go/internal/testutil" "cloud.google.com/go/pubsublite/internal/test" @@ -54,7 +55,7 @@ func newTestPartitionCountWatcher(t *testing.T, topicPath string, settings Publi tw := &testPartitionCountWatcher{ t: t, } - tw.watcher = newPartitionCountWatcher(ctx, adminClient, testPublishSettings(), topicPath, tw.onCountChanged) + tw.watcher = newPartitionCountWatcher(ctx, adminClient, settings, topicPath, tw.onCountChanged) tw.initAndStart(t, tw.watcher, "PartitionCountWatcher", adminClient) return tw } @@ -95,6 +96,59 @@ func TestPartitionCountWatcherZeroPartitionCountFails(t *testing.T) { watcher.VerifyCounts(nil) } +func TestPartitionCountWatcherInitialRequestTimesOut(t *testing.T) { + const topic = "projects/123456/locations/us-central1-b/topics/my-topic" + + verifiers := test.NewVerifiers(t) + barrier := verifiers.GlobalVerifier.PushWithBarrier(topicPartitionsReq(topic), topicPartitionsResp(1), nil) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + settings := testPublishSettings() + settings.Timeout = 20 * time.Millisecond // Set low timeout for initial request + watcher := newTestPartitionCountWatcher(t, topic, settings) + + if gotErr, wantErr := watcher.StartError(), ErrBackendUnavailable; !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("Start() got err: (%v), want err: (%v)", gotErr, wantErr) + } + barrier.Release() + watcher.VerifyCounts(nil) +} + +func TestPartitionCountWatcherUpdateLongerTimeouts(t *testing.T) { + const topic = "projects/123456/locations/us-central1-b/topics/my-topic" + wantPartitionCount1 := 1 + wantPartitionCount2 := 2 + + verifiers := test.NewVerifiers(t) + verifiers.GlobalVerifier.Push(topicPartitionsReq(topic), topicPartitionsResp(wantPartitionCount1), nil) + // Barrier used to delay response. + barrier := verifiers.GlobalVerifier.PushWithBarrier(topicPartitionsReq(topic), topicPartitionsResp(wantPartitionCount2), nil) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + watcher := newTestPartitionCountWatcher(t, topic, testPublishSettings()) + if gotErr := watcher.StartError(); gotErr != nil { + t.Errorf("Start() got err: (%v)", gotErr) + } + watcher.VerifyCounts([]int{wantPartitionCount1}) + + // Override the initial timeout after the first request to verify that it + // isn't used. If set at creation, the first request will fail. + const timeout = time.Millisecond + watcher.watcher.initialTimeout = timeout + go func() { + barrier.ReleaseAfter(func() { + time.Sleep(5 * timeout) + }) + }() + watcher.UpdatePartitionCount() + watcher.VerifyCounts([]int{wantPartitionCount1, wantPartitionCount2}) + watcher.StopVerifyNoError() +} + func TestPartitionCountWatcherPartitionCountUnchanged(t *testing.T) { const topic = "projects/123456/locations/us-central1-b/topics/my-topic" wantPartitionCount1 := 4 diff --git a/pubsublite/internal/wire/publisher_test.go b/pubsublite/internal/wire/publisher_test.go index b407c651dc4..47eda71f5be 100644 --- a/pubsublite/internal/wire/publisher_test.go +++ b/pubsublite/internal/wire/publisher_test.go @@ -34,8 +34,9 @@ func testPublishSettings() PublishSettings { // Send messages with minimal delay to speed up tests. settings.DelayThreshold = time.Millisecond settings.Timeout = 5 * time.Second - // Disable topic partition count background polling. - settings.ConfigPollPeriod = 0 + // Set long poll period to prevent background update, but still have non-zero + // request timeout. + settings.ConfigPollPeriod = 1 * time.Minute return settings } diff --git a/pubsublite/internal/wire/request_timer.go b/pubsublite/internal/wire/request_timer.go new file mode 100644 index 00000000000..ee3a1565e54 --- /dev/null +++ b/pubsublite/internal/wire/request_timer.go @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "sync" + "time" +) + +type requestTimerStatus int + +const ( + requestTimerNew requestTimerStatus = iota + requestTimerStopped + requestTimerTriggered +) + +// requestTimer bounds the duration of a request and executes `onTimeout` if +// the timer is triggered. +type requestTimer struct { + onTimeout func() + timeoutErr error + timer *time.Timer + mu sync.Mutex + status requestTimerStatus +} + +func newRequestTimer(duration time.Duration, onTimeout func(), timeoutErr error) *requestTimer { + rt := &requestTimer{ + onTimeout: onTimeout, + timeoutErr: timeoutErr, + status: requestTimerNew, + } + rt.timer = time.AfterFunc(duration, rt.onTriggered) + return rt +} + +func (rt *requestTimer) onTriggered() { + rt.mu.Lock() + defer rt.mu.Unlock() + if rt.status == requestTimerNew { + rt.status = requestTimerTriggered + rt.onTimeout() + } +} + +// Stop should be called upon a successful request to prevent the timer from +// expiring. +func (rt *requestTimer) Stop() { + rt.mu.Lock() + defer rt.mu.Unlock() + if rt.status == requestTimerNew { + rt.status = requestTimerStopped + rt.timer.Stop() + } +} + +// ResolveError returns `timeoutErr` if the timer triggered, or otherwise +// `originalErr`. +func (rt *requestTimer) ResolveError(originalErr error) error { + rt.mu.Lock() + defer rt.mu.Unlock() + if rt.status == requestTimerTriggered { + return rt.timeoutErr + } + return originalErr +} diff --git a/pubsublite/internal/wire/request_timer_test.go b/pubsublite/internal/wire/request_timer_test.go new file mode 100644 index 00000000000..86c96c2b9dd --- /dev/null +++ b/pubsublite/internal/wire/request_timer_test.go @@ -0,0 +1,61 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "errors" + "testing" + "time" + + "cloud.google.com/go/pubsublite/internal/test" +) + +func TestRequestTimerStop(t *testing.T) { + const timeout = 5 * time.Millisecond + onTimeout := func() { + t.Error("onTimeout should not be called") + } + + rt := newRequestTimer(timeout, onTimeout, errors.New("unused")) + rt.Stop() + time.Sleep(2 * timeout) + + if err := rt.ResolveError(nil); err != nil { + t.Errorf("ResolveError() got gotErr: %v", err) + } + wantErr := errors.New("original error") + if gotErr := rt.ResolveError(wantErr); !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("ResolveError() got err: %v, want err: %v", gotErr, wantErr) + } +} + +func TestRequestTimerExpires(t *testing.T) { + const timeout = 5 * time.Millisecond + timeoutErr := errors.New("on timeout") + + expired := test.NewCondition("request timer expired") + onTimeout := func() { + expired.SetDone() + } + + rt := newRequestTimer(timeout, onTimeout, timeoutErr) + expired.WaitUntilDone(t, serviceTestWaitTimeout) + + if gotErr := rt.ResolveError(nil); !test.ErrorEqual(gotErr, timeoutErr) { + t.Errorf("ResolveError() got err: %v, want err: %v", gotErr, timeoutErr) + } + if gotErr := rt.ResolveError(errors.New("ignored")); !test.ErrorEqual(gotErr, timeoutErr) { + t.Errorf("ResolveError() got err: %v, want err: %v", gotErr, timeoutErr) + } +} diff --git a/pubsublite/internal/wire/streams.go b/pubsublite/internal/wire/streams.go index a8cc1edacba..76624178c07 100644 --- a/pubsublite/internal/wire/streams.go +++ b/pubsublite/internal/wire/streams.go @@ -22,6 +22,8 @@ import ( "golang.org/x/xerrors" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" gax "github.com/googleapis/gax-go/v2" ) @@ -40,6 +42,11 @@ const ( streamTerminated ) +// Abort a stream initialization attempt after this duration to mitigate delays. +const defaultInitTimeout = 2 * time.Minute + +var errStreamInitTimeout = status.Error(codes.DeadlineExceeded, "pubsublite: stream initialization timed out") + type initialResponseRequired bool type notifyReset bool @@ -95,10 +102,11 @@ type streamHandler interface { // are private implementation. type retryableStream struct { // Immutable after creation. - ctx context.Context - handler streamHandler - responseType reflect.Type - timeout time.Duration + ctx context.Context + handler streamHandler + responseType reflect.Type + connectTimeout time.Duration + initTimeout time.Duration // Guards access to fields below. mu sync.Mutex @@ -115,11 +123,16 @@ type retryableStream struct { // maximum duration for reconnection. `responseType` is the type of the response // proto received on the stream. func newRetryableStream(ctx context.Context, handler streamHandler, timeout time.Duration, responseType reflect.Type) *retryableStream { + initTimeout := defaultInitTimeout + if timeout < defaultInitTimeout { + initTimeout = timeout + } return &retryableStream{ - ctx: ctx, - handler: handler, - responseType: responseType, - timeout: timeout, + ctx: ctx, + handler: handler, + responseType: responseType, + connectTimeout: timeout, + initTimeout: initTimeout, } } @@ -198,12 +211,14 @@ func (rs *retryableStream) unsafeClearStream() { } } -func (rs *retryableStream) setCancel(cancel context.CancelFunc) { +func (rs *retryableStream) newStreamContext() (ctx context.Context, cancel context.CancelFunc) { rs.mu.Lock() defer rs.mu.Unlock() rs.unsafeClearStream() + ctx, cancel = context.WithCancel(rs.ctx) rs.cancelStream = cancel + return } // connectStream attempts to establish a valid connection with the server. Due @@ -241,7 +256,7 @@ func (rs *retryableStream) connectStream(notifyReset notifyReset) { return } - newStream, cancelFunc, err := rs.initNewStream() + newStream, err := rs.initNewStream() if err != nil { rs.terminate(err) return @@ -257,7 +272,6 @@ func (rs *retryableStream) connectStream(notifyReset notifyReset) { } rs.status = streamConnected rs.stream = newStream - rs.cancelStream = cancelFunc return true } if !connected() { @@ -268,27 +282,32 @@ func (rs *retryableStream) connectStream(notifyReset notifyReset) { rs.listen(newStream) } -func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelFunc context.CancelFunc, err error) { - r := newStreamRetryer(rs.timeout) +func (rs *retryableStream) newInitTimer(cancelFunc func()) *requestTimer { + return newRequestTimer(rs.initTimeout, cancelFunc, errStreamInitTimeout) +} + +func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, err error) { + var cancelFunc context.CancelFunc + r := newStreamRetryer(rs.connectTimeout) for { backoff, shouldRetry := func() (time.Duration, bool) { var cctx context.Context - cctx, cancelFunc = context.WithCancel(rs.ctx) - // Store the cancel func to quickly cancel reconnecting if the stream is - // terminated. - rs.setCancel(cancelFunc) + cctx, cancelFunc = rs.newStreamContext() + // Bound the duration of the stream initialization attempt. + it := rs.newInitTimer(cancelFunc) + defer it.Stop() newStream, err = rs.handler.newStream(cctx) - if err != nil { + if err = it.ResolveError(err); err != nil { return r.RetryRecv(err) } initReq, needsResponse := rs.handler.initialRequest() - if err = newStream.SendMsg(initReq); err != nil { + if err = it.ResolveError(newStream.SendMsg(initReq)); err != nil { return r.RetrySend(err) } if needsResponse { response := reflect.New(rs.responseType).Interface() - if err = newStream.RecvMsg(response); err != nil { + if err = it.ResolveError(newStream.RecvMsg(response)); err != nil { if isStreamResetSignal(err) { rs.handler.onStreamStatusChange(streamResetState) } @@ -301,6 +320,12 @@ func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelF } } + // If the init timer fired due to a race, the stream would be unusable. + it.Stop() + if err = it.ResolveError(nil); err != nil { + return r.RetryRecv(err) + } + // We have a valid connection and should break from the outer loop. return 0, false }() diff --git a/pubsublite/internal/wire/streams_test.go b/pubsublite/internal/wire/streams_test.go index fb8e84e414e..368f39dacaf 100644 --- a/pubsublite/internal/wire/streams_test.go +++ b/pubsublite/internal/wire/streams_test.go @@ -284,6 +284,7 @@ func TestRetryableStreamConnectTimeout(t *testing.T) { // Set a very low timeout to ensure no retries. timeout := time.Millisecond pub := newTestStreamHandler(t, timeout) + pub.Stream.initTimeout = defaultInitTimeout wantErr := status.Error(codes.DeadlineExceeded, "timeout") verifiers := test.NewVerifiers(t) @@ -314,6 +315,46 @@ func TestRetryableStreamConnectTimeout(t *testing.T) { } } +func TestRetryableStreamInitTimeout(t *testing.T) { + const streamInitTimeout = 50 * time.Millisecond + const streamResponseDelay = 75 * time.Millisecond + + pub := newTestStreamHandler(t, defaultStreamTimeout) + pub.Stream.initTimeout = streamInitTimeout + + verifiers := test.NewVerifiers(t) + + // First stream will have a delayed response. + stream1 := test.NewRPCVerifier(t) + barrier := stream1.PushWithBarrier(pub.InitialReq, initPubResp(), nil) + verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1) + + // Second stream should succeed. + stream2 := test.NewRPCVerifier(t) + stream2.Push(pub.InitialReq, initPubResp(), nil) + verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + pub.Stream.Start() + if got, want := pub.NextStatus(), streamReconnecting; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + + barrier.ReleaseAfter(func() { + time.Sleep(streamResponseDelay) + }) + if got, want := pub.NextStatus(), streamConnected; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + + pub.Stream.Stop() + if got, want := pub.NextStatus(), streamTerminated; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } +} + func TestRetryableStreamSendReceive(t *testing.T) { pub := newTestStreamHandler(t, defaultStreamTimeout) req := msgPubReq(&pb.PubSubMessage{Data: []byte("msg")})