Skip to content

Commit

Permalink
all: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed May 6, 2024
1 parent 3f23c9a commit 54212e9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
38 changes: 34 additions & 4 deletions internal/client/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"strings"

"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/exp/maps"
)

// macKey contains MAC as byte array of 6, 8, or 20 bytes.
Expand Down Expand Up @@ -91,8 +94,13 @@ func (ci *Index) Add(c *Persistent) {

// ClashesUID returns existing persistent client with the same UID as c. Note
// that this is only possible when configuration contains duplicate fields.
func (ci *Index) ClashesUID(c *Persistent) (existing *Persistent) {
return ci.uidToClient[c.UID]
func (ci *Index) ClashesUID(c *Persistent) (err error) {
p, ok := ci.uidToClient[c.UID]
if ok {
return fmt.Errorf("another client %q uses the same uid", p.Name)
}

return nil
}

// Clashes returns an error if the index contains a different persistent client
Expand Down Expand Up @@ -329,15 +337,37 @@ func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
}
}

// SortedRange is like [Index.Range] but sorts the keys before iterating
// ensuring a predictable order.
func (ci *Index) SortedRange(
s func(a, b *Persistent) (n int),
f func(c *Persistent) (cont bool),
) {
cs := maps.Values(ci.uidToClient)
slices.SortFunc(cs, s)

for _, c := range cs {
if !f(c) {
break
}
}
}

// CloseUpstreams closes upstream configurations of persistent clients.
func (ci *Index) CloseUpstreams() (err error) {
sortFunc := func(a, b *Persistent) (n int) {
return strings.Compare(a.Name, b.Name)
}

var errs []error
for _, c := range ci.uidToClient {
ci.SortedRange(sortFunc, func(c *Persistent) (cont bool) {
err = c.CloseUpstreams()
if err != nil {
errs = append(errs, err)
}
}

return true
})

return errors.Join(errs...)
}
24 changes: 22 additions & 2 deletions internal/client/index_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package client
import (
"net"
"net/netip"
"slices"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -71,13 +73,14 @@ func TestClientIndex(t *testing.T) {
}
)

ci := newIDIndex([]*Persistent{
clients := []*Persistent{
clientWithBothFams,
clientWithSubnet,
clientWithMAC,
clientWithID,
clientLinkLocal,
})
}
ci := newIDIndex(clients)

testCases := []struct {
want *Persistent
Expand Down Expand Up @@ -120,6 +123,23 @@ func TestClientIndex(t *testing.T) {
_, ok := ci.Find(cliIPNone)
assert.False(t, ok)
})

t.Run("sorted_range", func(t *testing.T) {
sortFunc := func(a, b *Persistent) (n int) {
return strings.Compare(a.Name, b.Name)
}

slices.SortFunc(clients, sortFunc)

got := []*Persistent{}
ci.SortedRange(sortFunc, func(c *Persistent) (cont bool) {
got = append(got, c)

return true
})

assert.Equal(t, clients, got)
})
}

func TestClientIndex_Clashes(t *testing.T) {
Expand Down
11 changes: 4 additions & 7 deletions internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,10 @@ func (clients *clientsContainer) addFromConfig(
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
}

if p := clients.clientIndex.ClashesUID(cli); p != nil {
return fmt.Errorf(
"clients: client %s at index %d has duplicate uid as %s",
cli.Name,
i,
p.Name,
)
// TODO(s.chzhen): Consider moving to the client index constructor.
err = clients.clientIndex.ClashesUID(cli)
if err != nil {
return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err)
}

err = clients.add(cli)
Expand Down

0 comments on commit 54212e9

Please sign in to comment.