From 581bf92878dcb52ae8ea3633d4b3fcbb7054ff0f Mon Sep 17 00:00:00 2001 From: tmdiep Date: Fri, 13 Nov 2020 17:39:27 +1100 Subject: [PATCH] feat(pubsub): Allow Message and PublishResult to be used outside the package (#3200) Added NewMessage which can be provided a custom ack/nack handler. Added NewPublishResult which returns the set() func. --- pubsub/iterator.go | 13 +++--- pubsub/message.go | 83 ++++++++++++++++++++++++++--------- pubsub/streaming_pull_test.go | 19 ++++---- pubsub/subscription.go | 5 ++- pubsub/topic.go | 7 +++ 5 files changed, 90 insertions(+), 37 deletions(-) diff --git a/pubsub/iterator.go b/pubsub/iterator.go index 8f1d9316bbf..da888ad8676 100644 --- a/pubsub/iterator.go +++ b/pubsub/iterator.go @@ -218,14 +218,15 @@ func (it *messageIterator) receive(maxToPull int32) ([]*Message, error) { it.mu.Lock() now := time.Now() for _, m := range msgs { - m.receiveTime = now - addRecv(m.ID, m.ackID, now) - m.doneFunc = it.done - it.keepAliveDeadlines[m.ackID] = maxExt + ackh, _ := m.ackHandler() + ackh.receiveTime = now + addRecv(m.ID, ackh.ackID, now) + ackh.doneFunc = it.done + it.keepAliveDeadlines[ackh.ackID] = maxExt // Don't change the mod-ack if the message is going to be nacked. This is // possible if there are retries. - if !it.pendingNacks[m.ackID] { - ackIDs[m.ackID] = true + if !it.pendingNacks[ackh.ackID] { + ackIDs[ackh.ackID] = true } } deadline := it.ackDeadline() diff --git a/pubsub/message.go b/pubsub/message.go index 00076ef6463..54be363eefa 100644 --- a/pubsub/message.go +++ b/pubsub/message.go @@ -36,18 +36,12 @@ type Message struct { // labelled with. Attributes map[string]string - // ackID is the identifier to acknowledge this message. - ackID string - // PublishTime is the time at which the message was published. This is // populated by the server for Messages obtained from a subscription. // // This field is read-only. PublishTime time.Time - // receiveTime is the time the message was received by the client. - receiveTime time.Time - // DeliveryAttempt is the number of times a message has been delivered. // This is part of the dead lettering feature that forwards messages that // fail to be processed (from nack/ack deadline timeout) to a dead letter topic. @@ -59,19 +53,23 @@ type Message struct { // size is the approximate size of the message's data and attributes. size int - calledDone bool - - // The done method of the iterator that created this Message. - doneFunc func(string, bool, time.Time) - // OrderingKey identifies related messages for which publish order should // be respected. If empty string is used, message will be sent unordered. OrderingKey string + + // ackh handles Ack() or Nack(). + ackh ackHandler +} + +// NewMessage creates a message with a custom ack/nack handler, which should not +// be nil. +func NewMessage(ackh ackHandler) *Message { + return &Message{ackh: ackh} } func toMessage(resp *pb.ReceivedMessage) (*Message, error) { if resp.Message == nil { - return &Message{ackID: resp.AckId}, nil + return &Message{ackh: &psAckHandler{ackID: resp.AckId}}, nil } pubTime, err := ptypes.Timestamp(resp.Message.PublishTime) @@ -86,13 +84,13 @@ func toMessage(resp *pb.ReceivedMessage) (*Message, error) { } return &Message{ - ackID: resp.AckId, Data: resp.Message.Data, Attributes: resp.Message.Attributes, ID: resp.Message.MessageId, PublishTime: pubTime, DeliveryAttempt: deliveryAttempt, OrderingKey: resp.Message.OrderingKey, + ackh: &psAckHandler{ackID: resp.AckId}, }, nil } @@ -102,7 +100,9 @@ func toMessage(resp *pb.ReceivedMessage) (*Message, error) { // Client code must call Ack or Nack when finished for each received Message. // Calls to Ack or Nack have no effect after the first call. func (m *Message) Ack() { - m.done(true) + if m.ackh != nil { + m.ackh.OnAck() + } } // Nack indicates that the client will not or cannot process a Message passed to the Subscriber.Receive callback. @@ -111,15 +111,58 @@ func (m *Message) Ack() { // Client code must call Ack or Nack when finished for each received Message. // Calls to Ack or Nack have no effect after the first call. func (m *Message) Nack() { - m.done(false) + if m.ackh != nil { + m.ackh.OnNack() + } +} + +// ackHandler performs a safe cast of the message's ack handler to psAckHandler. +func (m *Message) ackHandler() (*psAckHandler, bool) { + ackh, ok := m.ackh.(*psAckHandler) + return ackh, ok +} + +func (m *Message) ackID() string { + if ackh, ok := m.ackh.(*psAckHandler); ok { + return ackh.ackID + } + return "" +} + +// ackHandler implements ack/nack handling. +type ackHandler interface { + OnAck() + OnNack() +} + +// psAckHandler handles ack/nack for the pubsub package. +type psAckHandler struct { + // ackID is the identifier to acknowledge this message. + ackID string + + // receiveTime is the time the message was received by the client. + receiveTime time.Time + + calledDone bool + + // The done method of the iterator that created this Message. + doneFunc func(string, bool, time.Time) +} + +func (ah *psAckHandler) OnAck() { + ah.done(true) +} + +func (ah *psAckHandler) OnNack() { + ah.done(false) } -func (m *Message) done(ack bool) { - if m.calledDone { +func (ah *psAckHandler) done(ack bool) { + if ah.calledDone { return } - m.calledDone = true - if m.doneFunc != nil { - m.doneFunc(m.ackID, ack, m.receiveTime) + ah.calledDone = true + if ah.doneFunc != nil { + ah.doneFunc(ah.ackID, ack, ah.receiveTime) } } diff --git a/pubsub/streaming_pull_test.go b/pubsub/streaming_pull_test.go index ba40a2d60d8..e70d2e7abff 100644 --- a/pubsub/streaming_pull_test.go +++ b/pubsub/streaming_pull_test.go @@ -67,7 +67,7 @@ func TestStreamingPullMultipleFetches(t *testing.T) { func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer, msgs []*pb.ReceivedMessage) { sub := client.Subscription("S") gotMsgs, err := pullN(context.Background(), sub, len(msgs), func(_ context.Context, m *Message) { - id, err := strconv.Atoi(m.ackID) + id, err := strconv.Atoi(m.ackID()) if err != nil { panic(err) } @@ -83,20 +83,21 @@ func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer } gotMap := map[string]*Message{} for _, m := range gotMsgs { - gotMap[m.ackID] = m + gotMap[m.ackID()] = m } for i, msg := range msgs { want, err := toMessage(msg) if err != nil { t.Fatal(err) } - want.calledDone = true - got := gotMap[want.ackID] + wantAckh, _ := want.ackHandler() + wantAckh.calledDone = true + got := gotMap[wantAckh.ackID] if got == nil { - t.Errorf("%d: no message for ackID %q", i, want.ackID) + t.Errorf("%d: no message for ackID %q", i, wantAckh.ackID) continue } - if !testutil.Equal(got, want, cmp.AllowUnexported(Message{}), cmpopts.IgnoreTypes(time.Time{}, func(string, bool, time.Time) {})) { + if !testutil.Equal(got, want, cmp.AllowUnexported(Message{}, psAckHandler{}), cmpopts.IgnoreTypes(time.Time{}, func(string, bool, time.Time) {})) { t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want) } } @@ -235,10 +236,10 @@ func TestStreamingPullConcurrent(t *testing.T) { } seen := map[string]bool{} for _, gm := range gotMsgs { - if seen[gm.ackID] { - t.Fatalf("duplicate ID %q", gm.ackID) + if seen[gm.ackID()] { + t.Fatalf("duplicate ID %q", gm.ackID()) } - seen[gm.ackID] = true + seen[gm.ackID()] = true } if len(seen) != nMessages { t.Fatalf("got %d messages, want %d", len(seen), nMessages) diff --git a/pubsub/subscription.go b/pubsub/subscription.go index 59f296e04cf..7e5f201145e 100644 --- a/pubsub/subscription.go +++ b/pubsub/subscription.go @@ -910,9 +910,10 @@ func (s *Subscription) Receive(ctx context.Context, f func(context.Context, *Mes // Return nil if the context is done, not err. return nil } - old := msg.doneFunc + ackh, _ := msg.ackHandler() + old := ackh.doneFunc msgLen := len(msg.Data) - msg.doneFunc = func(ackID string, ack bool, receiveTime time.Time) { + ackh.doneFunc = func(ackID string, ack bool, receiveTime time.Time) { defer fc.release(msgLen) old(ackID, ack, receiveTime) } diff --git a/pubsub/topic.go b/pubsub/topic.go index 07e392e0869..1dd50c00291 100644 --- a/pubsub/topic.go +++ b/pubsub/topic.go @@ -470,6 +470,13 @@ type PublishResult struct { err error } +// NewPublishResult returns the set() function to enable callers from outside +// this package to store and call it (e.g. unit tests). +func NewPublishResult() (*PublishResult, func(string, error)) { + result := &PublishResult{ready: make(chan struct{})} + return result, result.set +} + // Ready returns a channel that is closed when the result is ready. // When the Ready channel is closed, Get is guaranteed not to block. func (r *PublishResult) Ready() <-chan struct{} { return r.ready }