diff --git a/lib/cache/loadable.go b/lib/cache/loadable.go index 2612376..e2e642e 100644 --- a/lib/cache/loadable.go +++ b/lib/cache/loadable.go @@ -2,9 +2,12 @@ package cache import ( "context" + "errors" + "fmt" "sync" "github.com/eko/gocache/lib/v4/store" + "golang.org/x/sync/singleflight" ) const ( @@ -21,19 +24,21 @@ type LoadFunction[T any] func(ctx context.Context, key any) (T, error) // LoadableCache represents a cache that uses a function to load data type LoadableCache[T any] struct { - loadFunc LoadFunction[T] - cache CacheInterface[T] - setChannel chan *loadableKeyValue[T] - setterWg *sync.WaitGroup + singleFlight singleflight.Group + loadFunc LoadFunction[T] + cache CacheInterface[T] + setChannel chan *loadableKeyValue[T] + setterWg *sync.WaitGroup } -// NewLoadable instanciates a new cache that uses a function to load data +// NewLoadable instantiates a new cache that uses a function to load data func NewLoadable[T any](loadFunc LoadFunction[T], cache CacheInterface[T]) *LoadableCache[T] { loadable := &LoadableCache[T]{ - loadFunc: loadFunc, - cache: cache, - setChannel: make(chan *loadableKeyValue[T], 10000), - setterWg: &sync.WaitGroup{}, + singleFlight: singleflight.Group{}, + loadFunc: loadFunc, + cache: cache, + setChannel: make(chan *loadableKeyValue[T], 10000), + setterWg: &sync.WaitGroup{}, } loadable.setterWg.Add(1) @@ -47,6 +52,9 @@ func (c *LoadableCache[T]) setter() { for item := range c.setChannel { c.Set(context.Background(), item.key, item.value) + + cacheKey := c.getCacheKey(item.key) + c.singleFlight.Forget(cacheKey) } } @@ -60,9 +68,24 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) { } // Unable to find in cache, try to load it from load function - object, err = c.loadFunc(ctx, key) + cacheKey := c.getCacheKey(key) + zero := *new(T) + + loadedResult, err, _ := c.singleFlight.Do( + cacheKey, + func() (any, error) { + return c.loadFunc(ctx, key) + }, + ) if err != nil { - return object, err + return zero, err + } + + var ok bool + if object, ok = loadedResult.(T); !ok { + return zero, errors.New( + fmt.Sprintf("returned value can't be cast to %T", zero), + ) } // Then, put it back in cache @@ -102,3 +125,17 @@ func (c *LoadableCache[T]) Close() error { return nil } + +// getCacheKey returns the cache key for the given key object by returning +// the key if type is string or by computing a checksum of key structure +// if its type is other than string +func (c *LoadableCache[T]) getCacheKey(key any) string { + switch v := key.(type) { + case string: + return v + case CacheKeyGenerator: + return v.GetCacheKey() + default: + return checksum(key) + } +} diff --git a/lib/cache/loadable_test.go b/lib/cache/loadable_test.go index 9e0e487..b7c535c 100644 --- a/lib/cache/loadable_test.go +++ b/lib/cache/loadable_test.go @@ -3,6 +3,8 @@ package cache import ( "context" "errors" + "sync" + "sync/atomic" "testing" "time" @@ -98,25 +100,50 @@ func TestLoadableGetWhenAvailableInLoadFunc(t *testing.T) { // Cache 1 cache1 := NewMockSetterCacheInterface[any](ctrl) cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1")) + cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1")) + cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1")) cache1.EXPECT().Set(ctx, "my-key", cacheValue).AnyTimes().Return(nil) + var loadCallCount int32 + pauseLoadFn := make(chan struct{}) + loadFunc := func(_ context.Context, key any) (any, error) { + atomic.AddInt32(&loadCallCount, 1) + <-pauseLoadFn + time.Sleep(time.Millisecond * 10) return cacheValue, nil } cache := NewLoadable[any](loadFunc, cache1) - // When - value, err := cache.Get(ctx, "my-key") - - // Wait for data to be processed - for len(cache.setChannel) > 0 { - time.Sleep(1 * time.Millisecond) + const numRequests = 3 + var started sync.WaitGroup + started.Add(numRequests) + var finished sync.WaitGroup + finished.Add(numRequests) + for i := 0; i < numRequests; i++ { + go func() { + defer finished.Done() + started.Done() + // When + value, err := cache.Get(ctx, "my-key") + + // Wait for data to be processed + for len(cache.setChannel) > 0 { + time.Sleep(1 * time.Millisecond) + } + + // Then + assert.Nil(t, err) + assert.Equal(t, cacheValue, value) + }() } - // Then - assert.Nil(t, err) - assert.Equal(t, cacheValue, value) + started.Wait() + close(pauseLoadFn) + finished.Wait() + + assert.Equal(t, int32(1), loadCallCount) } func TestLoadableDelete(t *testing.T) { diff --git a/lib/go.mod b/lib/go.mod index ea5a466..bf24f67 100644 --- a/lib/go.mod +++ b/lib/go.mod @@ -8,6 +8,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/vmihailenco/msgpack/v5 v5.3.5 golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 + golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f ) require ( diff --git a/lib/go.sum b/lib/go.sum index 8dd57fa..846614a 100644 --- a/lib/go.sum +++ b/lib/go.sum @@ -304,6 +304,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f h1:Ax0t5p6N38Ga0dThY21weqDEyz2oklo4IvDkpigvkD8= +golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=