Skip to content

Commit

Permalink
feat(pubsublite): detect stream reset signal (#4144)
Browse files Browse the repository at this point in the history
Modifies retryableStream to detect the Pub/Sub Lite RESET signal and forward the notification to the streamHandler implementation.
  • Loading branch information
tmdiep committed May 26, 2021
1 parent 66c1176 commit ff5f8c9
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 11 deletions.
25 changes: 24 additions & 1 deletion pubsublite/internal/wire/rpc.go
Expand Up @@ -22,6 +22,7 @@ import (

"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
Expand Down Expand Up @@ -130,7 +131,29 @@ func retryableReadOnlyCallOption() gax.CallOption {
})
}

const pubsubLiteDefaultEndpoint = "-pubsublite.googleapis.com:443"
const (
pubsubLiteDefaultEndpoint = "-pubsublite.googleapis.com:443"
pubsubLiteErrorDomain = "pubsublite.googleapis.com"
resetSignal = "RESET"
)

// Pub/Sub Lite's RESET signal is a status containing error details that
// instructs streams to reset their state.
func isStreamResetSignal(err error) bool {
status, ok := status.FromError(err)
if !ok {
return false
}
if !isRetryableRecvCode(status.Code()) {
return false
}
for _, details := range status.Details() {
if errInfo, ok := details.(*errdetails.ErrorInfo); ok && errInfo.Reason == resetSignal && errInfo.Domain == pubsubLiteErrorDomain {
return true
}
}
return false
}

