From ff5f8c989cba2751dcc77745483ef3828e6df78c Mon Sep 17 00:00:00 2001 From: tmdiep Date: Thu, 27 May 2021 09:31:52 +1000 Subject: [PATCH] feat(pubsublite): detect stream reset signal (#4144) Modifies retryableStream to detect the Pub/Sub Lite RESET signal and forward the notification to the streamHandler implementation. --- pubsublite/internal/wire/rpc.go | 25 ++++++- pubsublite/internal/wire/rpc_test.go | 87 +++++++++++++++++++++++- pubsublite/internal/wire/streams.go | 35 +++++++--- pubsublite/internal/wire/streams_test.go | 81 ++++++++++++++++++++++ 4 files changed, 217 insertions(+), 11 deletions(-) diff --git a/pubsublite/internal/wire/rpc.go b/pubsublite/internal/wire/rpc.go index e5c6dfb9359..6519a4db225 100644 --- a/pubsublite/internal/wire/rpc.go +++ b/pubsublite/internal/wire/rpc.go @@ -22,6 +22,7 @@ import ( "google.golang.org/api/option" "google.golang.org/api/option/internaloption" + "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/keepalive" @@ -130,7 +131,29 @@ func retryableReadOnlyCallOption() gax.CallOption { }) } -const pubsubLiteDefaultEndpoint = "-pubsublite.googleapis.com:443" +const ( + pubsubLiteDefaultEndpoint = "-pubsublite.googleapis.com:443" + pubsubLiteErrorDomain = "pubsublite.googleapis.com" + resetSignal = "RESET" +) + +// Pub/Sub Lite's RESET signal is a status containing error details that +// instructs streams to reset their state. +func isStreamResetSignal(err error) bool { + status, ok := status.FromError(err) + if !ok { + return false + } + if !isRetryableRecvCode(status.Code()) { + return false + } + for _, details := range status.Details() { + if errInfo, ok := details.(*errdetails.ErrorInfo); ok && errInfo.Reason == resetSignal && errInfo.Domain == pubsubLiteErrorDomain { + return true + } + } + return false +} func defaultClientOptions(region string) []option.ClientOption { return []option.ClientOption{ diff --git a/pubsublite/internal/wire/rpc_test.go b/pubsublite/internal/wire/rpc_test.go index 8fe0f5b88a7..850a946a76f 100644 --- a/pubsublite/internal/wire/rpc_test.go +++ b/pubsublite/internal/wire/rpc_test.go @@ -15,13 +15,98 @@ package wire import ( "encoding/base64" + "errors" + "log" "testing" "cloud.google.com/go/internal/testutil" - "github.com/golang/protobuf/proto" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/structpb" + + spb "google.golang.org/genproto/googleapis/rpc/status" ) +func makeAny(msg proto.Message) *anypb.Any { + any, err := anypb.New(msg) + if err != nil { + log.Fatalf("Failed to make Any: %v", err) + } + return any +} + +func makeStreamResetSignal() error { + statuspb := &spb.Status{ + Code: int32(codes.Aborted), + Details: []*anypb.Any{makeAny(&errdetails.ErrorInfo{ + Reason: "RESET", Domain: "pubsublite.googleapis.com", + })}, + } + return status.ErrorProto(statuspb) +} + +func TestIsStreamResetSignal(t *testing.T) { + for _, tc := range []struct { + desc string + err error + want bool + }{ + { + desc: "reset signal", + err: makeStreamResetSignal(), + want: true, + }, + { + desc: "non-retryable code", + err: status.ErrorProto(&spb.Status{ + Code: int32(codes.FailedPrecondition), + Details: []*anypb.Any{makeAny(&errdetails.ErrorInfo{Reason: "RESET", Domain: "pubsublite.googleapis.com"})}, + }), + want: false, + }, + { + desc: "wrong domain", + err: status.ErrorProto(&spb.Status{ + Code: int32(codes.Aborted), + Details: []*anypb.Any{makeAny(&errdetails.ErrorInfo{Reason: "RESET"})}, + }), + want: false, + }, + { + desc: "wrong reason", + err: status.ErrorProto(&spb.Status{ + Code: int32(codes.Aborted), + Details: []*anypb.Any{makeAny(&errdetails.ErrorInfo{Domain: "pubsublite.googleapis.com"})}, + }), + want: false, + }, + { + desc: "missing details", + err: status.ErrorProto(&spb.Status{Code: int32(codes.Aborted)}), + want: false, + }, + { + desc: "nil error", + err: nil, + want: false, + }, + { + desc: "generic error", + err: errors.New(""), + want: false, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + if got := isStreamResetSignal(tc.err); got != tc.want { + t.Errorf("isStreamResetSignal() got: %v, want %v", got, tc.want) + } + }) + } +} + func TestPubsubMetadataAddClientInfo(t *testing.T) { for _, tc := range []struct { desc string diff --git a/pubsublite/internal/wire/streams.go b/pubsublite/internal/wire/streams.go index 752acb921c8..a8cc1edacba 100644 --- a/pubsublite/internal/wire/streams.go +++ b/pubsublite/internal/wire/streams.go @@ -35,11 +35,13 @@ type streamStatus int const ( streamUninitialized streamStatus = iota streamReconnecting + streamResetState streamConnected streamTerminated ) type initialResponseRequired bool +type notifyReset bool // streamHandler provides hooks for different Pub/Sub Lite streaming APIs // (e.g. publish, subscribe, streaming cursor, etc.) to use retryableStream. @@ -63,11 +65,15 @@ type streamHandler interface { validateInitialResponse(interface{}) error // onStreamStatusChange is used to notify stream handlers when the stream has - // changed state. A `streamReconnecting` status change is fired before - // attempting to connect a new stream. A `streamConnected` status change is - // fired when the stream is successfully connected. These are followed by - // onResponse() calls when responses are received from the server. These - // events are guaranteed to occur in this order. + // changed state. + // - A `streamReconnecting` status change is fired before attempting to + // connect a new stream. + // - A `streamResetState` status change may be fired if the stream should + // reset its state (due to receipt of the RESET signal from the server). + // - A `streamConnected` status change is fired when the stream is + // successfully connected. + // These are followed by onResponse() calls when responses are received from + // the server. These events are guaranteed to occur in this order. // // A final `streamTerminated` status change is fired when a permanent error // occurs. retryableStream.Error() returns the error that caused the stream to @@ -124,7 +130,7 @@ func (rs *retryableStream) Start() { defer rs.mu.Unlock() if rs.status == streamUninitialized { - go rs.connectStream() + go rs.connectStream(notifyReset(false)) } } @@ -151,7 +157,7 @@ func (rs *retryableStream) Send(request interface{}) (sent bool) { // stream. Nothing to do here. break case isRetryableSendError(err): - go rs.connectStream() + go rs.connectStream(notifyReset(false)) default: rs.unsafeTerminate(err) } @@ -206,7 +212,7 @@ func (rs *retryableStream) setCancel(cancel context.CancelFunc) { // terminated during reconnection. // // Intended to be called in a goroutine. It ends once the client stream closes. -func (rs *retryableStream) connectStream() { +func (rs *retryableStream) connectStream(notifyReset notifyReset) { canReconnect := func() bool { rs.mu.Lock() defer rs.mu.Unlock() @@ -225,7 +231,15 @@ func (rs *retryableStream) connectStream() { if !canReconnect() { return } + rs.handler.onStreamStatusChange(streamReconnecting) + if notifyReset { + rs.handler.onStreamStatusChange(streamResetState) + } + // Check whether handler terminated stream before reconnecting. + if rs.Status() == streamTerminated { + return + } newStream, cancelFunc, err := rs.initNewStream() if err != nil { @@ -275,6 +289,9 @@ func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelF if needsResponse { response := reflect.New(rs.responseType).Interface() if err = newStream.RecvMsg(response); err != nil { + if isStreamResetSignal(err) { + rs.handler.onStreamStatusChange(streamResetState) + } return r.RetryRecv(err) } if err = rs.handler.validateInitialResponse(response); err != nil { @@ -323,7 +340,7 @@ func (rs *retryableStream) listen(recvStream grpc.ClientStream) { } if err != nil { if isRetryableRecvError(err) { - go rs.connectStream() + go rs.connectStream(notifyReset(isStreamResetSignal(err))) } else { rs.terminate(err) } diff --git a/pubsublite/internal/wire/streams_test.go b/pubsublite/internal/wire/streams_test.go index 2079f8a8882..fb8e84e414e 100644 --- a/pubsublite/internal/wire/streams_test.go +++ b/pubsublite/internal/wire/streams_test.go @@ -358,3 +358,84 @@ func TestRetryableStreamSendReceive(t *testing.T) { t.Errorf("Stream final err: got (%v), want ", gotErr) } } + +func TestRetryableStreamConnectReceivesResetSignal(t *testing.T) { + pub := newTestStreamHandler(t, defaultStreamTimeout) + + verifiers := test.NewVerifiers(t) + + stream1 := test.NewRPCVerifier(t) + // Reset signal received during stream initialization. + stream1.Push(pub.InitialReq, nil, makeStreamResetSignal()) + verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1) + + stream2 := test.NewRPCVerifier(t) + stream2.Push(pub.InitialReq, initPubResp(), nil) + verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + pub.Stream.Start() + if got, want := pub.NextStatus(), streamReconnecting; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + if got, want := pub.NextStatus(), streamResetState; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + if got, want := pub.NextStatus(), streamConnected; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + + pub.Stream.Stop() + if got, want := pub.NextStatus(), streamTerminated; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + if gotErr := pub.Stream.Error(); gotErr != nil { + t.Errorf("Stream final err: got (%v), want ", gotErr) + } +} + +func TestRetryableStreamDisconnectedWithResetSignal(t *testing.T) { + pub := newTestStreamHandler(t, defaultStreamTimeout) + + verifiers := test.NewVerifiers(t) + + stream1 := test.NewRPCVerifier(t) + stream1.Push(pub.InitialReq, initPubResp(), nil) + // Reset signal received after stream is connected. + stream1.Push(nil, nil, makeStreamResetSignal()) + verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1) + + stream2 := test.NewRPCVerifier(t) + stream2.Push(pub.InitialReq, initPubResp(), nil) + verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2) + + mockServer.OnTestStart(verifiers) + defer mockServer.OnTestEnd() + + pub.Stream.Start() + if got, want := pub.NextStatus(), streamReconnecting; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + if got, want := pub.NextStatus(), streamConnected; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + if got, want := pub.NextStatus(), streamReconnecting; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + if got, want := pub.NextStatus(), streamResetState; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + if got, want := pub.NextStatus(), streamConnected; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + + pub.Stream.Stop() + if got, want := pub.NextStatus(), streamTerminated; got != want { + t.Errorf("Stream status change: got %d, want %d", got, want) + } + if gotErr := pub.Stream.Error(); gotErr != nil { + t.Errorf("Stream final err: got (%v), want ", gotErr) + } +}