diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go index 2f50f97908f..fcc546bee84 100644 --- a/pubsub/pstest/fake.go +++ b/pubsub/pstest/fake.go @@ -60,6 +60,11 @@ type ServerReactorOption struct { Reactor Reactor } +type publishResponse struct { + resp *pb.PublishResponse + err error +} + // For testing. Note that even though changes to the now variable are atomic, a call // to the stored function can race with a change to that function. This could be a // problem if tests are run in parallel, or even if concurrent parts of the same test @@ -98,6 +103,13 @@ type GServer struct { streamTimeout time.Duration timeNowFunc func() time.Time reactorOptions ReactorOptions + + // PublishResponses is a channel of responses to use for Publish. + publishResponses chan *publishResponse + // autoPublishResponse enables the server to automatically generate + // PublishResponse when publish is called. Otherwise, responses + // are generated from the publishResponses channel. + autoPublishResponse bool } // NewServer creates a new fake server running in the current process. @@ -114,11 +126,13 @@ func NewServer(opts ...ServerReactorOption) *Server { srv: srv, Addr: srv.Addr, GServer: GServer{ - topics: map[string]*topic{}, - subs: map[string]*subscription{}, - msgsByID: map[string]*Message{}, - timeNowFunc: timeNow, - reactorOptions: reactorOptions, + topics: map[string]*topic{}, + subs: map[string]*subscription{}, + msgsByID: map[string]*Message{}, + timeNowFunc: timeNow, + reactorOptions: reactorOptions, + publishResponses: make(chan *publishResponse, 100), + autoPublishResponse: true, }, } pb.RegisterPublisherServer(srv.Gsrv, &s.GServer) @@ -168,6 +182,37 @@ func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]stri return res.MessageIds[0] } +// AddPublishResponse adds a new publish response to the channel used for +// responding to publish requests. +func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) { + s.GServer.mu.Lock() + defer s.GServer.mu.Unlock() + pr := &publishResponse{} + if err != nil { + pr.err = err + } else { + pr.resp = pbr + } + s.GServer.publishResponses <- pr +} + +// SetAutoPublishResponse controls whether to automatically respond +// to messages published or to use user-added responses from the +// publishResponses channel. +func (s *Server) SetAutoPublishResponse(autoPublishResponse bool) { + s.GServer.mu.Lock() + defer s.GServer.mu.Unlock() + s.GServer.autoPublishResponse = autoPublishResponse +} + +// ResetPublishResponses resets the buffered publishResponses channel +// with a new buffered channel with the given size. +func (s *Server) ResetPublishResponses(size int) { + s.GServer.mu.Lock() + defer s.GServer.mu.Unlock() + s.GServer.publishResponses = make(chan *publishResponse, size) +} + // SetStreamTimeout sets the amount of time a stream will be active before it shuts // itself down. This mimics the real service's behavior of closing streams after 30 // minutes. If SetStreamTimeout is never called or is passed zero, streams never shut @@ -613,6 +658,15 @@ func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.Publis if top == nil { return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) } + + if !s.autoPublishResponse { + r := <-s.publishResponses + if r.err != nil { + return nil, r.err + } + return r.resp, nil + } + var ids []string for _, pm := range req.Messages { id := fmt.Sprintf("m%d", s.nextID) diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go index 6b75c2fd4a3..4546f46368d 100644 --- a/pubsub/pstest/fake_test.go +++ b/pubsub/pstest/fake_test.go @@ -1045,3 +1045,35 @@ func TestErrorInjection(t *testing.T) { } } } + +func TestPublishResponse(t *testing.T) { + ctx := context.Background() + _, _, srv, cleanup := newFake(ctx, t) + defer cleanup() + + // By default, autoPublishResponse is true so this should succeed immediately. + got := srv.Publish("projects/p/topics/t", []byte("msg1"), nil) + if want := "m0"; got != want { + t.Fatalf("srv.Publish(): got %v, want %v", got, want) + } + + // After disabling autoPublishResponse, publish() operations + // will read from the channel instead of auto generating messages. + srv.SetAutoPublishResponse(false) + + srv.AddPublishResponse(&pb.PublishResponse{ + MessageIds: []string{"1"}, + }, nil) + got = srv.Publish("projects/p/topics/t", []byte("msg2"), nil) + if want := "1"; got != want { + t.Fatalf("srv.Publish(): got %v, want %v", got, want) + } + + srv.AddPublishResponse(&pb.PublishResponse{ + MessageIds: []string{"2"}, + }, nil) + got = srv.Publish("projects/p/topics/t", []byte("msg3"), nil) + if want := "2"; got != want { + t.Fatalf("srv.Publish(): got %v, want %v", got, want) + } +}