Skip to content

Commit

Permalink
Merge pull request #243 from advoretsky/pr242_avoid_parallel_requests
Browse files Browse the repository at this point in the history
avoid parallel running a load function for the same key #242
  • Loading branch information
eko committed Apr 16, 2024
2 parents 6f5376b + dcc921e commit f8d51e4
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 20 deletions.
59 changes: 48 additions & 11 deletions lib/cache/loadable.go
Expand Up @@ -2,9 +2,12 @@ package cache

import (
"context"
"errors"
"fmt"
"sync"

"github.com/eko/gocache/lib/v4/store"
"golang.org/x/sync/singleflight"
)

const (
Expand All @@ -21,19 +24,21 @@ type LoadFunction[T any] func(ctx context.Context, key any) (T, error)

// LoadableCache represents a cache that uses a function to load data
type LoadableCache[T any] struct {
loadFunc LoadFunction[T]
cache CacheInterface[T]
setChannel chan *loadableKeyValue[T]
setterWg *sync.WaitGroup
singleFlight singleflight.Group
loadFunc LoadFunction[T]
cache CacheInterface[T]
setChannel chan *loadableKeyValue[T]
setterWg *sync.WaitGroup
}

// NewLoadable instanciates a new cache that uses a function to load data
// NewLoadable instantiates a new cache that uses a function to load data
func NewLoadable[T any](loadFunc LoadFunction[T], cache CacheInterface[T]) *LoadableCache[T] {
loadable := &LoadableCache[T]{
loadFunc: loadFunc,
cache: cache,
setChannel: make(chan *loadableKeyValue[T], 10000),
setterWg: &sync.WaitGroup{},
singleFlight: singleflight.Group{},
loadFunc: loadFunc,
cache: cache,
setChannel: make(chan *loadableKeyValue[T], 10000),
setterWg: &sync.WaitGroup{},
}

loadable.setterWg.Add(1)
Expand All @@ -47,6 +52,9 @@ func (c *LoadableCache[T]) setter() {

for item := range c.setChannel {
c.Set(context.Background(), item.key, item.value)

cacheKey := c.getCacheKey(item.key)
c.singleFlight.Forget(cacheKey)
}
}

Expand All @@ -60,9 +68,24 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) {
}

// Unable to find in cache, try to load it from load function
object, err = c.loadFunc(ctx, key)
cacheKey := c.getCacheKey(key)
zero := *new(T)

loadedResult, err, _ := c.singleFlight.Do(
cacheKey,
func() (any, error) {
return c.loadFunc(ctx, key)
},
)
if err != nil {
return object, err
return zero, err
}

var ok bool
if object, ok = loadedResult.(T); !ok {
return zero, errors.New(
fmt.Sprintf("returned value can't be cast to %T", zero),
)
}

// Then, put it back in cache
Expand Down Expand Up @@ -102,3 +125,17 @@ func (c *LoadableCache[T]) Close() error {

return nil
}

// getCacheKey returns the cache key for the given key object by returning
// the key if type is string or by computing a checksum of key structure
// if its type is other than string
func (c *LoadableCache[T]) getCacheKey(key any) string {
switch v := key.(type) {
case string:
return v
case CacheKeyGenerator:
return v.GetCacheKey()
default:
return checksum(key)
}
}
45 changes: 36 additions & 9 deletions lib/cache/loadable_test.go
Expand Up @@ -3,6 +3,8 @@ package cache
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -98,25 +100,50 @@ func TestLoadableGetWhenAvailableInLoadFunc(t *testing.T) {
// Cache 1
cache1 := NewMockSetterCacheInterface[any](ctrl)
cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1"))
cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1"))
cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1"))
cache1.EXPECT().Set(ctx, "my-key", cacheValue).AnyTimes().Return(nil)

var loadCallCount int32
pauseLoadFn := make(chan struct{})

loadFunc := func(_ context.Context, key any) (any, error) {
atomic.AddInt32(&loadCallCount, 1)
<-pauseLoadFn
time.Sleep(time.Millisecond * 10)
return cacheValue, nil
}

cache := NewLoadable[any](loadFunc, cache1)

// When
value, err := cache.Get(ctx, "my-key")

// Wait for data to be processed
for len(cache.setChannel) > 0 {
time.Sleep(1 * time.Millisecond)
const numRequests = 3
var started sync.WaitGroup
started.Add(numRequests)
var finished sync.WaitGroup
finished.Add(numRequests)
for i := 0; i < numRequests; i++ {
go func() {
defer finished.Done()
started.Done()
// When
value, err := cache.Get(ctx, "my-key")

// Wait for data to be processed
for len(cache.setChannel) > 0 {
time.Sleep(1 * time.Millisecond)
}

// Then
assert.Nil(t, err)
assert.Equal(t, cacheValue, value)
}()
}

// Then
assert.Nil(t, err)
assert.Equal(t, cacheValue, value)
started.Wait()
close(pauseLoadFn)
finished.Wait()

assert.Equal(t, int32(1), loadCallCount)
}

func TestLoadableDelete(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions lib/go.mod
Expand Up @@ -8,6 +8,7 @@ require (
github.com/stretchr/testify v1.8.1
github.com/vmihailenco/msgpack/v5 v5.3.5
golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f
)

require (
Expand Down
2 changes: 2 additions & 0 deletions lib/go.sum
Expand Up @@ -304,6 +304,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f h1:Ax0t5p6N38Ga0dThY21weqDEyz2oklo4IvDkpigvkD8=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
Expand Down

0 comments on commit f8d51e4

Please sign in to comment.