-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
72 lines (65 loc) · 2.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import unittest
import random
import time
from flatbush import FlatBush, Box
class TestFlatBush(unittest.TestCase):
def test_validate(self):
f = FlatBush()
boxes = []
dim = 100
index = 0
for x in range(dim):
for y in range(dim):
boxes.append(Box(index, x + 0.1, y + 0.1, x + 0.9, y + 0.9)) # Keep our own copy of the box for brute force validation
checkIndex = f.add(x + 0.1, y + 0.1, x + 0.9, y + 0.9)
assert checkIndex == index
index += 1
f.finish()
n_checked = 0
for _ in range(1000):
maxQueryWindow = 5
minx = random.uniform(0, dim)
miny = random.uniform(0, dim)
maxx = minx + random.uniform(0, maxQueryWindow)
maxy = miny + random.uniform(0, maxQueryWindow)
results = f.search(minx, miny, maxx, maxy)
# brute force validation that there are no false negatives
qbox = Box(0, minx, miny, maxx, maxy)
for box in boxes:
if box.positive_union(qbox):
# if object crosses the query rectangle, then it should be included in the result set
n_checked += 1
found = False
for b in results:
if b == box.index:
found = True
break
assert found, "Failed to find box in result set"
assert n_checked != 0, "Unit test didn't actually check anything"
def test_benchmark(self):
f = FlatBush()
dim = 500
start = time.perf_counter()
for x in range(dim):
for y in range(dim):
f.add(x + 0.1, y + 0.1, x + 0.9, y + 0.9)
f.finish()
print(f"Time to insert {dim * dim} elements: {1000 * (time.perf_counter() - start)} milliseconds")
start = time.perf_counter()
nquery = 100 * 1000
results = []
sx = 0
sy = 0
nresults = 0
for _ in range(nquery):
minx = sx % dim
miny = sy % dim
maxx = minx + 3.0
maxy = miny + 3.0
results = f.search(minx, miny, maxx, maxy)
nresults += len(results)
sx += 5
sy += 7
print(f"Time per query returning average of {round(nresults / nquery)} elements: {(1000000.0 / nquery) * (time.perf_counter() - start)} microseconds")
if __name__ == '__main__':
unittest.main()