func defaultClientOptions(region string) []option.ClientOption {
return []option.ClientOption{
Expand Down
87 changes: 86 additions & 1 deletion pubsublite/internal/wire/rpc_test.go
Expand Up @@ -15,13 +15,98 @@ package wire

import (
"encoding/base64"
"errors"
"log"
"testing"

"cloud.google.com/go/internal/testutil"
"github.com/golang/protobuf/proto"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"

spb "google.golang.org/genproto/googleapis/rpc/status"
)

func makeAny(msg proto.Message) *anypb.Any {
any, err := anypb.New(msg)
if err != nil {
log.Fatalf("Failed to make Any: %v", err)
}
return any
}

func makeStreamResetSignal() error {
statuspb := &spb.Status{
Code: int32(codes.Aborted),
Details: []*anypb.Any{makeAny(&errdetails.ErrorInfo{
Reason: "RESET", Domain: "pubsublite.googleapis.com",
})},
}
return status.ErrorProto(statuspb)
}

func TestIsStreamResetSignal(t *testing.T) {
for _, tc := range []struct {
desc string
err error
want bool
}{
{
desc: "reset signal",
err: makeStreamResetSignal(),
want: true,
},
{
desc: "non-retryable code",
err: status.ErrorProto(&spb.Status{
Code: int32(codes.FailedPrecondition),
Details: []*anypb.Any{makeAny(&errdetails.ErrorInfo{Reason: "RESET", Domain: "pubsublite.googleapis.com"})},
}),
want: false,
},
{
desc: "wrong domain",
err: status.ErrorProto(&spb.Status{
Code: int32(codes.Aborted),
Details: []*anypb.Any{makeAny(&errdetails.ErrorInfo{Reason: "RESET"})},
}),
want: false,
},
{
desc: "wrong reason",
err: status.ErrorProto(&spb.Status{
Code: int32(codes.Aborted),
Details: []*anypb.Any{makeAny(&errdetails.ErrorInfo{Domain: "pubsublite.googleapis.com"})},
}),
want: false,
},
{
desc: "missing details",
err: status.ErrorProto(&spb.Status{Code: int32(codes.Aborted)}),
want: false,
},
{
desc: "nil error",
err: nil,
want: false,
},
{
desc: "generic error",
err: errors.New(""),
want: false,
},
} {
t.Run(tc.desc, func(t *testing.T) {
if got := isStreamResetSignal(tc.err); got != tc.want {
t.Errorf("isStreamResetSignal() got: %v, want %v", got, tc.want)
}
})
}
}

func TestPubsubMetadataAddClientInfo(t *testing.T) {
for _, tc := range []struct {
desc string
Expand Down
35 changes: 26 additions & 9 deletions pubsublite/internal/wire/streams.go
Expand Up @@ -35,11 +35,13 @@ type streamStatus int
const (
streamUninitialized streamStatus = iota
streamReconnecting
streamResetState
streamConnected
streamTerminated
)

type initialResponseRequired bool
type notifyReset bool

// streamHandler provides hooks for different Pub/Sub Lite streaming APIs
// (e.g. publish, subscribe, streaming cursor, etc.) to use retryableStream.
Expand All @@ -63,11 +65,15 @@ type streamHandler interface {
validateInitialResponse(interface{}) error

// onStreamStatusChange is used to notify stream handlers when the stream has
// changed state. A `streamReconnecting` status change is fired before
// attempting to connect a new stream. A `streamConnected` status change is
// fired when the stream is successfully connected. These are followed by
// onResponse() calls when responses are received from the server. These
// events are guaranteed to occur in this order.
// changed state.
// - A `streamReconnecting` status change is fired before attempting to
// connect a new stream.
// - A `streamResetState` status change may be fired if the stream should
// reset its state (due to receipt of the RESET signal from the server).
// - A `streamConnected` status change is fired when the stream is
// successfully connected.
// These are followed by onResponse() calls when responses are received from
// the server. These events are guaranteed to occur in this order.
//
// A final `streamTerminated` status change is fired when a permanent error
// occurs. retryableStream.Error() returns the error that caused the stream to
Expand Down Expand Up @@ -124,7 +130,7 @@ func (rs *retryableStream) Start() {
defer rs.mu.Unlock()

if rs.status == streamUninitialized {
go rs.connectStream()
go rs.connectStream(notifyReset(false))
}
}

Expand All @@ -151,7 +157,7 @@ func (rs *retryableStream) Send(request interface{}) (sent bool) {
// stream. Nothing to do here.
break
case isRetryableSendError(err):
go rs.connectStream()
go rs.connectStream(notifyReset(false))
default:
rs.unsafeTerminate(err)
}
Expand Down Expand Up @@ -206,7 +212,7 @@ func (rs *retryableStream) setCancel(cancel context.CancelFunc) {
// terminated during reconnection.
//
// Intended to be called in a goroutine. It ends once the client stream closes.
func (rs *retryableStream) connectStream() {
func (rs *retryableStream) connectStream(notifyReset notifyReset) {
canReconnect := func() bool {
rs.mu.Lock()
defer rs.mu.Unlock()
Expand All @@ -225,7 +231,15 @@ func (rs *retryableStream) connectStream() {
if !canReconnect() {
return
}

rs.handler.onStreamStatusChange(streamReconnecting)
if notifyReset {
rs.handler.onStreamStatusChange(streamResetState)
}
// Check whether handler terminated stream before reconnecting.
if rs.Status() == streamTerminated {
return
}

newStream, cancelFunc, err := rs.initNewStream()
if err != nil {
Expand Down Expand Up @@ -275,6 +289,9 @@ func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelF
if needsResponse {
response := reflect.New(rs.responseType).Interface()
if err = newStream.RecvMsg(response); err != nil {
if isStreamResetSignal(err) {
rs.handler.onStreamStatusChange(streamResetState)
}
return r.RetryRecv(err)
}
if err = rs.handler.validateInitialResponse(response); err != nil {
Expand Down Expand Up @@ -323,7 +340,7 @@ func (rs *retryableStream) listen(recvStream grpc.ClientStream) {
}
if err != nil {
if isRetryableRecvError(err) {
go rs.connectStream()
go rs.connectStream(notifyReset(isStreamResetSignal(err)))
} else {
rs.terminate(err)
}
Expand Down
81 changes: 81 additions & 0 deletions pubsublite/internal/wire/streams_test.go
Expand Up @@ -358,3 +358,84 @@ func TestRetryableStreamSendReceive(t *testing.T) {
t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
}
}

func TestRetryableStreamConnectReceivesResetSignal(t *testing.T) {
pub := newTestStreamHandler(t, defaultStreamTimeout)

verifiers := test.NewVerifiers(t)

stream1 := test.NewRPCVerifier(t)
// Reset signal received during stream initialization.
stream1.Push(pub.InitialReq, nil, makeStreamResetSignal())
verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1)

stream2 := test.NewRPCVerifier(t)
stream2.Push(pub.InitialReq, initPubResp(), nil)
verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2)

mockServer.OnTestStart(verifiers)
defer mockServer.OnTestEnd()

pub.Stream.Start()
if got, want := pub.NextStatus(), streamReconnecting; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}
if got, want := pub.NextStatus(), streamResetState; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}
if got, want := pub.NextStatus(), streamConnected; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}

pub.Stream.Stop()
if got, want := pub.NextStatus(), streamTerminated; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}
if gotErr := pub.Stream.Error(); gotErr != nil {
t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
}
}

func TestRetryableStreamDisconnectedWithResetSignal(t *testing.T) {
pub := newTestStreamHandler(t, defaultStreamTimeout)

verifiers := test.NewVerifiers(t)

stream1 := test.NewRPCVerifier(t)
stream1.Push(pub.InitialReq, initPubResp(), nil)
// Reset signal received after stream is connected.
stream1.Push(nil, nil, makeStreamResetSignal())
verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1)

stream2 := test.NewRPCVerifier(t)
stream2.Push(pub.InitialReq, initPubResp(), nil)
verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2)

mockServer.OnTestStart(verifiers)
defer mockServer.OnTestEnd()

pub.Stream.Start()
if got, want := pub.NextStatus(), streamReconnecting; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}
if got, want := pub.NextStatus(), streamConnected; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}
if got, want := pub.NextStatus(), streamReconnecting; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}
if got, want := pub.NextStatus(), streamResetState; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}
if got, want := pub.NextStatus(), streamConnected; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}

pub.Stream.Stop()
if got, want := pub.NextStatus(), streamTerminated; got != want {
t.Errorf("Stream status change: got %d, want %d", got, want)
}
if gotErr := pub.Stream.Error(); gotErr != nil {
t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
}
}

0 comments on commit ff5f8c9

Please sign in to comment.