Skip to content

Commit

Permalink
Add property checks for Nat (#1175)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Mar 17, 2024
1 parent 39adf19 commit 96163d4
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 24 deletions.
77 changes: 56 additions & 21 deletions test_workspace/Nat.bosatsu
@@ -1,37 +1,72 @@
package Bosatsu/Nat

from Bosatsu/Predef import add as operator +, times as operator *
export Nat(), times2, add, mult, exp, to_Int, to_Nat, is_even, div2
export Nat(), times2, add, sub_Nat, mult, exp, to_Int, to_Nat, is_even, div2, cmp_Nat

# This is the traditional encoding of natural numbers
# it is useful when you are iterating on all values
enum Nat: Zero, Succ(prev: Nat)

def cmp_Nat(a: Nat, b: Nat) -> Comparison:
recur a:
case Zero:
match b:
case Zero: EQ
case _: LT
case Succ(n):
match b:
case Zero: GT
case Succ(m): cmp_Nat(n, m)

# This is an O(n) operation
def times2(n: Nat) -> Nat:
recur n:
Zero: Zero
Succ(prev):
# 2*(n + 1) = 2*n + 1 + 1
Succ(Succ(times2(prev)))
def loop(n: Nat, acc: Nat):
recur n:
Zero: acc
Succ(prev):
# 2*(n + 1) = 2*n + 1 + 1
loop(prev, Succ(Succ(acc)))
loop(n, Zero)

def add(n1: Nat, n2: Nat) -> Nat:
recur n1:
Zero: n2
Succ(prev_n1):
match n2:
Zero: n1
Succ(prev_n2): Succ(Succ(add(prev_n1, prev_n2)))

# (n1 + 1) * (n2 + 1) = n1 * n2 + n1 + n2 + 1
def loop(n1: Nat, n2: Nat):
recur n1:
Zero: n2
Succ(prev_n1): loop(prev_n1, Succ(n2))

if n2 matches Zero: n1
else: loop(n1, n2)

def sub_Nat(n1: Nat, n2: Nat) -> Nat:
recur n2:
Zero: n1
Succ(prev_n2):
# (n1 + 1) - (n2 + 1) == (n1 - n2)
match n1:
case Zero: Zero
case Succ(prev_n1):
sub_Nat(prev_n1, prev_n2)

def mult(n1: Nat, n2: Nat) -> Nat:
recur n1:
Zero: Zero
Succ(n1):
match n2:
Zero: Zero
Succ(n2):
Succ(mult(n1, n2).add(add(n1, n2)))
# return n1 * n2 + c
# note: (n1 + 1) * (n2 + 1) + c ==
# n1 * n2 + (n1 + n2 + 1 + c)
def loop(n1: Nat, n2: Nat, c: Nat):
recur n1:
Zero: c
Succ(n1):
match n2:
Zero: c
Succ(n2):
c1 = Succ(add(add(n1, n2), c))
loop(n1, n2, c1)

# we repeatedly do add(n, m)
# where we keep stepping down by one on both sides.
# add is more efficient if the lhs is smaller
# so we check that first
if cmp_Nat(n1, n2) matches GT: loop(n2, n1, Zero)
else: loop(n1, n2, Zero)

def is_even(n: Nat) -> Bool:
def loop(n: Nat, res: Bool) -> Bool:
Expand Down
111 changes: 111 additions & 0 deletions test_workspace/NumberProps.bosatsu
@@ -0,0 +1,111 @@
package Bosatsu/NumberProps

from Bosatsu/BinNat import (BinNat, toBinNat as int_to_BinNat)
from Bosatsu/Nat import (Nat, Zero as NZero, Succ as NSucc, to_Nat as int_to_Nat, is_even as is_even_Nat,
times2 as times2_Nat, div2 as div2_Nat, cmp_Nat, to_Int as nat_to_Int, add as add_Nat,
mult as mult_Nat, exp as exp_Nat, sub_Nat)
from Bosatsu/Properties import (Prop, suite_Prop, forall_Prop, run_Prop)
from Bosatsu/Rand import (Rand, from_pair, geometric_Int, int_range, map_Rand, prod_Rand)

export (rand_Int, rand_Nat, rand_BinNat)

# Property checks for Nat, BinNat, Int

#external def todo(ignore: x) -> forall a. a

rand_Int: Rand[Int] = from_pair(int_range(128), geometric_Int)
rand_Nat: Rand[Nat] = rand_Int.map_Rand(int_to_Nat)
rand_BinNat: Rand[BinNat] = rand_Int.map_Rand(int_to_BinNat)

eq_Int = (a, b) -> a.cmp_Int(b) matches EQ

int_props = suite_Prop(
"Int props",
[
forall_Prop(prod_Rand(rand_Int, rand_Int), "divmod law", ((a, b)) -> (
adivb = a.div(b)
amodb = a.mod_Int(b)
a1 = adivb.times(b).add(amodb)
Assertion(eq_Int(a1, a), "check")
)),
]
)

def cmp_Comparison(c1: Comparison, c2: Comparison) -> Comparison:
match c1:
case LT:
match c2:
case LT: EQ
case _: LT
case EQ:
match c2:
case LT: GT
case EQ: EQ
case GT: LT
case GT:
match c2:
case GT: EQ
case _: GT

def exp_Int(base: Int, power: Int) -> Int:
int_loop(power, 1, (p, acc) -> (p.sub(1), acc.times(base)))

small_rand_Nat: Rand[Nat] = int_range(7).map_Rand(int_to_Nat)

nat_props = suite_Prop(
"Nat props",
[
forall_Prop(rand_Nat, "if is_even(n) then times2(div2(n)) == n", n -> (
if is_even_Nat(n):
n1 = times2_Nat(div2_Nat(n))
Assertion(cmp_Nat(n1, n) matches EQ, "times2/div2")
else:
# we return the previous number
n1 = times2_Nat(div2_Nat(n))
Assertion(cmp_Nat(NSucc(n1), n) matches EQ, "times2/div2")
)),
forall_Prop(prod_Rand(rand_Nat, rand_Nat), "cmp_Nat matches cmp_Int", ((n1, n2)) -> (
cmp_n = cmp_Nat(n1, n2)
cmp_i = cmp_Int(n1.nat_to_Int(), n2.nat_to_Int())
Assertion(cmp_Comparison(cmp_n, cmp_i) matches EQ, "cmp_Nat")
)),
forall_Prop(prod_Rand(rand_Nat, rand_Nat), "add homomorphism", ((n1, n2)) -> (
n3 = add_Nat(n1, n2)
i3 = add(n1.nat_to_Int(), n2.nat_to_Int())
Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "add homomorphism")
)),
forall_Prop(prod_Rand(rand_Nat, rand_Nat), "sub_Nat homomorphism", ((n1, n2)) -> (
n3 = sub_Nat(n1, n2)
i1 = n1.nat_to_Int()
i2 = n2.nat_to_Int()
match cmp_Int(i1, i2):
case EQ | GT:
i3 = sub(i1, i2)
Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "sub_Nat homomorphism")
case LT:
Assertion(n3 matches NZero, "sub to zero")
)),
forall_Prop(prod_Rand(rand_Nat, rand_Nat), "mult homomorphism", ((n1, n2)) -> (
n3 = mult_Nat(n1, n2)
i3 = times(n1.nat_to_Int(), n2.nat_to_Int())
Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "mult homomorphism")
)),
forall_Prop(prod_Rand(small_rand_Nat, small_rand_Nat), "exp homomorphism", ((n1, n2)) -> (
n3 = exp_Nat(n1, n2)
i3 = exp_Int(n1.nat_to_Int(), n2.nat_to_Int())
Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "exp homomorphism")
)),
forall_Prop(rand_Nat, "times2 == x -> mult(2, x)", n -> (
t2 = n.times2_Nat()
t2_2 = mult_Nat(n, NSucc(NSucc(NZero)))
Assertion(cmp_Nat(t2, t2_2) matches EQ, "times2 == mult(2, _)")
)),
]
)

