diff --git a/pubsublite/internal/test/mock.go b/pubsublite/internal/test/mock.go new file mode 100644 index 00000000000..0c8c5a4aa7d --- /dev/null +++ b/pubsublite/internal/test/mock.go @@ -0,0 +1,245 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "context" + "fmt" + "io" + "reflect" + "sync" + + "cloud.google.com/go/internal/testutil" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" +) + +// Server is a mock Pub/Sub Lite server that can be used for unit testing. +type Server struct { + LiteServer MockServer + gRPCServer *testutil.Server +} + +// MockServer is an in-memory mock implementation of a Pub/Sub Lite service, +// which allows unit tests to inspect requests received by the server and send +// fake responses. +// This is the interface that should be used by tests. +type MockServer interface { + // OnTestStart must be called at the start of each test to clear any existing + // state and set the verifier for unary RPCs. + OnTestStart(globalVerifier *RPCVerifier) + // OnTestEnd should be called at the end of each test to flush the verifiers + // (i.e. check whether any expected requests were not sent to the server). + OnTestEnd() + // AddPublishStream adds a verifier for a publish stream of a topic partition. + AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) +} + +// NewServer creates a new mock Pub/Sub Lite server. +func NewServer() (*Server, error) { + srv, err := testutil.NewServer() + if err != nil { + return nil, err + } + liteServer := newMockLiteServer() + pb.RegisterAdminServiceServer(srv.Gsrv, liteServer) + pb.RegisterPublisherServiceServer(srv.Gsrv, liteServer) + srv.Start() + return &Server{LiteServer: liteServer, gRPCServer: srv}, nil +} + +// Addr returns the address that the server is listening on. +func (s *Server) Addr() string { + return s.gRPCServer.Addr +} + +// Close shuts down the server and releases all resources. +func (s *Server) Close() { + s.gRPCServer.Close() +} + +type streamHolder struct { + stream grpc.ServerStream + verifier *RPCVerifier +} + +// mockLiteServer implements the MockServer interface. +type mockLiteServer struct { + pb.AdminServiceServer + pb.PublisherServiceServer + + mu sync.Mutex + + // Global list of verifiers for all unary RPCs. This should be set before the + // test begins. + globalVerifier *RPCVerifier + + // Publish stream verifiers by topic & partition. + publishVerifiers *keyedStreamVerifiers + + nextStreamID int + activeStreams map[int]*streamHolder + testActive bool +} + +func key(path string, partition int) string { + return fmt.Sprintf("%s:%d", path, partition) +} + +func newMockLiteServer() *mockLiteServer { + return &mockLiteServer{ + publishVerifiers: newKeyedStreamVerifiers(), + activeStreams: make(map[int]*streamHolder), + } +} + +func (s *mockLiteServer) startStream(stream grpc.ServerStream, verifier *RPCVerifier) (id int) { + s.mu.Lock() + defer s.mu.Unlock() + + id = s.nextStreamID + s.nextStreamID++ + s.activeStreams[id] = &streamHolder{stream: stream, verifier: verifier} + return +} + +func (s *mockLiteServer) endStream(id int) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.activeStreams, id) +} + +func (s *mockLiteServer) popStreamVerifier(key string, keyedVerifiers *keyedStreamVerifiers) (*RPCVerifier, error) { + s.mu.Lock() + defer s.mu.Unlock() + + return keyedVerifiers.Pop(key) +} + +func (s *mockLiteServer) handleStream(stream grpc.ServerStream, req interface{}, requestType reflect.Type, key string, keyedVerifiers *keyedStreamVerifiers) (err error) { + verifier, err := s.popStreamVerifier(key, keyedVerifiers) + if err != nil { + return err + } + + id := s.startStream(stream, verifier) + + // Verify initial request. + retResponse, retErr := verifier.Pop(req) + var ok bool + + for { + if retErr != nil { + err = retErr + break + } + if err = stream.SendMsg(retResponse); err != nil { + err = status.Errorf(codes.FailedPrecondition, "mockserver: stream send error: %v", err) + break + } + + // Check whether the next response isn't blocked on a request. + ok, retResponse, retErr = verifier.TryPop() + if ok { + continue + } + + req = reflect.New(requestType).Interface() + if err = stream.RecvMsg(req); err == io.EOF { + break + } else if err != nil { + err = status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error: %v", err) + break + } + retResponse, retErr = verifier.Pop(req) + } + + // Check whether the stream ended prematurely. + verifier.Flush() + s.endStream(id) + return +} + +// MockServer implementation. + +func (s *mockLiteServer) OnTestStart(globalVerifier *RPCVerifier) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.testActive { + panic("mockserver is already in use by another test") + } + + s.testActive = true + s.globalVerifier = globalVerifier + s.publishVerifiers.Reset() + s.activeStreams = make(map[int]*streamHolder) +} + +func (s *mockLiteServer) OnTestEnd() { + s.mu.Lock() + defer s.mu.Unlock() + + s.testActive = false + if s.globalVerifier != nil { + s.globalVerifier.Flush() + } + + for _, as := range s.activeStreams { + as.verifier.Flush() + } +} + +func (s *mockLiteServer) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) { + s.mu.Lock() + defer s.mu.Unlock() + s.publishVerifiers.Push(key(topic, partition), streamVerifier) +} + +// PublisherService implementation. + +func (s *mockLiteServer) Publish(stream pb.PublisherService_PublishServer) error { + req, err := stream.Recv() + if err != nil { + return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err) + } + if len(req.GetInitialRequest().GetTopic()) == 0 { + return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial publish request: %v", req) + } + + initReq := req.GetInitialRequest() + k := key(initReq.GetTopic(), int(initReq.GetPartition())) + return s.handleStream(stream, req, reflect.TypeOf(pb.PublishRequest{}), k, s.publishVerifiers) +} + +// AdminService implementation. + +func (s *mockLiteServer) GetTopicPartitions(ctx context.Context, req *pb.GetTopicPartitionsRequest) (*pb.TopicPartitions, error) { + s.mu.Lock() + defer s.mu.Unlock() + + retResponse, retErr := s.globalVerifier.Pop(req) + if retErr != nil { + return nil, retErr + } + resp, ok := retResponse.(*pb.TopicPartitions) + if !ok { + return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse)) + } + return resp, nil +} diff --git a/pubsublite/internal/test/util.go b/pubsublite/internal/test/util.go new file mode 100644 index 00000000000..5486c13100e --- /dev/null +++ b/pubsublite/internal/test/util.go @@ -0,0 +1,48 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// ErrorEqual compares two errors for equivalence. +func ErrorEqual(got, want error) bool { + if got == want { + return true + } + return cmp.Equal(got, want, cmpopts.EquateErrors()) +} + +// ErrorHasCode returns true if an error has the desired canonical code. +func ErrorHasCode(got error, wantCode codes.Code) bool { + if s, ok := status.FromError(got); ok { + return s.Code() == wantCode + } + return false +} + +// FakeSource is a fake source that returns a configurable constant. +type FakeSource struct { + Ret int64 +} + +// Int63 returns the configured fake random number. +func (f *FakeSource) Int63() int64 { return f.Ret } + +// Seed is unimplemented. +func (f *FakeSource) Seed(seed int64) {} diff --git a/pubsublite/internal/test/verifier.go b/pubsublite/internal/test/verifier.go new file mode 100644 index 00000000000..a1e2681c8c3 --- /dev/null +++ b/pubsublite/internal/test/verifier.go @@ -0,0 +1,226 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "container/list" + "sync" + "testing" + "time" + + "cloud.google.com/go/internal/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + // blockWaitTimeout is the timeout for any wait operations to ensure no + // deadlocks. + blockWaitTimeout = 30 * time.Second +) + +type rpcMetadata struct { + wantRequest interface{} + retResponse interface{} + retErr error + blockResponse chan struct{} +} + +// wait until the `blockResponse` is released by the test, or a timeout occurs. +// Returns immediately if there was no block. +func (r *rpcMetadata) wait() error { + if r.blockResponse == nil { + return nil + } + select { + case <-time.After(blockWaitTimeout): + // Note: avoid returning a retryable code to quickly terminate the test. + return status.Errorf(codes.FailedPrecondition, "mockserver: test did not unblock response within %v", blockWaitTimeout) + case <-r.blockResponse: + return nil + } +} + +// RPCVerifier stores an queue of requests expected from the client, and the +// corresponding response or error to return. +type RPCVerifier struct { + t *testing.T + mu sync.Mutex + rpcs *list.List // Value = *rpcMetadata + numCalls int +} + +// NewRPCVerifier creates a new verifier for requests received by the server. +func NewRPCVerifier(t *testing.T) *RPCVerifier { + return &RPCVerifier{ + t: t, + rpcs: list.New(), + numCalls: -1, + } +} + +// Push appends a new {request, response, error} tuple. +func (v *RPCVerifier) Push(wantRequest interface{}, retResponse interface{}, retErr error) { + v.mu.Lock() + defer v.mu.Unlock() + + v.rpcs.PushBack(&rpcMetadata{ + wantRequest: wantRequest, + retResponse: retResponse, + retErr: retErr, + }) +} + +// PushWithBlock is like Push, but returns a channel that the test should close +// when it would like the response to be sent to the client. This is useful for +// synchronizing with work that needs to be done on the client. +func (v *RPCVerifier) PushWithBlock(wantRequest interface{}, retResponse interface{}, retErr error) chan struct{} { + v.mu.Lock() + defer v.mu.Unlock() + + block := make(chan struct{}) + v.rpcs.PushBack(&rpcMetadata{ + wantRequest: wantRequest, + retResponse: retResponse, + retErr: retErr, + blockResponse: block, + }) + return block +} + +// Pop validates the received request with the next {request, response, error} +// tuple. +func (v *RPCVerifier) Pop(gotRequest interface{}) (interface{}, error) { + v.mu.Lock() + defer v.mu.Unlock() + + v.numCalls++ + elem := v.rpcs.Front() + if elem == nil { + v.t.Errorf("call(%d): unexpected request:\n%v", v.numCalls, gotRequest) + return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected request") + } + + rpc, _ := elem.Value.(*rpcMetadata) + v.rpcs.Remove(elem) + + if !testutil.Equal(gotRequest, rpc.wantRequest) { + v.t.Errorf("call(%d): got request: %v\nwant request: %v", v.numCalls, gotRequest, rpc.wantRequest) + } + if err := rpc.wait(); err != nil { + return nil, err + } + return rpc.retResponse, rpc.retErr +} + +// TryPop should be used only for streams. It checks whether the request in the +// next tuple is nil, in which case the response or error should be returned to +// the client without waiting for a request. Useful for streams where the server +// continuously sends data (e.g. subscribe stream). +func (v *RPCVerifier) TryPop() (bool, interface{}, error) { + v.mu.Lock() + defer v.mu.Unlock() + + elem := v.rpcs.Front() + if elem == nil { + return false, nil, nil + } + + rpc, _ := elem.Value.(*rpcMetadata) + if rpc.wantRequest != nil { + return false, nil, nil + } + + v.rpcs.Remove(elem) + if err := rpc.wait(); err != nil { + return true, nil, err + } + return true, rpc.retResponse, rpc.retErr +} + +// Flush logs an error for any remaining {request, response, error} tuples, in +// case the client terminated early. +func (v *RPCVerifier) Flush() { + v.mu.Lock() + defer v.mu.Unlock() + + for elem := v.rpcs.Front(); elem != nil; elem = elem.Next() { + v.numCalls++ + rpc, _ := elem.Value.(*rpcMetadata) + v.t.Errorf("call(%d): did not receive expected request:\n%v", v.numCalls, rpc.wantRequest) + } + v.rpcs.Init() +} + +// streamVerifiers stores a queue of verifiers for unique stream connections. +type streamVerifiers struct { + t *testing.T + verifiers *list.List // Value = *RPCVerifier + numStreams int +} + +func newStreamVerifiers(t *testing.T) *streamVerifiers { + return &streamVerifiers{ + t: t, + verifiers: list.New(), + numStreams: -1, + } +} + +func (sv *streamVerifiers) Push(v *RPCVerifier) { + sv.verifiers.PushBack(v) +} + +func (sv *streamVerifiers) Pop() (*RPCVerifier, error) { + sv.numStreams++ + elem := sv.verifiers.Front() + if elem == nil { + sv.t.Errorf("stream(%d): unexpected connection with no verifiers", sv.numStreams) + return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected stream connection") + } + + v, _ := elem.Value.(*RPCVerifier) + sv.verifiers.Remove(elem) + return v, nil +} + +// keyedStreamVerifiers stores indexed streamVerifiers. +type keyedStreamVerifiers struct { + verifiers map[string]*streamVerifiers +} + +func newKeyedStreamVerifiers() *keyedStreamVerifiers { + return &keyedStreamVerifiers{verifiers: make(map[string]*streamVerifiers)} +} + +func (kv *keyedStreamVerifiers) Reset() { + kv.verifiers = make(map[string]*streamVerifiers) +} + +func (kv *keyedStreamVerifiers) Push(key string, v *RPCVerifier) { + sv, ok := kv.verifiers[key] + if !ok { + sv = newStreamVerifiers(v.t) + kv.verifiers[key] = sv + } + sv.Push(v) +} + +func (kv *keyedStreamVerifiers) Pop(key string) (*RPCVerifier, error) { + sv, ok := kv.verifiers[key] + if !ok { + return nil, status.Error(codes.FailedPrecondition, "mockserver: unexpected connection with no configured responses") + } + return sv.Pop() +}