From 22ffc18e522c0f943db57f8c943e7356067bedfd Mon Sep 17 00:00:00 2001 From: Alex Hong <9397363+hongalex@users.noreply.github.com> Date: Mon, 21 Jun 2021 10:01:40 -0700 Subject: [PATCH] fix(pubsub): fix memory leak issue in publish scheduler (#4282) --- .../internal/scheduler/publish_scheduler.go | 44 ++++++++++-------- pubsub/loadtest/benchmark_test.go | 4 +- pubsub/loadtest/loadtest.go | 23 +++++++--- pubsub/pstest/fake.go | 2 - pubsub/pstest/fake_test.go | 11 +++++ pubsub/topic_test.go | 46 ++++++++++++++++--- 6 files changed, 95 insertions(+), 35 deletions(-) diff --git a/pubsub/internal/scheduler/publish_scheduler.go b/pubsub/internal/scheduler/publish_scheduler.go index 9dbf51c1cfd..cba172b79db 100644 --- a/pubsub/internal/scheduler/publish_scheduler.go +++ b/pubsub/internal/scheduler/publish_scheduler.go @@ -39,8 +39,8 @@ type PublishScheduler struct { BufferedByteLimit int mu sync.Mutex - bundlers map[string]*bundler.Bundler - outstanding map[string]int + bundlers sync.Map // keys -> *bundler.Bundler + outstanding sync.Map // keys -> num outstanding messages keysMu sync.RWMutex // keysWithErrors tracks ordering keys that cannot accept new messages. @@ -76,8 +76,6 @@ func NewPublishScheduler(workers int, handle func(bundle interface{})) *PublishS } s := PublishScheduler{ - bundlers: make(map[string]*bundler.Bundler), - outstanding: make(map[string]int), keysWithErrors: make(map[string]struct{}), workers: make(chan struct{}, workers), handle: handle, @@ -106,9 +104,11 @@ func (s *PublishScheduler) Add(key string, item interface{}, size int) error { s.mu.Lock() defer s.mu.Unlock() - b, ok := s.bundlers[key] + var b *bundler.Bundler + bInterface, ok := s.bundlers.Load(key) + if !ok { - s.outstanding[key] = 1 + s.outstanding.Store(key, 1) b = bundler.NewBundler(item, func(bundle interface{}) { s.workers <- struct{}{} s.handle(bundle) @@ -116,10 +116,11 @@ func (s *PublishScheduler) Add(key string, item interface{}, size int) error { nlen := reflect.ValueOf(bundle).Len() s.mu.Lock() - s.outstanding[key] -= nlen - if s.outstanding[key] == 0 { - delete(s.outstanding, key) - delete(s.bundlers, key) + outsInterface, _ := s.outstanding.Load(key) + s.outstanding.Store(key, outsInterface.(int)-nlen) + if v, _ := s.outstanding.Load(key); v == 0 { + s.outstanding.Delete(key) + s.bundlers.Delete(key) } s.mu.Unlock() }) @@ -142,9 +143,13 @@ func (s *PublishScheduler) Add(key string, item interface{}, size int) error { b.HandlerLimit = 1 } - s.bundlers[key] = b + s.bundlers.Store(key, b) + } else { + b = bInterface.(*bundler.Bundler) + oi, _ := s.outstanding.Load(key) + s.outstanding.Store(key, oi.(int)+1) } - s.outstanding[key]++ + return b.Add(item, size) } @@ -152,22 +157,25 @@ func (s *PublishScheduler) Add(key string, item interface{}, size int) error { // blocks until all items have been flushed. func (s *PublishScheduler) FlushAndStop() { close(s.done) - for _, b := range s.bundlers { - b.Flush() - } + s.bundlers.Range(func(_, bi interface{}) bool { + bi.(*bundler.Bundler).Flush() + return true + }) } // Flush waits until all bundlers are sent. func (s *PublishScheduler) Flush() { var wg sync.WaitGroup - for _, b := range s.bundlers { + s.bundlers.Range(func(_, bi interface{}) bool { wg.Add(1) go func(b *bundler.Bundler) { defer wg.Done() b.Flush() - }(b) - } + }(bi.(*bundler.Bundler)) + return true + }) wg.Wait() + } // IsPaused checks if the bundler associated with an ordering keys is diff --git a/pubsub/loadtest/benchmark_test.go b/pubsub/loadtest/benchmark_test.go index 885caa2d5bd..611002e6a38 100644 --- a/pubsub/loadtest/benchmark_test.go +++ b/pubsub/loadtest/benchmark_test.go @@ -46,6 +46,7 @@ const ( batchDuration = 50 * time.Millisecond serverDelay = 200 * time.Millisecond maxOutstandingPublishes = 1600 // max_outstanding_messages in run.py + useOrdered = true ) func BenchmarkPublishThroughput(b *testing.B) { @@ -53,7 +54,8 @@ func BenchmarkPublishThroughput(b *testing.B) { client := perfClient(serverDelay, 1, b) lts := &PubServer{ID: "xxx"} - lts.init(client, "t", messageSize, batchSize, batchDuration) + lts.init(client, "t", messageSize, batchSize, batchDuration, useOrdered) + b.ResetTimer() for i := 0; i < b.N; i++ { runOnce(lts) diff --git a/pubsub/loadtest/loadtest.go b/pubsub/loadtest/loadtest.go index d6e9c1b1d62..cc1c734432b 100644 --- a/pubsub/loadtest/loadtest.go +++ b/pubsub/loadtest/loadtest.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "errors" + "fmt" "log" "runtime" "strconv" @@ -38,6 +39,7 @@ type pubServerConfig struct { topic *pubsub.Topic msgData []byte batchSize int32 + ordered bool } // PubServer is a dummy Pub/Sub server for load testing. @@ -56,23 +58,26 @@ func (l *PubServer) Start(ctx context.Context, req *pb.StartRequest) (*pb.StartR return nil, err } dur := req.PublishBatchDuration.AsDuration() - l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur) + l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur, false) log.Println("started") return &pb.StartResponse{}, nil } -func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration) { +func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration, ordered bool) { topic := c.Topic(topicName) topic.PublishSettings = pubsub.PublishSettings{ - DelayThreshold: batchDur, - CountThreshold: 950, - ByteThreshold: 9500000, + DelayThreshold: batchDur, + CountThreshold: 950, + ByteThreshold: 9500000, + BufferedByteLimit: 2e9, } + topic.EnableMessageOrdering = ordered l.cfg.Store(pubServerConfig{ topic: topic, msgData: bytes.Repeat([]byte{'A'}, int(msgSize)), batchSize: batchSize, + ordered: ordered, }) } @@ -101,14 +106,18 @@ func (l *PubServer) publishBatch() ([]int64, error) { rs := make([]*pubsub.PublishResult, cfg.batchSize) for i := int32(0); i < cfg.batchSize; i++ { - rs[i] = cfg.topic.Publish(context.TODO(), &pubsub.Message{ + msg := &pubsub.Message{ Data: cfg.msgData, Attributes: map[string]string{ "sendTime": startStr, "clientId": l.ID, "sequenceNumber": strconv.Itoa(int(seqNum + i)), }, - }) + } + if cfg.ordered { + msg.OrderingKey = fmt.Sprintf("key-%d", seqNum+i) + } + rs[i] = cfg.topic.Publish(context.TODO(), msg) } for i, r := range rs { _, err := r.Get(context.Background()) diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go index fcc546bee84..23b5bc57101 100644 --- a/pubsub/pstest/fake.go +++ b/pubsub/pstest/fake.go @@ -185,8 +185,6 @@ func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]stri // AddPublishResponse adds a new publish response to the channel used for // responding to publish requests. func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) { - s.GServer.mu.Lock() - defer s.GServer.mu.Unlock() pr := &publishResponse{} if err != nil { pr.err = err diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go index 4546f46368d..279bcc0a972 100644 --- a/pubsub/pstest/fake_test.go +++ b/pubsub/pstest/fake_test.go @@ -1076,4 +1076,15 @@ func TestPublishResponse(t *testing.T) { if want := "2"; got != want { t.Fatalf("srv.Publish(): got %v, want %v", got, want) } + + go func() { + got = srv.Publish("projects/p/topics/t", []byte("msg4"), nil) + if want := "3"; got != want { + fmt.Printf("srv.Publish(): got %v, want %v", got, want) + } + }() + time.Sleep(5 * time.Second) + srv.AddPublishResponse(&pb.PublishResponse{ + MessageIds: []string{"3"}, + }, nil) } diff --git a/pubsub/topic_test.go b/pubsub/topic_test.go index 1707ba16ede..d216cd38e96 100644 --- a/pubsub/topic_test.go +++ b/pubsub/topic_test.go @@ -321,31 +321,63 @@ func TestFlushStopTopic(t *testing.T) { // Subsequent publishes after a flush should succeed. topic.Flush() - r := topic.Publish(ctx, &Message{ + r1 := topic.Publish(ctx, &Message{ Data: []byte("hello"), }) - _, err = r.Get(ctx) + _, err = r1.Get(ctx) if err != nil { t.Errorf("got err: %v", err) } // Publishing after a flush should succeed. topic.Flush() - r = topic.Publish(ctx, &Message{ + r2 := topic.Publish(ctx, &Message{ Data: []byte("world"), }) - _, err = r.Get(ctx) + _, err = r2.Get(ctx) if err != nil { t.Errorf("got err: %v", err) } + // Publishing after a temporarily blocked flush should succeed. + srv.SetAutoPublishResponse(false) + + r3 := topic.Publish(ctx, &Message{ + Data: []byte("blocking message publish"), + }) + go func() { + topic.Flush() + }() + + // Wait a second between publishes to ensure messages are not bundled together. + time.Sleep(1 * time.Second) + r4 := topic.Publish(ctx, &Message{ + Data: []byte("message published after flush"), + }) + + // Wait 5 seconds to simulate network delay. + time.Sleep(5 * time.Second) + srv.AddPublishResponse(&pubsubpb.PublishResponse{ + MessageIds: []string{"1"}, + }, nil) + srv.AddPublishResponse(&pubsubpb.PublishResponse{ + MessageIds: []string{"2"}, + }, nil) + + if _, err = r3.Get(ctx); err != nil { + t.Errorf("got err: %v", err) + } + if _, err = r4.Get(ctx); err != nil { + t.Errorf("got err: %v", err) + } + // Publishing after Stop should fail. + srv.SetAutoPublishResponse(true) topic.Stop() - r = topic.Publish(ctx, &Message{ + r5 := topic.Publish(ctx, &Message{ Data: []byte("this should fail"), }) - _, err = r.Get(ctx) - if err != errTopicStopped { + if _, err := r5.Get(ctx); err != errTopicStopped { t.Errorf("got %v, want errTopicStopped", err) } }