Skip to content

Commit

Permalink
Merge pull request #14 from mroth/v2
Browse files Browse the repository at this point in the history
v2: conversion to utilize go1.18 generics
  • Loading branch information
mroth committed Nov 14, 2022
2 parents c32b594 + 54fe109 commit 0892404
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 66 deletions.
8 changes: 0 additions & 8 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@ jobs:
strategy:
matrix:
go:
- "1.10"
- "1.11"
- "1.12"
- "1.13"
- "1.14"
- "1.15"
- "1.16"
- "1.17"
- "1.18"
- "1.19"
name: Go ${{ matrix.go }} test
Expand Down
29 changes: 17 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@ element to be selected are not equal, but rather defined by relative "weights"
```go
import (
/* ...snip... */
wr "github.com/mroth/weightedrand"
"github.com/mroth/weightedrand/v2"
)

func main() {
rand.Seed(time.Now().UTC().UnixNano()) // always seed random!

chooser, _ := wr.NewChooser(
wr.Choice{Item: "🍒", Weight: 0},
wr.Choice{Item: "🍋", Weight: 1},
wr.Choice{Item: "🍊", Weight: 1},
wr.Choice{Item: "🍉", Weight: 3},
wr.Choice{Item: "🥑", Weight: 5},
chooser, _ := weightedrand.NewChooser(
weightedrand.NewChoice('🍒', 0),
weightedrand.NewChoice('🍋', 1),
weightedrand.NewChoice('🍊', 1),
weightedrand.NewChoice('🍉', 3),
weightedrand.NewChoice('🥑', 5),
)
/* The following will print 🍋 and 🍊 with 0.1 probability, 🍉 with 0.3
probability, and 🥑 with 0.5 probability. 🍒 will never be printed. (Note
the weights don't have to add up to 10, that was just done here to make the
example easier to read.) */
result := chooser.Pick().(string)
// The following will print 🍋 and 🍊 with 0.1 probability, 🍉 with 0.3
// probability, and 🥑 with 0.5 probability. 🍒 will never be printed. (Note
// the weights don't have to add up to 10, that was just done here to make
// the example easier to read.)
result := chooser.Pick()
fmt.Println(result)
}
```
Expand Down Expand Up @@ -73,6 +73,11 @@ right choice! If you are only picking from the same distribution once,
`randutil` will be faster. `weightedrand` optimizes for repeated calls at the
expense of some initialization time and memory storage.

## Requirements

weightedrand >= v2 requires go1.18 or greater. For support on earlier versions
of go, use weightedrand [v1](https://github.com/mroth/weightedrand/tree/v1).

## Credits

To better understand the algorithm used in this library (as well as the one used
Expand Down
8 changes: 4 additions & 4 deletions examples/compbench/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"time"

"github.com/jmcvetta/randutil"
"github.com/mroth/weightedrand"
"github.com/mroth/weightedrand/v2"
)

const BMMinChoices = 10
Expand Down Expand Up @@ -98,9 +98,9 @@ func BenchmarkSingle(b *testing.B) {
})
}

