diff --git a/pubsublite/internal/wire/periodic_task.go b/pubsublite/internal/wire/periodic_task.go new file mode 100644 index 00000000000..4b5810b5755 --- /dev/null +++ b/pubsublite/internal/wire/periodic_task.go @@ -0,0 +1,70 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "time" +) + +// periodicTask manages a recurring background task. +type periodicTask struct { + period time.Duration + task func() + ticker *time.Ticker + stop chan struct{} +} + +func newPeriodicTask(period time.Duration, task func()) *periodicTask { + return &periodicTask{ + period: period, + task: task, + } +} + +// Start the polling goroutine. No-op if the goroutine is already running. +// The task is executed after the polling period. +func (pt *periodicTask) Start() { + if pt.ticker != nil { + return + } + + pt.ticker = time.NewTicker(pt.period) + pt.stop = make(chan struct{}) + go pt.poll(pt.ticker, pt.stop) +} + +// Stop/pause the periodic task. +func (pt *periodicTask) Stop() { + if pt.ticker == nil { + return + } + + pt.ticker.Stop() + close(pt.stop) + + pt.ticker = nil + pt.stop = nil +} + +func (pt *periodicTask) poll(ticker *time.Ticker, stop chan struct{}) { + for { + select { + case <-stop: + // Ends the goroutine. + return + case <-ticker.C: + pt.task() + } + } +} diff --git a/pubsublite/internal/wire/periodic_task_test.go b/pubsublite/internal/wire/periodic_task_test.go new file mode 100644 index 00000000000..c3b04ab44b9 --- /dev/null +++ b/pubsublite/internal/wire/periodic_task_test.go @@ -0,0 +1,67 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package wire + +import ( + "sync/atomic" + "testing" + "time" +) + +func TestPeriodicTask(t *testing.T) { + var callCount int32 + values := make(chan int32) + task := func() { + values <- atomic.AddInt32(&callCount, 1) + } + ptask := newPeriodicTask(10*time.Millisecond, task) + + t.Run("Start1", func(t *testing.T) { + ptask.Start() + ptask.Start() // Tests duplicate start + + if got, want := <-values, int32(1); got != want { + t.Errorf("got %d, want %d", got, want) + } + }) + + t.Run("Stop1", func(t *testing.T) { + ptask.Stop() + ptask.Stop() // Tests duplicate stop + + select { + case got := <-values: + t.Errorf("got unexpected value %d", got) + default: + } + }) + + t.Run("Start2", func(t *testing.T) { + ptask.Start() + + if got, want := <-values, int32(2); got != want { + t.Errorf("got %d, want %d", got, want) + } + }) + + t.Run("Stop2", func(t *testing.T) { + ptask.Stop() + + select { + case got := <-values: + t.Errorf("got unexpected value %d", got) + default: + } + }) +}