all_props = [int_props, nat_props]

seed = 123456
test = TestSuite("properties", [
run_Prop(p, 100, seed) for p in all_props
])
4 changes: 2 additions & 2 deletions test_workspace/Properties.bosatsu
Expand Up @@ -73,5 +73,5 @@ all_props = suite_Prop(
shift_unshift_law,
positive_and_law,
])

all_laws = run_Prop(all_props, 100, 42)
test = run_Prop(all_props, 100, 42)
6 changes: 5 additions & 1 deletion test_workspace/Rand.bosatsu
Expand Up @@ -7,7 +7,7 @@ from Bosatsu/BinNat import (BinNat, Zero as BZero, Odd, Even,
next as next_BinNat, prev as prev_BinNat, toInt as binNat_to_Int)

export (Rand, run_Rand, prod_Rand, map_Rand, flat_map_Rand, const_Rand,
int_range, sequence_Rand, bool_Rand, geometric_Int, from_pair, one_of)
int_range, nat_range, sequence_Rand, bool_Rand, geometric_Int, from_pair, one_of)

struct State(s0: Int, s1: Int, s2: Int, s3: Int)
struct UInt64(toInt: Int)
Expand Down Expand Up @@ -171,6 +171,9 @@ def int_range(high: Int) -> Rand[Int]:
resample(rand_Int, high, uint_count)
else: const0

def nat_range(high: Nat) -> Rand[Nat]:
int_range(nat_to_Int(high)).map_Rand(int_to_Nat)

def geometric(depth: Nat, acc: Int) -> Rand[Int]:
recur depth:
case Zero: const_Rand(acc)
Expand All @@ -182,6 +185,7 @@ def geometric(depth: Nat, acc: Int) -> Rand[Int]:
case False: geometric(prev, acc + 1)
))

# geometric distribution with mean 1
geometric_Int: Rand[Int] = geometric(nat30, 0)

def len[a](list: List[a]) -> BinNat:
Expand Down

0 comments on commit 96163d4

Please sign in to comment.