/
Nat.bosatsu
125 lines (104 loc) · 3.2 KB
/
Nat.bosatsu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package Bosatsu/Nat
from Bosatsu/Predef import add as operator +, times as operator *
export Nat(), times2, add, mult, exp, to_Int, to_Nat, is_even, div2
# This is the traditional encoding of natural numbers
# it is useful when you are iterating on all values
enum Nat: Zero, Succ(prev: Nat)
# This is an O(n) operation
def times2(n: Nat) -> Nat:
recur n:
Zero: Zero
Succ(prev):
# 2*(n + 1) = 2*n + 1 + 1
Succ(Succ(times2(prev)))
def add(n1: Nat, n2: Nat) -> Nat:
recur n1:
Zero: n2
Succ(prev_n1):
match n2:
Zero: n1
Succ(prev_n2): Succ(Succ(add(prev_n1, prev_n2)))
# (n1 + 1) * (n2 + 1) = n1 * n2 + n1 + n2 + 1
def mult(n1: Nat, n2: Nat) -> Nat:
recur n1:
Zero: Zero
Succ(n1):
match n2:
Zero: Zero
Succ(n2):
Succ(mult(n1, n2).add(add(n1, n2)))
def is_even(n: Nat) -> Bool:
def loop(n: Nat, res: Bool) -> Bool:
recur n:
case Zero: res
case Succ(n): loop(n, False if res else True)
loop(n, True)
def div2(n: Nat) -> Nat:
def loop(n: Nat, acc: Nat):
recur n:
case Zero: acc
case Succ(n):
# (n + 1) / 2 = n/2 if n is even, else n/2 + 1
if is_even(n): loop(n, acc)
else: loop(n, Succ(acc))
loop(n, Zero)
one = Succ(Zero)
def exp(base: Nat, power: Nat) -> Nat:
match base:
case Zero: one if power matches Zero else Zero
case Succ(Zero): one
case two_or_more:
def loop(power, acc):
recur power:
case Zero: acc
case Succ(prev):
# b^(n + 1) = (b^n) * b
loop(prev, acc.mult(two_or_more))
loop(power, one)
def to_Int(n: Nat) -> Int:
def loop(acc: Int, n: Nat):
recur n:
Zero: acc
Succ(n): loop(acc + 1, n)
loop(0, n)
def to_Nat(i: Int) -> Nat:
int_loop(i, Zero, \i, nat -> (i.sub(1), Succ(nat)))
################
# Test code below
################
n1 = Succ(Zero)
n2 = Succ(n1)
n3 = Succ(n2)
n4 = Succ(n3)
n5 = Succ(n4)
def operator ==(i0: Int, i1: Int):
cmp_Int(i0, i1) matches EQ
def addLaw(n1: Nat, n2: Nat, label: String) -> Test:
Assertion(add(n1, n2).to_Int() == (n1.to_Int() + n2.to_Int()), label)
def multLaw(n1: Nat, n2: Nat, label: String) -> Test:
Assertion(mult(n1, n2).to_Int() == (n1.to_Int() * n2.to_Int()), label)
def from_to_law(i: Int, message: String) -> Test:
Assertion(i.to_Nat().to_Int() == i, message)
from_to_suite = TestSuite("to_Nat/to_Int tests", [
Assertion(-1.to_Nat().to_Int() == 0, "-1 -> 0"),
Assertion(-42.to_Nat().to_Int() == 0, "-42 -> 0"),
from_to_law(0, "0"),
from_to_law(1, "1"),
from_to_law(10, "10"),
from_to_law(42, "42"),
])
tests = TestSuite("Nat tests",
[
addLaw(Zero, Zero, "0 + 0"),
addLaw(Zero, n1, "0 + 1"),
addLaw(n1, Zero, "1 + 0"),
addLaw(n1, n2, "1 + 2"),
addLaw(n2, n1, "2 + 1"),
multLaw(Zero, Zero, "0 * 0"),
multLaw(Zero, n1, "0 * 1"),
multLaw(n1, Zero, "1 * 0"),
multLaw(n1, n2, "1 * 2"),
multLaw(n2, n1, "2 * 1"),
from_to_suite,
Assertion(exp(n2, n5).to_Int() matches 32, "exp(2, 5) == 32")
])