New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add property checks for BinNat #1176
Changes from 4 commits
08e8a73
8992a98
a808d21
62104d5
fb3e341
d32020f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,8 @@ package Bosatsu/BinNat | |
|
||
from Bosatsu/Nat import Nat, Zero as NatZero, Succ as NatSucc, times2 as times2_Nat | ||
|
||
export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev | ||
export (BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, | ||
prev, times_BinNat, exp, cmp_BinNat, is_even, sub_BinNat, eq_BinNat) | ||
# a natural number with three variants: | ||
# Zero = 0 | ||
# Odd(n) = 2n + 1 | ||
|
@@ -11,6 +12,9 @@ export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev | |
# Zero, Odd(Zero), Even(Zero), Odd(Odd(Zero)), Even(Odd(Zero)) | ||
enum BinNat: Zero, Odd(half: BinNat), Even(half1: BinNat) | ||
|
||
def is_even(b: BinNat) -> Bool: | ||
b matches Zero | Even(_) | ||
|
||
# Convert a BinNat into the equivalent Int | ||
# this is O(log(b)) operation | ||
def toInt(b: BinNat) -> Int: | ||
|
@@ -37,7 +41,51 @@ def toBinNat(n: Int) -> BinNat: | |
(dec(n), fns) | ||
) | ||
# Now apply all the transformations | ||
fns.foldLeft(Zero, \n, fn -> fn(n)) | ||
fns.foldLeft(Zero, (n, fn) -> fn(n)) | ||
|
||
def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison: | ||
recur a: | ||
case Zero: | ||
match b: | ||
case Odd(_) | Even(_): LT | ||
case Zero: EQ | ||
case Odd(a1): | ||
match b: | ||
case Odd(b1): cmp_BinNat(a1, b1) | ||
case Even(b1): | ||
# 2n + 1 <> 2m + 2 | ||
# if n <= m, LT | ||
# if n > m GT | ||
match cmp_BinNat(a1, b1): | ||
case LT | EQ: LT | ||
case GT: GT | ||
case Zero: GT | ||
case Even(a1): | ||
match b: | ||
case Even(b1): cmp_BinNat(a1, b1) | ||
case Odd(b1): | ||
# 2n + 2 <> 2m + 1 | ||
# if n >= m, GT | ||
# if n < m LT | ||
match cmp_BinNat(a1, b1): | ||
case GT | EQ: GT | ||
case LT: LT | ||
case Zero: GT | ||
|
||
# this is more efficient potentially than cmp_BinNat | ||
# because at the first difference we can stop. In the worst | ||
# case of equality, the cost is the same. | ||
def eq_BinNat(a: BinNat, b: BinNat) -> Bool: | ||
recur a: | ||
case Zero: b matches Zero | ||
case Odd(n): | ||
match b: | ||
case Odd(m): eq_BinNat(n, m) | ||
case _: False | ||
case Even(n): | ||
match b: | ||
case Even(m): eq_BinNat(n, m) | ||
case _: False | ||
|
||
# Return the next number | ||
def next(b: BinNat) -> BinNat: | ||
|
@@ -53,14 +101,11 @@ def next(b: BinNat) -> BinNat: | |
# Return the previous number if the number is > 0, else return 0 | ||
def prev(b: BinNat) -> BinNat: | ||
recur b: | ||
Zero: Zero | ||
Odd(Zero): | ||
# This breaks the law below because 0 - 1 = 0 in this function | ||
Zero | ||
Odd(half): | ||
case Zero | Odd(Zero): Zero | ||
case Odd(half): | ||
# (2n + 1) - 1 = 2n = 2(n-1 + 1) | ||
Even(prev(half)) | ||
Even(half1): | ||
case Even(half1): | ||
# 2(n + 1) - 1 = 2n + 1 | ||
Odd(half1) | ||
|
||
|
@@ -97,11 +142,51 @@ def times2(b: BinNat) -> BinNat: | |
#2(2(n + 1)) = 2((2n + 1) + 1) | ||
Even(Odd(n)) | ||
|
||
def sub_BinNat(left: BinNat, right: BinNat) -> BinNat: | ||
# invariant: left > right | ||
# we always return >= 1 | ||
def loop(left, right): | ||
recur left: | ||
case Odd(left) as odd: | ||
match right: | ||
case Zero: odd | ||
case Odd(right): | ||
# (2n + 1) - (2m + 1) = 2(n - m) | ||
times2(loop(left, right)) | ||
case Even(right): | ||
# (2n + 1) - (2m + 2) = 2(n - m) - 1 | ||
# note if (2n + 1) > (2m + 2), then n > m | ||
times2(loop(left, right)).prev() | ||
case Even(left) as even: | ||
match right: | ||
case Zero: even | ||
case Odd(right): | ||
# (2n + 2) - (2m + 1) | ||
# if n >= m: 2(n - m) + 1 | ||
# if (2n + 2) > (2m + 1), then n > m | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is wrong... if (2n + 2) > (2m + 1) then n >= m. So, that means the invariant must actually be left >= right. So that means we only know that (2n + 2) >= (2m + 1), but that's okay, because we know an odd number can't equal an even number, so we recover n >= m. |
||
Odd(loop(left, right)) | ||
case Even(right): | ||
# (2n + 2) - (2m + 2) = 2(n - m) | ||
times2(loop(left, right)) | ||
case Zero: | ||
# this is last because it is actually unreachable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this isn't true. The invariant is left >= right |
||
# due to the invariant that left > right | ||
# but the type system can't express this | ||
# we put it last to improve performance | ||
Zero | ||
|
||
if cmp_BinNat(left, right) matches (LT | EQ): Zero | ||
else: loop(left, right) | ||
|
||
def div2(b: BinNat) -> BinNat: | ||
match b: | ||
case Zero: Zero | ||
case Odd(n): n | ||
case Even(n): prev(n) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was a bug! yay for property checks! |
||
case Odd(n): | ||
# (2n + 1)/2 = n | ||
n | ||
case Even(n): | ||
# (2n + 2)/2 = n + 1 | ||
next(n) | ||
|
||
# multiply two BinNat together | ||
def times_BinNat(left: BinNat, right: BinNat) -> BinNat: | ||
|
@@ -122,6 +207,21 @@ def times_BinNat(left: BinNat, right: BinNat) -> BinNat: | |
prod = times_BinNat(left, right) | ||
times2(prod.add_BinNat(right)) | ||
|
||
one = Odd(Zero) | ||
|
||
def exp(base: BinNat, power: BinNat) -> BinNat: | ||
recur power: | ||
case Zero: one | ||
case Odd(n): | ||
# b^(2n + 1) == (b^n) * (b^n) * b | ||
bn = exp(base, n) | ||
bn.times_BinNat(bn).times_BinNat(base) | ||
case Even(n): | ||
# b^(2n + 2) = (b^n * b)^2 | ||
bn = exp(base, n) | ||
bn1 = bn.times_BinNat(base) | ||
bn1.times_BinNat(bn1) | ||
|
||
# fold(fn, a, Zero) = a | ||
# fold(fn, a, n) = fold(fn, fn(a, n - 1), n - 1) | ||
def fold_left_BinNat(fn: (a, BinNat) -> a, init: a, cnt: BinNat) -> a: | ||
|
@@ -158,7 +258,6 @@ def next_law(i, msg): | |
def times2_law(i, msg): | ||
Assertion(i.toBinNat().times2().toInt().eq_Int(i.times(2)), msg) | ||
|
||
one = Odd(Zero) | ||
two = one.next() | ||
three = two.next() | ||
four = three.next() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see below: invariant: left >= right, so the comment below is wrong, we can return Zero.