Skip to content

Commit

Permalink
feat(pubsub): Allow Message and PublishResult to be used outside the …
Browse files Browse the repository at this point in the history
…package (#3200)

Added NewMessage which can be provided a custom ack/nack handler. Added NewPublishResult which returns the set() func.
  • Loading branch information
tmdiep committed Nov 13, 2020
1 parent bd837fc commit 581bf92
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 37 deletions.
13 changes: 7 additions & 6 deletions pubsub/iterator.go
Expand Up @@ -218,14 +218,15 @@ func (it *messageIterator) receive(maxToPull int32) ([]*Message, error) {
it.mu.Lock()
now := time.Now()
for _, m := range msgs {
m.receiveTime = now
addRecv(m.ID, m.ackID, now)
m.doneFunc = it.done
it.keepAliveDeadlines[m.ackID] = maxExt
ackh, _ := m.ackHandler()
ackh.receiveTime = now
addRecv(m.ID, ackh.ackID, now)
ackh.doneFunc = it.done
it.keepAliveDeadlines[ackh.ackID] = maxExt
// Don't change the mod-ack if the message is going to be nacked. This is
// possible if there are retries.
if !it.pendingNacks[m.ackID] {
ackIDs[m.ackID] = true
if !it.pendingNacks[ackh.ackID] {
ackIDs[ackh.ackID] = true
}
}
deadline := it.ackDeadline()
Expand Down
83 changes: 63 additions & 20 deletions pubsub/message.go
Expand Up @@ -36,18 +36,12 @@ type Message struct {
// labelled with.
Attributes map[string]string

// ackID is the identifier to acknowledge this message.
ackID string

// PublishTime is the time at which the message was published. This is
// populated by the server for Messages obtained from a subscription.
//
// This field is read-only.
PublishTime time.Time

// receiveTime is the time the message was received by the client.
receiveTime time.Time

// DeliveryAttempt is the number of times a message has been delivered.
// This is part of the dead lettering feature that forwards messages that
// fail to be processed (from nack/ack deadline timeout) to a dead letter topic.
Expand All @@ -59,19 +53,23 @@ type Message struct {
// size is the approximate size of the message's data and attributes.
size int

calledDone bool

// The done method of the iterator that created this Message.
doneFunc func(string, bool, time.Time)

// OrderingKey identifies related messages for which publish order should
// be respected. If empty string is used, message will be sent unordered.
OrderingKey string

// ackh handles Ack() or Nack().
ackh ackHandler
}

// NewMessage creates a message with a custom ack/nack handler, which should not
// be nil.
func NewMessage(ackh ackHandler) *Message {
return &Message{ackh: ackh}
}

func toMessage(resp *pb.ReceivedMessage) (*Message, error) {
if resp.Message == nil {
return &Message{ackID: resp.AckId}, nil
return &Message{ackh: &psAckHandler{ackID: resp.AckId}}, nil
}

pubTime, err := ptypes.Timestamp(resp.Message.PublishTime)
Expand All @@ -86,13 +84,13 @@ func toMessage(resp *pb.ReceivedMessage) (*Message, error) {
}

return &Message{
ackID: resp.AckId,
Data: resp.Message.Data,
Attributes: resp.Message.Attributes,
ID: resp.Message.MessageId,
PublishTime: pubTime,
DeliveryAttempt: deliveryAttempt,
OrderingKey: resp.Message.OrderingKey,
ackh: &psAckHandler{ackID: resp.AckId},
}, nil
}

Expand All @@ -102,7 +100,9 @@ func toMessage(resp *pb.ReceivedMessage) (*Message, error) {
// Client code must call Ack or Nack when finished for each received Message.
// Calls to Ack or Nack have no effect after the first call.
func (m *Message) Ack() {
m.done(true)
if m.ackh != nil {
m.ackh.OnAck()
}
}

// Nack indicates that the client will not or cannot process a Message passed to the Subscriber.Receive callback.
Expand All @@ -111,15 +111,58 @@ func (m *Message) Ack() {
// Client code must call Ack or Nack when finished for each received Message.
// Calls to Ack or Nack have no effect after the first call.
func (m *Message) Nack() {
m.done(false)
if m.ackh != nil {
m.ackh.OnNack()
}
}

// ackHandler performs a safe cast of the message's ack handler to psAckHandler.
func (m *Message) ackHandler() (*psAckHandler, bool) {
ackh, ok := m.ackh.(*psAckHandler)
return ackh, ok
}

func (m *Message) ackID() string {
if ackh, ok := m.ackh.(*psAckHandler); ok {
return ackh.ackID
}
return ""
}

// ackHandler implements ack/nack handling.
type ackHandler interface {
OnAck()
OnNack()
}

// psAckHandler handles ack/nack for the pubsub package.
type psAckHandler struct {
// ackID is the identifier to acknowledge this message.
ackID string

// receiveTime is the time the message was received by the client.
receiveTime time.Time

calledDone bool

// The done method of the iterator that created this Message.
doneFunc func(string, bool, time.Time)
}

func (ah *psAckHandler) OnAck() {
ah.done(true)
}

func (ah *psAckHandler) OnNack() {
ah.done(false)
}

func (m *Message) done(ack bool) {
if m.calledDone {
func (ah *psAckHandler) done(ack bool) {
if ah.calledDone {
return
}
m.calledDone = true
if m.doneFunc != nil {
m.doneFunc(m.ackID, ack, m.receiveTime)
ah.calledDone = true
if ah.doneFunc != nil {
ah.doneFunc(ah.ackID, ack, ah.receiveTime)
}
}
19 changes: 10 additions & 9 deletions pubsub/streaming_pull_test.go
Expand Up @@ -67,7 +67,7 @@ func TestStreamingPullMultipleFetches(t *testing.T) {
func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer, msgs []*pb.ReceivedMessage) {
sub := client.Subscription("S")
gotMsgs, err := pullN(context.Background(), sub, len(msgs), func(_ context.Context, m *Message) {
id, err := strconv.Atoi(m.ackID)
id, err := strconv.Atoi(m.ackID())
if err != nil {
panic(err)
}
Expand All @@ -83,20 +83,21 @@ func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer
}
gotMap := map[string]*Message{}
for _, m := range gotMsgs {
gotMap[m.ackID] = m
gotMap[m.ackID()] = m
}
for i, msg := range msgs {
want, err := toMessage(msg)
if err != nil {
t.Fatal(err)
}
want.calledDone = true
got := gotMap[want.ackID]
wantAckh, _ := want.ackHandler()
wantAckh.calledDone = true
got := gotMap[wantAckh.ackID]
if got == nil {
t.Errorf("%d: no message for ackID %q", i, want.ackID)
t.Errorf("%d: no message for ackID %q", i, wantAckh.ackID)
continue
}
if !testutil.Equal(got, want, cmp.AllowUnexported(Message{}), cmpopts.IgnoreTypes(time.Time{}, func(string, bool, time.Time) {})) {
if !testutil.Equal(got, want, cmp.AllowUnexported(Message{}, psAckHandler{}), cmpopts.IgnoreTypes(time.Time{}, func(string, bool, time.Time) {})) {
t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want)
}
}
Expand Down Expand Up @@ -235,10 +236,10 @@ func TestStreamingPullConcurrent(t *testing.T) {
}
seen := map[string]bool{}
for _, gm := range gotMsgs {
if seen[gm.ackID] {
t.Fatalf("duplicate ID %q", gm.ackID)
if seen[gm.ackID()] {
t.Fatalf("duplicate ID %q", gm.ackID())
}
seen[gm.ackID] = true
seen[gm.ackID()] = true
}
if len(seen) != nMessages {
t.Fatalf("got %d messages, want %d", len(seen), nMessages)
Expand Down
5 changes: 3 additions & 2 deletions pubsub/subscription.go
Expand Up @@ -910,9 +910,10 @@ func (s *Subscription) Receive(ctx context.Context, f func(context.Context, *Mes
// Return nil if the context is done, not err.
return nil
}
old := msg.doneFunc
ackh, _ := msg.ackHandler()
old := ackh.doneFunc
msgLen := len(msg.Data)
msg.doneFunc = func(ackID string, ack bool, receiveTime time.Time) {
ackh.doneFunc = func(ackID string, ack bool, receiveTime time.Time) {
defer fc.release(msgLen)
old(ackID, ack, receiveTime)
}
Expand Down
7 changes: 7 additions & 0 deletions pubsub/topic.go
Expand Up @@ -470,6 +470,13 @@ type PublishResult struct {
err error
}

// NewPublishResult returns the set() function to enable callers from outside
// this package to store and call it (e.g. unit tests).
func NewPublishResult() (*PublishResult, func(string, error)) {
result := &PublishResult{ready: make(chan struct{})}
return result, result.set
}

// Ready returns a channel that is closed when the result is ready.
// When the Ready channel is closed, Get is guaranteed not to block.
func (r *PublishResult) Ready() <-chan struct{} { return r.ready }
Expand Down

0 comments on commit 581bf92

Please sign in to comment.