Skip to content

Commit

Permalink
fix(pubsub): fix memory leak issue in publish scheduler (#4282)
Browse files Browse the repository at this point in the history
  • Loading branch information
hongalex committed Jun 21, 2021
1 parent 24c4451 commit 22ffc18
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 35 deletions.
44 changes: 26 additions & 18 deletions pubsub/internal/scheduler/publish_scheduler.go
Expand Up @@ -39,8 +39,8 @@ type PublishScheduler struct {
BufferedByteLimit int

mu sync.Mutex
bundlers map[string]*bundler.Bundler
outstanding map[string]int
bundlers sync.Map // keys -> *bundler.Bundler
outstanding sync.Map // keys -> num outstanding messages

keysMu sync.RWMutex
// keysWithErrors tracks ordering keys that cannot accept new messages.
Expand Down Expand Up @@ -76,8 +76,6 @@ func NewPublishScheduler(workers int, handle func(bundle interface{})) *PublishS
}

s := PublishScheduler{
bundlers: make(map[string]*bundler.Bundler),
outstanding: make(map[string]int),
keysWithErrors: make(map[string]struct{}),
workers: make(chan struct{}, workers),
handle: handle,
Expand Down Expand Up @@ -106,20 +104,23 @@ func (s *PublishScheduler) Add(key string, item interface{}, size int) error {

s.mu.Lock()
defer s.mu.Unlock()
b, ok := s.bundlers[key]
var b *bundler.Bundler
bInterface, ok := s.bundlers.Load(key)

if !ok {
s.outstanding[key] = 1
s.outstanding.Store(key, 1)
b = bundler.NewBundler(item, func(bundle interface{}) {
s.workers <- struct{}{}
s.handle(bundle)
<-s.workers

nlen := reflect.ValueOf(bundle).Len()
s.mu.Lock()
s.outstanding[key] -= nlen
if s.outstanding[key] == 0 {
delete(s.outstanding, key)
delete(s.bundlers, key)
outsInterface, _ := s.outstanding.Load(key)
s.outstanding.Store(key, outsInterface.(int)-nlen)
if v, _ := s.outstanding.Load(key); v == 0 {
s.outstanding.Delete(key)
s.bundlers.Delete(key)
}
s.mu.Unlock()
})
Expand All @@ -142,32 +143,39 @@ func (s *PublishScheduler) Add(key string, item interface{}, size int) error {
b.HandlerLimit = 1
}

s.bundlers[key] = b
s.bundlers.Store(key, b)
} else {
b = bInterface.(*bundler.Bundler)
oi, _ := s.outstanding.Load(key)
s.outstanding.Store(key, oi.(int)+1)
}
s.outstanding[key]++

return b.Add(item, size)
}

// FlushAndStop begins flushing items from bundlers and from the scheduler. It
// blocks until all items have been flushed.
func (s *PublishScheduler) FlushAndStop() {
close(s.done)
for _, b := range s.bundlers {
b.Flush()
}
s.bundlers.Range(func(_, bi interface{}) bool {
bi.(*bundler.Bundler).Flush()
return true
})
}

// Flush waits until all bundlers are sent.
func (s *PublishScheduler) Flush() {
var wg sync.WaitGroup
for _, b := range s.bundlers {
s.bundlers.Range(func(_, bi interface{}) bool {
wg.Add(1)
go func(b *bundler.Bundler) {
defer wg.Done()
b.Flush()
}(b)
}
}(bi.(*bundler.Bundler))
return true
})
wg.Wait()

}

// IsPaused checks if the bundler associated with an ordering keys is
Expand Down
4 changes: 3 additions & 1 deletion pubsub/loadtest/benchmark_test.go
Expand Up @@ -46,14 +46,16 @@ const (
batchDuration = 50 * time.Millisecond
serverDelay = 200 * time.Millisecond
maxOutstandingPublishes = 1600 // max_outstanding_messages in run.py
useOrdered = true
)

func BenchmarkPublishThroughput(b *testing.B) {
b.SetBytes(nMessages * messageSize)
client := perfClient(serverDelay, 1, b)

lts := &PubServer{ID: "xxx"}
lts.init(client, "t", messageSize, batchSize, batchDuration)
lts.init(client, "t", messageSize, batchSize, batchDuration, useOrdered)

b.ResetTimer()
for i := 0; i < b.N; i++ {
runOnce(lts)
Expand Down
23 changes: 16 additions & 7 deletions pubsub/loadtest/loadtest.go
Expand Up @@ -22,6 +22,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"log"
"runtime"
"strconv"
Expand All @@ -38,6 +39,7 @@ type pubServerConfig struct {
topic *pubsub.Topic
msgData []byte
batchSize int32
ordered bool
}

// PubServer is a dummy Pub/Sub server for load testing.
Expand All @@ -56,23 +58,26 @@ func (l *PubServer) Start(ctx context.Context, req *pb.StartRequest) (*pb.StartR
return nil, err
}
dur := req.PublishBatchDuration.AsDuration()
l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur)
l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur, false)
log.Println("started")
return &pb.StartResponse{}, nil
}

func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration) {
func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration, ordered bool) {
topic := c.Topic(topicName)
topic.PublishSettings = pubsub.PublishSettings{
DelayThreshold: batchDur,
CountThreshold: 950,
ByteThreshold: 9500000,
DelayThreshold: batchDur,
CountThreshold: 950,
ByteThreshold: 9500000,
BufferedByteLimit: 2e9,
}
topic.EnableMessageOrdering = ordered

l.cfg.Store(pubServerConfig{
topic: topic,
msgData: bytes.Repeat([]byte{'A'}, int(msgSize)),
batchSize: batchSize,
ordered: ordered,
})
}

