Skip to content

Commit

Permalink
fix: use saturating casts in lean_float_to_uint8 to avoid UB
Browse files Browse the repository at this point in the history
  • Loading branch information
digama0 authored and leodemoura committed Aug 11, 2022
1 parent 9ac4cf9 commit d8c6c82
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 46 deletions.
41 changes: 27 additions & 14 deletions src/include/lean/lean.h
Expand Up @@ -1826,20 +1826,33 @@ static inline uint64_t lean_name_hash(b_lean_obj_arg n) {
}

/* float primitives */
static inline uint8_t lean_float_to_uint8(double a) { return (uint8_t)a; }
static inline uint16_t lean_float_to_uint16(double a) { return (uint16_t)a; }
static inline uint32_t lean_float_to_uint32(double a) { return (uint32_t)a; }
static inline uint64_t lean_float_to_uint64(double a) { return (uint64_t)a; }
static inline size_t lean_float_to_usize(double a) { return (size_t)a; }
static inline double lean_float_add(double a, double b) { return a + b; }
static inline double lean_float_sub(double a, double b) { return a - b; }
static inline double lean_float_mul(double a, double b) { return a * b; }
static inline double lean_float_div(double a, double b) { return a / b; }
static inline double lean_float_negate(double a) { return -a; }
static inline uint8_t lean_float_beq(double a, double b) { return a == b; }
static inline uint8_t lean_float_decLe(double a, double b) { return a <= b; }
static inline uint8_t lean_float_decLt(double a, double b) { return a < b; }
static inline double lean_uint64_to_float(uint64_t a) { return (double) a; }
static inline uint8_t lean_float_to_uint8(double a) {
return 0. <= a ? (a < 256. ? (uint8_t)a : UINT8_MAX) : 0;
}
static inline uint16_t lean_float_to_uint16(double a) {
return 0. <= a ? (a < 65536. ? (uint16_t)a : UINT16_MAX) : 0;
}
static inline uint32_t lean_float_to_uint32(double a) {
return 0. <= a ? (a < 4294967296. ? (uint32_t)a : UINT32_MAX) : 0;
}
static inline uint64_t lean_float_to_uint64(double a) {
return 0. <= a ? (a < 18446744073709551616. ? (uint64_t)a : UINT64_MAX) : 0;
}
static inline size_t lean_float_to_usize(double a) {
if (sizeof(size_t) == sizeof(uint64_t)) // NOLINT
return (size_t) lean_float_to_uint64(a); // NOLINT
else
return (size_t) lean_float_to_uint32(a); // NOLINT
}
static inline double lean_float_add(double a, double b) { return a + b; }
static inline double lean_float_sub(double a, double b) { return a - b; }
static inline double lean_float_mul(double a, double b) { return a * b; }
static inline double lean_float_div(double a, double b) { return a / b; }
static inline double lean_float_negate(double a) { return -a; }
static inline uint8_t lean_float_beq(double a, double b) { return a == b; }
static inline uint8_t lean_float_decLe(double a, double b) { return a <= b; }
static inline uint8_t lean_float_decLt(double a, double b) { return a < b; }
static inline double lean_uint64_to_float(uint64_t a) { return (double) a; }

#ifdef __cplusplus
}
Expand Down
72 changes: 40 additions & 32 deletions tests/compiler/float.lean
@@ -1,45 +1,53 @@

def tst1 : IO Unit := do
IO.println (1 : Float);
IO.println ((1 : Float) + 2);
IO.println ((2 : Float) - 3);
IO.println ((3 : Float) * 2);
IO.println ((3 : Float) / 2);
IO.println (decide ((3 : Float) < 2));
IO.println (decide ((3 : Float) < 4));
IO.println ((3 : Float) == 2);
IO.println ((2 : Float) == 2);
IO.println (decide ((3 : Float) ≤ 2));
IO.println (decide ((3 : Float) ≤ 3));
IO.println (decide ((3 : Float) ≤ 4));
IO.println (Float.ofInt 0)
IO.println (Float.ofInt 42)
IO.println (Float.ofInt (-42))
pure ()
IO.println (1 : Float)
IO.println ((1 : Float) + 2)
IO.println ((2 : Float) - 3)
IO.println ((3 : Float) * 2)
IO.println ((3 : Float) / 2)
IO.println (decide ((3 : Float) < 2))
IO.println (decide ((3 : Float) < 4))
IO.println ((3 : Float) == 2)
IO.println ((2 : Float) == 2)
IO.println (decide ((3 : Float) ≤ 2))
IO.println (decide ((3 : Float) ≤ 3))
IO.println (decide ((3 : Float) ≤ 4))
IO.println (Float.ofInt 0)
IO.println (Float.ofInt 42)
IO.println (Float.ofInt (-42))
IO.println (0 / 0 : Float).toUInt8
IO.println (0 / 0 : Float).toUInt16
IO.println (0 / 0 : Float).toUInt32
IO.println (0 / 0 : Float).toUInt64
IO.println (-1 : Float).toUInt8
IO.println (256 : Float).toUInt8
IO.println (1 / 0 : Float).toUInt8
IO.println (-1 : Float).toUInt64
IO.println (2^64 : Float).toUInt64
IO.println (1 / 0 : Float).toUInt64

structure Foo :=
(x : Nat)
(w : UInt64)
(y : Float)
(z : Float)
structure Foo where
x : Nat
w : UInt64
y : Float
z : Float

@[noinline] def mkFoo (x : Nat) : Foo :=
{ x := x, w := x.toUInt64, y := x.toFloat / 3, z := x.toFloat / 2 }
{ x := x, w := x.toUInt64, y := x.toFloat / 3, z := x.toFloat / 2 }

def tst2 (x : Nat) : IO Unit := do
let foo := mkFoo x;
IO.println foo.y;
IO.println foo.z
let foo := mkFoo x
IO.println foo.y
IO.println foo.z

@[noinline] def fMap (f : Float → Float) (xs : List Float) :=
xs.map f
xs.map f

def tst3 (xs : List Float) (y : Float) : IO Unit :=
IO.println (fMap (fun x => x / y) xs)
IO.println (fMap (fun x => x / y) xs)

def main : IO Unit := do
tst1;
IO.println "-----";
tst2 7;
tst3 [3, 4, 7, 8, 9, 11] 2;
pure ()
tst1
IO.println "-----"
tst2 7
tst3 [3, 4, 7, 8, 9, 11] 2
10 changes: 10 additions & 0 deletions tests/compiler/float.lean.expected.out
Expand Up @@ -13,6 +13,16 @@ true
0.000000
42.000000
-42.000000
0
0
0
0
0
255
255
0
18446744073709551615
18446744073709551615
-----
2.333333
3.500000
Expand Down

0 comments on commit d8c6c82

Please sign in to comment.