Skip to content

Commit

Permalink
Improvement to Trie's implementation (#2777)
Browse files Browse the repository at this point in the history
* Move trie implementation to new file
* Add benchmarks for trie impl
* Remove return val `error` from `trieNode.insert()`
* Add unit tests for trie impl
* Switch to iterative trie impl
* Remove lowercase op in `trieNode.contains()`
  • Loading branch information
eugercek committed Nov 28, 2022
1 parent 24f9871 commit 616e5e2
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 50 deletions.
54 changes: 4 additions & 50 deletions lib/types/hostnametrie.go
Expand Up @@ -108,6 +108,7 @@ func NewHostnameTrie(source []string) (*HostnameTrie, error) {
// Regex description of hostname pattern to enforce blocks by. Global var
// to avoid compilation penalty at runtime.
// based on regex from https://stackoverflow.com/a/106223/5427244
//
//nolint:lll
var validHostnamePattern *regexp.Regexp = regexp.MustCompile(`^(\*\.?)?((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9]))?$`)

Expand All @@ -126,60 +127,13 @@ func (t *HostnameTrie) insert(s string) error {
return err
}

return t.trieNode.insert(s)
t.trieNode.insert(s)
return nil
}

// Contains returns whether s matches a pattern in the HostnameTrie
// along with the matching pattern, if one was found.
func (t *HostnameTrie) Contains(s string) (matchedPattern string, matchFound bool) {
return t.trieNode.contains(s)
}

type trieNode struct {
isLeaf bool
children map[rune]*trieNode
}

func (t *trieNode) insert(s string) error {
if len(s) == 0 {
t.isLeaf = true
return nil
}

// mask creation of the trie by initializing the root here
if t.children == nil {
t.children = make(map[rune]*trieNode)
}

rStr := []rune(s) // need to iterate by runes for intl' names
last := len(rStr) - 1
if c, ok := t.children[rStr[last]]; ok {
return c.insert(string(rStr[:last]))
}

t.children[rStr[last]] = &trieNode{children: make(map[rune]*trieNode)}
return t.children[rStr[last]].insert(string(rStr[:last]))
}

func (t *trieNode) contains(s string) (matchedPattern string, matchFound bool) {
s = strings.ToLower(s)
if len(s) == 0 {
if t.isLeaf {
return "", true
}
} else {
rStr := []rune(s)
last := len(rStr) - 1
if c, ok := t.children[rStr[last]]; ok {
if match, matched := c.contains(string(rStr[:last])); matched {
return match + string(rStr[last]), true
}
}
}

if _, wild := t.children['*']; wild {
return "*", true
}

return "", false
return t.trieNode.contains(s)
}
74 changes: 74 additions & 0 deletions lib/types/trie.go
@@ -0,0 +1,74 @@
package types

import "strings"

type trieNode struct {
isLeaf bool
children map[rune]*trieNode
}

func (t *trieNode) insert(s string) {
runes := []rune(s)

if t.children == nil {
t.children = map[rune]*trieNode{}
}

ptr := t
for i := len(runes) - 1; i >= 0; i-- {
c, ok := ptr.children[runes[i]]

if !ok {
ptr.children[runes[i]] = &trieNode{children: map[rune]*trieNode{}}
c = ptr.children[runes[i]]
}

ptr = c
}

ptr.isLeaf = true
}

func (t *trieNode) contains(s string) (string, bool) {
rs := []rune(s)

builder, wMatch := strings.Builder{}, ""
found := true

ptr := t
for i := len(rs) - 1; i >= 0; i-- {
child, ok := ptr.children[rs[i]]

if _, wOk := ptr.children['*']; wOk {
wMatch = builder.String() + string('*')
}

if !ok {
found = false
break
}

builder.WriteRune(rs[i])
ptr = child
}

if found && ptr.isLeaf {
return reverseString(builder.String()), true
}

if _, ok := ptr.children['*']; ok {
builder.WriteRune('*')
return reverseString(builder.String()), true
}

return reverseString(wMatch), wMatch != ""
}

func reverseString(s string) string {
rs := []rune(s)
for i, j := 0, len(rs)-1; i < len(rs)/2; i, j = i+1, j-1 {
rs[i], rs[j] = rs[j], rs[i]
}

return string(rs)
}
120 changes: 120 additions & 0 deletions lib/types/trie_test.go
@@ -0,0 +1,120 @@
package types

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestTrieInsert(t *testing.T) {
t.Parallel()

root := trieNode{}

const val = "k6.io"
root.insert(val)

ptr := &root
for i, rs := len(val)-1, []rune(val); i >= 0; i-- {
next, ok := ptr.children[rs[i]]
require.True(t, ok)
ptr = next
}

require.True(t, ptr.isLeaf)
}

func TestTrieContains(t *testing.T) {
t.Parallel()

root := trieNode{}
root.insert("k6.io")
root.insert("specific.k6.io")
root.insert("*.k6.io")

tcs := []struct {
query, expVal string
found bool
}{
// Trie functionality
{query: "k6.io", expVal: "k6.io", found: true},
{query: "io", expVal: "", found: false},
{query: "no.k6.no.io", expVal: "", found: false},
{query: "specific.k6.io", expVal: "specific.k6.io", found: true},
{query: "", expVal: "", found: false},
{query: "long.long.long.long.long.long.long.long.no.match", expVal: "", found: false},
{query: "pre.matching.long.long.long.long.test.k6.noio", expVal: "", found: false},

// Wildcard
{query: "foo.k6.io", expVal: "*.k6.io", found: true},
{query: "specific.k6.io", expVal: "specific.k6.io", found: true},
{query: "not.specific.k6.io", expVal: "*.k6.io", found: true},
}

for _, tc := range tcs {
tc := tc
t.Run(tc.query, func(t *testing.T) {
t.Parallel()

val, ok := root.contains(tc.query)

require.Equal(t, tc.found, ok)
require.Equal(t, tc.expVal, val)
})
}
}

func TestReverseString(t *testing.T) {
t.Parallel()

tcs := []struct{ str, rev string }{
{str: "even", rev: "neve"},
{str: "odd", rev: "ddo"},
{str: "", rev: ""},
}

for _, tc := range tcs {
tc := tc

t.Run(tc.str, func(t *testing.T) {
t.Parallel()
val := reverseString(tc.str)

require.Equal(t, tc.rev, val)
})
}
}

func BenchmarkTrieInsert(b *testing.B) {
arr := []string{
"k6.io", "*.sub.k6.io", "specific.sub.k6.io",
"grafana.com", "*.sub.sub.grafana.com", "test.sub.sub.grafana.com",
}
b.ResetTimer()

for i := 0; i < b.N; i++ {
root := trieNode{}
for _, v := range arr {
root.insert(v)
}
}
}

func BenchmarkTrieContains(b *testing.B) {
root := trieNode{}
arr := []string{
"k6.io", "*.sub.k6.io", "specific.sub.k6.io",
"grafana.com", "*.sub.sub.grafana.com", "test.sub.sub.grafana.com",
}

for _, v := range arr {
root.insert(v)
}
b.ResetTimer()

for i := 0; i < b.N; i++ {
for _, v := range arr {
root.contains(v)
}
}
}
3 changes: 3 additions & 0 deletions lib/types/types.go
@@ -1,3 +1,6 @@
// Package types contains types used in the codebase
// Most of the types have a Null prefix like gopkg.in/guregu/null.v3
// and UnmarshalJSON and MarshalJSON methods.
package types

import (
Expand Down

0 comments on commit 616e5e2

Please sign in to comment.