Expand Down Expand Up @@ -101,14 +106,18 @@ func (l *PubServer) publishBatch() ([]int64, error) {

rs := make([]*pubsub.PublishResult, cfg.batchSize)
for i := int32(0); i < cfg.batchSize; i++ {
rs[i] = cfg.topic.Publish(context.TODO(), &pubsub.Message{
msg := &pubsub.Message{
Data: cfg.msgData,
Attributes: map[string]string{
"sendTime": startStr,
"clientId": l.ID,
"sequenceNumber": strconv.Itoa(int(seqNum + i)),
},
})
}
if cfg.ordered {
msg.OrderingKey = fmt.Sprintf("key-%d", seqNum+i)
}
rs[i] = cfg.topic.Publish(context.TODO(), msg)
}
for i, r := range rs {
_, err := r.Get(context.Background())
Expand Down
2 changes: 0 additions & 2 deletions pubsub/pstest/fake.go
Expand Up @@ -185,8 +185,6 @@ func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]stri
// AddPublishResponse adds a new publish response to the channel used for
// responding to publish requests.
func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) {
s.GServer.mu.Lock()
defer s.GServer.mu.Unlock()
pr := &publishResponse{}
if err != nil {
pr.err = err
Expand Down
11 changes: 11 additions & 0 deletions pubsub/pstest/fake_test.go
Expand Up @@ -1076,4 +1076,15 @@ func TestPublishResponse(t *testing.T) {
if want := "2"; got != want {
t.Fatalf("srv.Publish(): got %v, want %v", got, want)
}

go func() {
got = srv.Publish("projects/p/topics/t", []byte("msg4"), nil)
if want := "3"; got != want {
fmt.Printf("srv.Publish(): got %v, want %v", got, want)
}
}()
time.Sleep(5 * time.Second)
srv.AddPublishResponse(&pb.PublishResponse{
MessageIds: []string{"3"},
}, nil)
}
46 changes: 39 additions & 7 deletions pubsub/topic_test.go
Expand Up @@ -321,31 +321,63 @@ func TestFlushStopTopic(t *testing.T) {

// Subsequent publishes after a flush should succeed.
topic.Flush()
r := topic.Publish(ctx, &Message{
r1 := topic.Publish(ctx, &Message{
Data: []byte("hello"),
})
_, err = r.Get(ctx)
_, err = r1.Get(ctx)
if err != nil {
t.Errorf("got err: %v", err)
}

// Publishing after a flush should succeed.
topic.Flush()
r = topic.Publish(ctx, &Message{
r2 := topic.Publish(ctx, &Message{
Data: []byte("world"),
})
_, err = r.Get(ctx)
_, err = r2.Get(ctx)
if err != nil {
t.Errorf("got err: %v", err)
}

// Publishing after a temporarily blocked flush should succeed.
srv.SetAutoPublishResponse(false)

r3 := topic.Publish(ctx, &Message{
Data: []byte("blocking message publish"),
})
go func() {
topic.Flush()
}()

// Wait a second between publishes to ensure messages are not bundled together.
time.Sleep(1 * time.Second)
r4 := topic.Publish(ctx, &Message{
Data: []byte("message published after flush"),
})

// Wait 5 seconds to simulate network delay.
time.Sleep(5 * time.Second)
srv.AddPublishResponse(&pubsubpb.PublishResponse{
MessageIds: []string{"1"},
}, nil)
srv.AddPublishResponse(&pubsubpb.PublishResponse{
MessageIds: []string{"2"},
}, nil)

if _, err = r3.Get(ctx); err != nil {
t.Errorf("got err: %v", err)
}
if _, err = r4.Get(ctx); err != nil {
t.Errorf("got err: %v", err)
}

// Publishing after Stop should fail.
srv.SetAutoPublishResponse(true)
topic.Stop()
r = topic.Publish(ctx, &Message{
r5 := topic.Publish(ctx, &Message{
Data: []byte("this should fail"),
})
_, err = r.Get(ctx)
if err != errTopicStopped {
if _, err := r5.Get(ctx); err != errTopicStopped {
t.Errorf("got %v, want errTopicStopped", err)
}
}

0 comments on commit 22ffc18

Please sign in to comment.