func mockChoices(tb testing.TB, n int) []weightedrand.Choice {
func mockChoices(tb testing.TB, n int) []weightedrand.Choice[rune, uint] {
tb.Helper()
choices := make([]weightedrand.Choice, 0, n)
choices := make([]weightedrand.Choice[rune, uint], 0, n)
for i := 0; i < n; i++ {
s := '🥑'
w := rand.Intn(10)
Expand All @@ -110,7 +110,7 @@ func mockChoices(tb testing.TB, n int) []weightedrand.Choice {
return choices
}

func convertChoices(tb testing.TB, cs []weightedrand.Choice) []randutil.Choice {
func convertChoices(tb testing.TB, cs []weightedrand.Choice[rune, uint]) []randutil.Choice {
tb.Helper()
res := make([]randutil.Choice, len(cs))
for i, c := range cs {
Expand Down
6 changes: 3 additions & 3 deletions examples/compbench/go.mod
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module github.com/mroth/weightedrand/examples/compbench

go 1.15
go 1.18

require (
github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff
github.com/mroth/weightedrand v0.0.0
github.com/mroth/weightedrand/v2 v2.0.0
)

replace github.com/mroth/weightedrand => ../..
replace github.com/mroth/weightedrand/v2 => ../..
16 changes: 8 additions & 8 deletions examples/frequency/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import (
"math/rand"
"time"

wr "github.com/mroth/weightedrand"
"github.com/mroth/weightedrand/v2"
)

func main() {
rand.Seed(time.Now().UTC().UnixNano()) // always seed random!

c, err := wr.NewChooser(
wr.Choice{Item: '🍒', Weight: 0}, // alternatively: wr.NewChoice('🍒', 0)
wr.Choice{Item: '🍋', Weight: 1},
wr.Choice{Item: '🍊', Weight: 1},
wr.Choice{Item: '🍉', Weight: 3},
wr.Choice{Item: '🥑', Weight: 5},
c, err := weightedrand.NewChooser(
weightedrand.NewChoice('🍒', 0),
weightedrand.NewChoice('🍋', 1),
weightedrand.NewChoice('🍊', 1),
weightedrand.NewChoice('🍉', 3),
weightedrand.NewChoice('🥑', 5),
)
if err != nil {
log.Fatal(err)
Expand All @@ -26,7 +26,7 @@ func main() {
/* Let's pick a bunch of fruits so we can see the distribution in action! */
fruits := make([]rune, 40*18)
for i := 0; i < len(fruits); i++ {
fruits[i] = c.Pick().(rune)
fruits[i] = c.Pick()
}
fmt.Println(string(fruits))

Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/mroth/weightedrand
module github.com/mroth/weightedrand/v2

go 1.10
go 1.18
30 changes: 19 additions & 11 deletions weightedrand.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,30 @@ import (
)

// Choice is a generic wrapper that can be used to add weights for any item.
type Choice struct {
Item interface{}
Weight uint
type Choice[T any, W integer] struct {
Item T
Weight W
}

type integer interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
}

// NewChoice creates a new Choice with specified item and weight.
func NewChoice(item interface{}, weight uint) Choice {
return Choice{Item: item, Weight: weight}
func NewChoice[T any, W integer](item T, weight W) Choice[T, W] {
return Choice[T, W]{Item: item, Weight: weight}
}

// A Chooser caches many possible Choices in a structure designed to improve
// performance on repeated calls for weighted random selection.
type Chooser struct {
data []Choice
type Chooser[T any, W integer] struct {
data []Choice[T, W]
totals []int
max int
}

// NewChooser initializes a new Chooser for picking from the provided choices.
func NewChooser(choices ...Choice) (*Chooser, error) {
func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], error) {
sort.Slice(choices, func(i, j int) bool {
return choices[i].Weight < choices[j].Weight
})
Expand All @@ -45,6 +49,10 @@ func NewChooser(choices ...Choice) (*Chooser, error) {
runningTotal := 0
for i, c := range choices {
weight := int(c.Weight)
if weight < 0 {
continue // ignore negative weights, can never be picked
}

if (maxInt - runningTotal) <= weight {
return nil, errWeightOverflow
}
Expand All @@ -56,7 +64,7 @@ func NewChooser(choices ...Choice) (*Chooser, error) {
return nil, errNoValidChoices
}

return &Chooser{data: choices, totals: totals, max: runningTotal}, nil
return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal}, nil
}

const (
Expand All @@ -80,7 +88,7 @@ var (
// Pick returns a single weighted random Choice.Item from the Chooser.
//
// Utilizes global rand as the source of randomness.
func (c Chooser) Pick() interface{} {
func (c Chooser[T, W]) Pick() T {
r := rand.Intn(c.max) + 1
i := searchInts(c.totals, r)
return c.data[i].Item
Expand All @@ -95,7 +103,7 @@ func (c Chooser) Pick() interface{} {
//
// It is the responsibility of the caller to ensure the provided rand.Source is
// free from thread safety issues.
func (c Chooser) PickSource(rs *rand.Rand) interface{} {
func (c Chooser[T, W]) PickSource(rs *rand.Rand) T {
r := rs.Intn(c.max) + 1
i := searchInts(c.totals, r)
return c.data[i].Item
Expand Down
41 changes: 23 additions & 18 deletions weightedrand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func Example() {
NewChoice('🍉', 0),
NewChoice('🥑', 42),
)
fruit := chooser.Pick().(rune)
fruit := chooser.Pick()
fmt.Printf("%c", fruit)
//Output: 🥑
}
Expand All @@ -45,32 +45,37 @@ const (
func TestNewChooser(t *testing.T) {
tests := []struct {
name string
cs []Choice
cs []Choice[rune, int]
wantErr error
}{
{
name: "zero choices",
cs: []Choice{},
cs: []Choice[rune, int]{},
wantErr: errNoValidChoices,
},
{
name: "no choices with positive weight",
cs: []Choice{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}},
cs: []Choice[rune, int]{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}},
wantErr: errNoValidChoices,
},
{
name: "choice with weight equals 1",
cs: []Choice{{Item: 'a', Weight: 1}},
cs: []Choice[rune, int]{{Item: 'a', Weight: 1}},
wantErr: nil,
},
{
name: "weight overflow",
cs: []Choice{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}},
cs: []Choice[rune, int]{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}},
wantErr: errWeightOverflow,
},
{
name: "nominal case",
cs: []Choice{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}},
cs: []Choice[rune, int]{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}},
wantErr: nil,
},
{
name: "negative weight case",
cs: []Choice[rune, int]{{Item: 'a', Weight: 3}, {Item: 'b', Weight: -2}},
wantErr: nil,
},
}
Expand Down Expand Up @@ -100,7 +105,7 @@ func TestChooser_Pick(t *testing.T) {
counts := make(map[int]int)
for i := 0; i < testIterations; i++ {
c := chooser.Pick()
counts[c.(int)]++
counts[c]++
}

verifyFrequencyCounts(t, counts, choices)
Expand All @@ -127,7 +132,7 @@ func TestChooser_PickSource(t *testing.T) {
rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
for i := 0; i < testIterations/2; i++ {
c := chooser.PickSource(rs)
counts[c.(int)]++
counts[c]++
}
}
go checker(counts1)
Expand All @@ -140,19 +145,19 @@ func TestChooser_PickSource(t *testing.T) {

// Similar to what is used in randutil test, but in randomized order to avoid
// any issues with algorithms that are accidentally dependant on presorted data.
func mockFrequencyChoices(t *testing.T, n int) []Choice {
func mockFrequencyChoices(t *testing.T, n int) []Choice[int, int] {
t.Helper()
choices := make([]Choice, 0, n)
choices := make([]Choice[int, int], 0, n)
list := rand.Perm(n)
for _, v := range list {
c := NewChoice(v, uint(v))
c := NewChoice(v, v)
choices = append(choices, c)
}
t.Log("mocked choices of", choices)
return choices
}

func verifyFrequencyCounts(t *testing.T, counts map[int]int, choices []Choice) {
func verifyFrequencyCounts(t *testing.T, counts map[int]int, choices []Choice[int, int]) {
t.Helper()

// Ensure weight 0 results in no results
Expand Down Expand Up @@ -202,7 +207,7 @@ func BenchmarkPick(b *testing.B) {
b.ResetTimer()

for i := 0; i < b.N; i++ {
_ = chooser.Pick().(rune)
_ = chooser.Pick()
}
})
}
Expand All @@ -220,19 +225,19 @@ func BenchmarkPickParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
for pb.Next() {
_ = chooser.PickSource(rs).(rune)
_ = chooser.PickSource(rs)
}
})
})
}
}

func mockChoices(n int) []Choice {
choices := make([]Choice, 0, n)
func mockChoices(n int) []Choice[rune, int] {
choices := make([]Choice[rune, int], 0, n)
for i := 0; i < n; i++ {
s := '🥑'
w := rand.Intn(10)
c := NewChoice(s, uint(w))
c := NewChoice(s, w)
choices = append(choices, c)
}
return choices
Expand Down

0 comments on commit 0892404

Please sign in to comment.