Skip to content

Commit

Permalink
feat(compiler): Inline all primitives (#1076)
Browse files Browse the repository at this point in the history
* feat(compiler): Inline all primitives

* snapshots
  • Loading branch information
ospencer committed Dec 31, 2021
1 parent dcd0f9e commit c227130
Show file tree
Hide file tree
Showing 20 changed files with 486 additions and 877 deletions.
220 changes: 3 additions & 217 deletions compiler/src/middle_end/analyze_inline_wasm.re
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// This module still exists to support the --no-bulk-memory flag.

open Anftree;
open Anf_iterator;
open Grain_typed;
open Grain_utils;

type inline_type =
| WasmPrim1(prim1)
| WasmPrim2(prim2)
| WasmPrimN(primn);

type analysis +=
Expand Down Expand Up @@ -37,205 +37,15 @@ let get_primitive = prim =>
Option.bind(prim, prim => {
Translprim.(
switch (PrimMap.find_opt(prim_map, prim)) {
| Some(Primitive1(prim)) => Some(WasmPrim1(prim))
| Some(Primitive2(prim)) => Some(WasmPrim2(prim))
| Some(PrimitiveN(prim)) => Some(WasmPrimN(prim))
| None => None
| _ => None
}
)
});

let get_primitive_i32 = id => {
let prim =
switch (id) {
| "fromGrain" => Some("@wasm.fromGrain")
| "toGrain" => Some("@wasm.toGrain")
| "load" => Some("@wasm.load_int32")
| "load8S" => Some("@wasm.load_8_s_int32")
| "load8U" => Some("@wasm.load_8_u_int32")
| "load16S" => Some("@wasm.load_16_s_int32")
| "load16U" => Some("@wasm.load_16_u_int32")
| "store" => Some("@wasm.store_int32")
| "store8" => Some("@wasm.store_8_int32")
| "store16" => Some("@wasm.store_16_int32")
| "clz" => Some("@wasm.clz_int32")
| "ctz" => Some("@wasm.ctz_int32")
| "popcnt" => Some("@wasm.popcnt_int32")
| "eqz" => Some("@wasm.eq_z_int32")
| "add" => Some("@wasm.add_int32")
| "sub" => Some("@wasm.sub_int32")
| "mul" => Some("@wasm.mul_int32")
| "divS" => Some("@wasm.div_s_int32")
| "divU" => Some("@wasm.div_u_int32")
| "remS" => Some("@wasm.rem_s_int32")
| "remU" => Some("@wasm.rem_u_int32")
| "and" => Some("@wasm.and_int32")
| "or" => Some("@wasm.or_int32")
| "xor" => Some("@wasm.xor_int32")
| "shl" => Some("@wasm.shl_int32")
| "shrU" => Some("@wasm.shr_u_int32")
| "shrS" => Some("@wasm.shr_s_int32")
| "rotl" => Some("@wasm.rot_l_int32")
| "rotr" => Some("@wasm.rot_r_int32")
| "eq" => Some("@wasm.eq_int32")
| "ne" => Some("@wasm.ne_int32")
| "ltS" => Some("@wasm.lt_s_int32")
| "ltU" => Some("@wasm.lt_u_int32")
| "leS" => Some("@wasm.le_s_int32")
| "leU" => Some("@wasm.le_u_int32")
| "gtS" => Some("@wasm.gt_s_int32")
| "gtU" => Some("@wasm.gt_u_int32")
| "geS" => Some("@wasm.ge_s_int32")
| "geU" => Some("@wasm.ge_u_int32")
| "wrapI64" => Some("@wasm.wrap_int64")
| "truncF32S" => Some("@wasm.trunc_s_float32_to_int32")
| "truncF32U" => Some("@wasm.trunc_u_float32_to_int32")
| "truncF64S" => Some("@wasm.trunc_s_float64_to_int32")
| "truncF64U" => Some("@wasm.trunc_u_float64_to_int32")
| "reinterpretF32" => Some("@wasm.reinterpret_float32")
| "extendS8" => Some("@wasm.extend_s8_int32")
| "extendS16" => Some("@wasm.extend_s16_int32")
| _ => None
};

get_primitive(prim);
};
let get_primitive_i64 = id => {
let prim =
switch (id) {
| "load" => Some("@wasm.load_int64")
| "load8S" => Some("@wasm.load_8_s_int64")
| "load8U" => Some("@wasm.load_8_u_int64")
| "load16S" => Some("@wasm.load_16_s_int64")
| "load16U" => Some("@wasm.load_16_u_int64")
| "load32S" => Some("@wasm.load_32_s_int64")
| "load32U" => Some("@wasm.load_32_u_int64")
| "store" => Some("@wasm.store_int64")
| "store8" => Some("@wasm.store_8_int64")
| "store16" => Some("@wasm.store_16_int64")
| "store32" => Some("@wasm.store_32_int64")
| "clz" => Some("@wasm.clz_int64")
| "ctz" => Some("@wasm.ctz_int64")
| "popcnt" => Some("@wasm.popcnt_int64")
| "eqz" => Some("@wasm.eq_z_int64")
| "add" => Some("@wasm.add_int64")
| "sub" => Some("@wasm.sub_int64")
| "mul" => Some("@wasm.mul_int64")
| "divS" => Some("@wasm.div_s_int64")
| "divU" => Some("@wasm.div_u_int64")
| "remS" => Some("@wasm.rem_s_int64")
| "remU" => Some("@wasm.rem_u_int64")
| "and" => Some("@wasm.and_int64")
| "or" => Some("@wasm.or_int64")
| "xor" => Some("@wasm.xor_int64")
| "shl" => Some("@wasm.shl_int64")
| "shrU" => Some("@wasm.shr_u_int64")
| "shrS" => Some("@wasm.shr_s_int64")
| "rotl" => Some("@wasm.rot_l_int64")
| "rotr" => Some("@wasm.rot_r_int64")
| "eq" => Some("@wasm.eq_int64")
| "ne" => Some("@wasm.ne_int64")
| "ltS" => Some("@wasm.lt_s_int64")
| "ltU" => Some("@wasm.lt_u_int64")
| "leS" => Some("@wasm.le_s_int64")
| "leU" => Some("@wasm.le_u_int64")
| "gtS" => Some("@wasm.gt_s_int64")
| "gtU" => Some("@wasm.gt_u_int64")
| "geS" => Some("@wasm.ge_s_int64")
| "geU" => Some("@wasm.ge_u_int64")
| "extendI32S" => Some("@wasm.extend_s_int32")
| "extendI32U" => Some("@wasm.extend_u_int32")
| "truncF32S" => Some("@wasm.trunc_s_float32_to_int64")
| "truncF32U" => Some("@wasm.trunc_u_float32_to_int64")
| "truncF64S" => Some("@wasm.trunc_s_float64_to_int64")
| "truncF64U" => Some("@wasm.trunc_u_float64_to_int64")
| "reinterpretF64" => Some("@wasm.reinterpret_float64")
| "extendS8" => Some("@wasm.extend_s8_int64")
| "extendS16" => Some("@wasm.extend_s16_int64")
| "extendS32" => Some("@wasm.extend_s32_int64")
| _ => None
};

get_primitive(prim);
};
let get_primitive_f32 = id => {
let prim =
switch (id) {
| "load" => Some("@wasm.load_float32")
| "store" => Some("@wasm.store_float32")
| "neg" => Some("@wasm.neg_float32")
| "abs" => Some("@wasm.abs_float32")
| "ceil" => Some("@wasm.ceil_float32")
| "floor" => Some("@wasm.floor_float32")
| "trunc" => Some("@wasm.trunc_float32")
| "nearest" => Some("@wasm.nearest_float32")
| "sqrt" => Some("@wasm.sqrt_float32")
| "add" => Some("@wasm.add_float32")
| "sub" => Some("@wasm.sub_float32")
| "mul" => Some("@wasm.mul_float32")
| "div" => Some("@wasm.div_float32")
| "copySign" => Some("@wasm.copy_sign_float32")
| "min" => Some("@wasm.min_float32")
| "max" => Some("@wasm.max_float32")
| "eq" => Some("@wasm.eq_float32")
| "ne" => Some("@wasm.ne_float32")
| "lt" => Some("@wasm.lt_float32")
| "le" => Some("@wasm.le_float32")
| "gt" => Some("@wasm.gt_float32")
| "ge" => Some("@wasm.ge_float32")
| "reinterpretI32" => Some("@wasm.reinterpret_int32")
| "convertI32S" => Some("@wasm.convert_s_int32_to_float32")
| "convertI32U" => Some("@wasm.convert_u_int32_to_float32")
| "convertI64S" => Some("@wasm.convert_s_int64_to_float32")
| "convertI64U" => Some("@wasm.convert_u_int64_to_float32")
| "demoteF64" => Some("@wasm.demote_float64")
| _ => None
};

get_primitive(prim);
};
let get_primitive_f64 = id => {
let prim =
switch (id) {
| "load" => Some("@wasm.load_float64")
| "store" => Some("@wasm.store_float64")
| "neg" => Some("@wasm.neg_float64")
| "abs" => Some("@wasm.abs_float64")
| "ceil" => Some("@wasm.ceil_float64")
| "floor" => Some("@wasm.floor_float64")
| "trunc" => Some("@wasm.trunc_float64")
| "nearest" => Some("@wasm.nearest_float64")
| "sqrt" => Some("@wasm.sqrt_float64")
| "add" => Some("@wasm.add_float64")
| "sub" => Some("@wasm.sub_float64")
| "mul" => Some("@wasm.mul_float64")
| "div" => Some("@wasm.div_float64")
| "copySign" => Some("@wasm.copy_sign_float64")
| "min" => Some("@wasm.min_float64")
| "max" => Some("@wasm.max_float64")
| "eq" => Some("@wasm.eq_float64")
| "ne" => Some("@wasm.ne_float64")
| "lt" => Some("@wasm.lt_float64")
| "le" => Some("@wasm.le_float64")
| "gt" => Some("@wasm.gt_float64")
| "ge" => Some("@wasm.ge_float64")
| "reinterpretI64" => Some("@wasm.reinterpret_int64")
| "convertI32S" => Some("@wasm.convert_s_int32_to_float64")
| "convertI32U" => Some("@wasm.convert_u_int32_to_float64")
| "convertI64S" => Some("@wasm.convert_s_int64_to_float64")
| "convertI64U" => Some("@wasm.convert_u_int64_to_float64")
| "promoteF32" => Some("@wasm.promote_float32")
| _ => None
};

get_primitive(prim);
};
let get_primitive_memory = id => {
let prim =
switch (id) {
| "grow" => Some("@wasm.memory_grow")
| "size" => Some("@wasm.memory_size")
| "compare" => Some("@wasm.memory_compare")
| "copy" when Config.bulk_memory^ => Some("@wasm.memory_copy")
| "fill" when Config.bulk_memory^ => Some("@wasm.memory_fill")
| _ => None
Expand All @@ -249,30 +59,6 @@ let analyze = ({imports, body, analyses}) => {
mod_has_inlineable_wasm := false;
let process_import = ({imp_use_id, imp_desc}) => {
switch (imp_desc) {
| GrainValue("runtime/unsafe/wasmi32", name) =>
mod_has_inlineable_wasm := true;
switch (get_primitive_i32(name)) {
| Some(prim) => set_inlineable_wasm(imp_use_id, prim)
| None => ()
};
| GrainValue("runtime/unsafe/wasmi64", name) =>
mod_has_inlineable_wasm := true;
switch (get_primitive_i64(name)) {
| Some(prim) => set_inlineable_wasm(imp_use_id, prim)
| None => ()
};
| GrainValue("runtime/unsafe/wasmf32", name) =>
mod_has_inlineable_wasm := true;
switch (get_primitive_f32(name)) {
| Some(prim) => set_inlineable_wasm(imp_use_id, prim)
| None => ()
};
| GrainValue("runtime/unsafe/wasmf64", name) =>
mod_has_inlineable_wasm := true;
switch (get_primitive_f64(name)) {
| Some(prim) => set_inlineable_wasm(imp_use_id, prim)
| None => ()
};
| GrainValue("runtime/unsafe/memory", name) =>
mod_has_inlineable_wasm := true;
switch (get_primitive_memory(name)) {
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/middle_end/analyze_inline_wasm.rei
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ open Anftree;
open Grain_typed;

type inline_type =
| WasmPrim1(prim1)
| WasmPrim2(prim2)
| WasmPrimN(primn);

let mod_has_inlineable_wasm: ref(bool);
Expand Down
62 changes: 27 additions & 35 deletions compiler/src/middle_end/linearize.re
Original file line number Diff line number Diff line change
Expand Up @@ -426,21 +426,19 @@ let rec transl_imm =
) =>
let (ans, ans_setup) = transl_comp_expression(e);
(Imm.trap(~loc, ~env, ()), ans_setup @ [BSeq(ans)]);
| TExpApp(
{exp_desc: TExpIdent(_, _, {val_kind: TValPrim("@and")})},
[arg1, arg2],
) =>
transl_imm({...e, exp_desc: TExpPrim2(And, arg1, arg2)})
| TExpApp(
{exp_desc: TExpIdent(_, _, {val_kind: TValPrim("@or")})},
[arg1, arg2],
) =>
transl_imm({...e, exp_desc: TExpPrim2(Or, arg1, arg2)})
| TExpApp(
{exp_desc: TExpIdent(_, _, {val_kind: TValPrim("@not")})},
[arg],
) =>
transl_imm({...e, exp_desc: TExpPrim1(Not, arg)})
| TExpApp({exp_desc: TExpIdent(_, _, {val_kind: TValPrim(prim)})}, args) =>
Translprim.(
switch (PrimMap.find_opt(prim_map, prim), args) {
| (Some(Primitive1(prim)), [arg]) =>
transl_imm({...e, exp_desc: TExpPrim1(prim, arg)})
| (Some(Primitive2(prim)), [arg1, arg2]) =>
transl_imm({...e, exp_desc: TExpPrim2(prim, arg1, arg2)})
| (Some(PrimitiveN(prim)), args) =>
transl_imm({...e, exp_desc: TExpPrimN(prim, args)})
| (Some(_), _) => failwith("transl_imm: invalid primitive arity")
| (None, _) => failwith("transl_imm: unknown primitive")
}
)
| TExpApp(func, args) =>
let tmp = gensym("app");
let (new_func, func_setup) = transl_imm(func);
Expand Down Expand Up @@ -1072,26 +1070,20 @@ and transl_comp_expression =
Comp.imm(~attributes, ~allocation_type, ~env, Imm.trap(~loc, ~env, ())),
ans_setup @ [BSeq(ans)],
);
| TExpApp(
{exp_desc: TExpIdent(_, _, {val_kind: TValPrim("@assert")})},
[arg],
) =>
transl_comp_expression({...e, exp_desc: TExpPrim1(Assert, arg)})
| TExpApp(
{exp_desc: TExpIdent(_, _, {val_kind: TValPrim("@and")})},
[arg1, arg2],
) =>
transl_comp_expression({...e, exp_desc: TExpPrim2(And, arg1, arg2)})
| TExpApp(
{exp_desc: TExpIdent(_, _, {val_kind: TValPrim("@or")})},
[arg1, arg2],
) =>
transl_comp_expression({...e, exp_desc: TExpPrim2(Or, arg1, arg2)})
| TExpApp(
{exp_desc: TExpIdent(_, _, {val_kind: TValPrim("@not")})},
[arg],
) =>
transl_comp_expression({...e, exp_desc: TExpPrim1(Not, arg)})
| TExpApp({exp_desc: TExpIdent(_, _, {val_kind: TValPrim(prim)})}, args) =>
Translprim.(
switch (PrimMap.find_opt(prim_map, prim), args) {
| (Some(Primitive1(prim)), [arg]) =>
transl_comp_expression({...e, exp_desc: TExpPrim1(prim, arg)})
| (Some(Primitive2(prim)), [arg1, arg2]) =>
transl_comp_expression({...e, exp_desc: TExpPrim2(prim, arg1, arg2)})
| (Some(PrimitiveN(prim)), args) =>
transl_comp_expression({...e, exp_desc: TExpPrimN(prim, args)})
| (Some(_), _) =>
failwith("transl_comp_expression: invalid primitive arity")
| (None, _) => failwith("transl_comp_expression: unknown primitive")
}
)
| TExpApp(func, args) =>
let (new_func, func_setup) = transl_imm(func);
let (new_args, new_setup) = List.split(List.map(transl_imm, args));
Expand Down
19 changes: 2 additions & 17 deletions compiler/src/middle_end/optimize_inline_wasm.re
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// This module still exists to support the --no-bulk-memory flag.

open Anftree;
open Grain_typed;
open Analyze_inline_wasm;
Expand All @@ -7,28 +9,11 @@ module IWArg: Anf_mapper.MapArgument = {

let leave_comp_expression = ({comp_desc: desc} as c) => {
switch (desc) {
| CApp(({imm_desc: ImmId(id)}, _), [arg1], _)
when has_inline_wasm_type(id) =>
let prim1 =
switch (get_inline_wasm_type(id)) {
| WasmPrim1(prim1) => prim1
| _ => failwith("internal: WasmPrim1 was not found")
};
{...c, comp_desc: CPrim1(prim1, arg1)};
| CApp(({imm_desc: ImmId(id)}, _), [arg1, arg2], _)
when has_inline_wasm_type(id) =>
let prim2 =
switch (get_inline_wasm_type(id)) {
| WasmPrim2(prim2) => prim2
| _ => failwith("internal: WasmPrim2 was not found")
};
{...c, comp_desc: CPrim2(prim2, arg1, arg2)};
| CApp(({imm_desc: ImmId(id)}, _), args, _)
when has_inline_wasm_type(id) =>
let primn =
switch (get_inline_wasm_type(id)) {
| WasmPrimN(primn) => primn
| _ => failwith("internal: WasmPrimN was not found")
};
{...c, comp_desc: CPrimN(primn, args)};
| _ => c
Expand Down

0 comments on commit c227130

Please sign in to comment.