Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use saturating casts in lean_float_to_uint8 to avoid UB #1458

Merged
merged 1 commit into from Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 27 additions & 14 deletions src/include/lean/lean.h
Expand Up @@ -1826,20 +1826,33 @@ static inline uint64_t lean_name_hash(b_lean_obj_arg n) {
}

/* float primitives */
static inline uint8_t lean_float_to_uint8(double a) { return (uint8_t)a; }
static inline uint16_t lean_float_to_uint16(double a) { return (uint16_t)a; }
static inline uint32_t lean_float_to_uint32(double a) { return (uint32_t)a; }
static inline uint64_t lean_float_to_uint64(double a) { return (uint64_t)a; }
static inline size_t lean_float_to_usize(double a) { return (size_t)a; }
static inline double lean_float_add(double a, double b) { return a + b; }
static inline double lean_float_sub(double a, double b) { return a - b; }
static inline double lean_float_mul(double a, double b) { return a * b; }
static inline double lean_float_div(double a, double b) { return a / b; }
static inline double lean_float_negate(double a) { return -a; }
static inline uint8_t lean_float_beq(double a, double b) { return a == b; }
static inline uint8_t lean_float_decLe(double a, double b) { return a <= b; }
static inline uint8_t lean_float_decLt(double a, double b) { return a < b; }
static inline double lean_uint64_to_float(uint64_t a) { return (double) a; }
static inline uint8_t lean_float_to_uint8(double a) {
return 0. <= a ? (a < 256. ? (uint8_t)a : UINT8_MAX) : 0;
}
static inline uint16_t lean_float_to_uint16(double a) {
return 0. <= a ? (a < 65536. ? (uint16_t)a : UINT16_MAX) : 0;
}
static inline uint32_t lean_float_to_uint32(double a) {
return 0. <= a ? (a < 4294967296. ? (uint32_t)a : UINT32_MAX) : 0;
}
static inline uint64_t lean_float_to_uint64(double a) {
return 0. <= a ? (a < 18446744073709551616. ? (uint64_t)a : UINT64_MAX) : 0;
}
static inline size_t lean_float_to_usize(double a) {
if (sizeof(size_t) == sizeof(uint64_t)) // NOLINT
return (size_t) lean_float_to_uint64(a); // NOLINT
else
return (size_t) lean_float_to_uint32(a); // NOLINT
}
static inline double lean_float_add(double a, double b) { return a + b; }
static inline double lean_float_sub(double a, double b) { return a - b; }
static inline double lean_float_mul(double a, double b) { return a * b; }
static inline double lean_float_div(double a, double b) { return a / b; }
static inline double lean_float_negate(double a) { return -a; }
static inline uint8_t lean_float_beq(double a, double b) { return a == b; }
static inline uint8_t lean_float_decLe(double a, double b) { return a <= b; }
static inline uint8_t lean_float_decLt(double a, double b) { return a < b; }
static inline double lean_uint64_to_float(uint64_t a) { return (double) a; }

#ifdef __cplusplus
}
Expand Down
72 changes: 40 additions & 32 deletions tests/compiler/float.lean
@@ -1,45 +1,53 @@

def tst1 : IO Unit := do
IO.println (1 : Float);
IO.println ((1 : Float) + 2);
IO.println ((2 : Float) - 3);
IO.println ((3 : Float) * 2);
IO.println ((3 : Float) / 2);
IO.println (decide ((3 : Float) < 2));
IO.println (decide ((3 : Float) < 4));
IO.println ((3 : Float) == 2);
IO.println ((2 : Float) == 2);
IO.println (decide ((3 : Float) ≤ 2));
IO.println (decide ((3 : Float) ≤ 3));
IO.println (decide ((3 : Float) ≤ 4));
IO.println (Float.ofInt 0)
IO.println (Float.ofInt 42)
IO.println (Float.ofInt (-42))
pure ()
IO.println (1 : Float)
IO.println ((1 : Float) + 2)
IO.println ((2 : Float) - 3)
IO.println ((3 : Float) * 2)
IO.println ((3 : Float) / 2)
IO.println (decide ((3 : Float) < 2))
IO.println (decide ((3 : Float) < 4))
IO.println ((3 : Float) == 2)
IO.println ((2 : Float) == 2)
IO.println (decide ((3 : Float) ≤ 2))
IO.println (decide ((3 : Float) ≤ 3))
IO.println (decide ((3 : Float) ≤ 4))
IO.println (Float.ofInt 0)
IO.println (Float.ofInt 42)
IO.println (Float.ofInt (-42))
IO.println (0 / 0 : Float).toUInt8
IO.println (0 / 0 : Float).toUInt16
IO.println (0 / 0 : Float).toUInt32
IO.println (0 / 0 : Float).toUInt64
IO.println (-1 : Float).toUInt8
IO.println (256 : Float).toUInt8
IO.println (1 / 0 : Float).toUInt8
IO.println (-1 : Float).toUInt64
IO.println (2^64 : Float).toUInt64
IO.println (1 / 0 : Float).toUInt64

structure Foo :=
(x : Nat)
(w : UInt64)
(y : Float)
(z : Float)
structure Foo where
x : Nat
w : UInt64
y : Float
z : Float

@[noinline] def mkFoo (x : Nat) : Foo :=
{ x := x, w := x.toUInt64, y := x.toFloat / 3, z := x.toFloat / 2 }
{ x := x, w := x.toUInt64, y := x.toFloat / 3, z := x.toFloat / 2 }

def tst2 (x : Nat) : IO Unit := do
let foo := mkFoo x;
IO.println foo.y;
IO.println foo.z
let foo := mkFoo x
IO.println foo.y
IO.println foo.z

@[noinline] def fMap (f : Float → Float) (xs : List Float) :=
xs.map f
xs.map f

def tst3 (xs : List Float) (y : Float) : IO Unit :=
IO.println (fMap (fun x => x / y) xs)
IO.println (fMap (fun x => x / y) xs)

def main : IO Unit := do
tst1;
IO.println "-----";
tst2 7;
tst3 [3, 4, 7, 8, 9, 11] 2;
pure ()
tst1
IO.println "-----"
tst2 7
tst3 [3, 4, 7, 8, 9, 11] 2
10 changes: 10 additions & 0 deletions tests/compiler/float.lean.expected.out
Expand Up @@ -13,6 +13,16 @@ true
0.000000
42.000000
-42.000000
0
0
0
0
0
255
255
0
18446744073709551615
18446744073709551615
-----
2.333333
3.500000
Expand Down