Skip to content

Commit

Permalink
WIP Removing hardcoded type size and circuits for operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Gustavo2622 committed Apr 18, 2024
1 parent d3ec6d7 commit 278ee27
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 109 deletions.
3 changes: 2 additions & 1 deletion libs/lospecs/circuit_avx2.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ module FromSpec () : S = struct
spec

let func_from_spec (f: symbol) : (reg list) -> reg =
(fun regs -> Circuit_spec.circuit_of_spec regs (List.assoc f specs))
let fname = List.assoc f specs in
(fun regs -> Circuit_spec.circuit_of_spec regs fname)

(* ------------------------------------------------------------------ *)
let vpermd = List.assoc "VPERMD" specs
Expand Down
218 changes: 122 additions & 96 deletions src/ecBDep.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module Map = Batteries.Map
module Hashtbl = Batteries.Hashtbl
module Set = Batteries.Set

(* CIRCUIT DEFINITIONS: *)
(* -------------------------------------------------------------------- *)
module C = struct
include Lospecs.Aig
Expand All @@ -31,6 +32,11 @@ let zpad (n: int) (r: C.reg) =
List.append r (List.init (n - (List.length r)) (fun _ -> C.false_))
else r

let symbol_of_qsymbol ((h,t): qsymbol) : symbol =
let v = List.fold_left (fun acc s -> Format.sprintf "%s.%s" acc s) (List.hd h) (List.tl h) in
Format.sprintf "%s.%s" v t

(* TYPE DEFINITIONS: *)
(* -------------------------------------------------------------------- *)
type width = int

Expand All @@ -46,14 +52,18 @@ and vsymbol =
and brhs =
| Const of width * zint
| Copy of vsymbol
| Op of symbol * bargs
| Op of qsymbol * bargs

and barg =
| Const of width * zint
| Var of vsymbol

and bargs = barg list

(* -------------------------------------------------------------------- *)
exception BDepError

(* PRETTY PRINTING: *)
(* -------------------------------------------------------------------- *)
let pp_barg (fmt : Format.formatter) (b : barg) =
match b with
Expand All @@ -73,7 +83,8 @@ let pp_brhs (fmt : Format.formatter) (rhs : brhs) =
Format.fprintf fmt "%s@%d" x w

| Op (op, args) ->
Format.fprintf fmt "%s%a"
Format.fprintf fmt "%a %a"
(fun fmt q -> List.iter (Format.fprintf fmt "%s.") (fst q); Format.fprintf fmt "%s\n" (snd q))
op
(Format.pp_print_list
(fun fmt a -> Format.fprintf fmt "@ %a" pp_barg a))
Expand All @@ -87,48 +98,7 @@ let pp_bstmt (fmt : Format.formatter) (((x, w), rhs) : bstmt) =
let pp_bprgm (fmt : Format.formatter) (bprgm : bprgm) =
List.iter (Format.fprintf fmt "%a;@." pp_bstmt) bprgm

(* -------------------------------------------------------------------- *)
let register_of_barg (env : env) (arg : barg) : C.reg =
match arg with
| Const (w, i) ->
C.of_bigint ~size:w (EcBigInt.to_zt i)

| Var (x, i) ->
Circ.lookup_circ x env

(* -------------------------------------------------------------------- *)
let registers_of_bargs (env : env) (args : bargs) : C.reg list =
List.map (register_of_barg env) args

(* -------------------------------------------------------------------- *)
let circuit_of_bstmt (env : env) (((v, s), rhs) : bstmt) : env * C.reg =
let r =
match rhs with
| Const (w, i) ->
C.of_bigint ~size:w (EcBigInt.to_zt i)

| Copy (x, w) -> Circ.lookup_circ x env

| Op (op, args) -> try
begin
match op with
| "OPP_8" -> C.opp (args |> registers_of_bargs env |> List.hd) (* FIXME: Needs to be in spec *)
| _ ->
args |> registers_of_bargs env |> (C.func_from_spec op)
end
with Not_found -> Format.eprintf "op %s not found@." op; assert false
in

let env = Circ.bind_circ v r env in

(env, r)

(* -------------------------------------------------------------------- *)
let circuit_from_bprgm (env: env) (prg : bprgm) =
List.fold_left_map circuit_of_bstmt env prg

(* -------------------------------------------------------------------- *)
(* FIXME : Fix printing later *)
(* NOT SO PRETTY PRINTING: *)

let print_deps ~name (env : env) (r : C.reg) =
let deps = C.deps r in
Expand Down Expand Up @@ -170,13 +140,78 @@ let print_deps_alt ~name (r : C.reg) =
deps
) deps


(* FIXME ^ fix above printing *)

(* -------------------------------------------------------------------- *)
let print_deps_ric (env : env) (r : symbol) =
let circ = Circ.lookup_circ r env in
print_deps env circ ~name:r



(* CITCUIT CONSTRUCTION AUXILIARY FUNCTIONS: *)
let trans_wtype (ty: ty) (env: env): width =
match (EcEnv.Ty.hnorm ty env).ty_node with
| Tconstr (p, []) -> begin
let q = EcPath.toqsymbol p in
match EcEnv.Circ.lookup_bitstring q env with
| Some w -> w
| None -> Format.eprintf "Unknown type: ";
List.iter (Format.eprintf "%s.") (fst q);
Format.eprintf "%s\n" (snd q);
raise BDepError
end
| _ -> Format.eprintf "Unsupported type variant\n"; raise BDepError


(* CIRCUIT CONSTRUCTION FUNCTIONS: *)
(* -------------------------------------------------------------------- *)
let register_of_barg (env : env) (arg : barg) : C.reg =
match arg with
| Const (w, i) ->
C.of_bigint ~size:w (EcBigInt.to_zt i)

| Var (x, i) ->
Circ.lookup_circ x env

(* -------------------------------------------------------------------- *)
let registers_of_bargs (env : env) (args : bargs) : C.reg list =
List.map (register_of_barg env) args

(* -------------------------------------------------------------------- *)
let circuit_of_bstmt (env : env) (((v, s), rhs) : bstmt) : env * C.reg =
let r =
match rhs with
| Const (w, i) ->
C.of_bigint ~size:w (EcBigInt.to_zt i)

| Copy (x, w) -> Circ.lookup_circ x env

| Op (op, args) ->
begin
match EcEnv.Circ.lookup_op op env with
| Some op -> args |> registers_of_bargs env |> op
| None -> Format.eprintf "Unregistered circuit for operator: %s@." (symbol_of_qsymbol op);
raise BDepError
end
(* try begin
match op with
| "OPP_8" -> C.opp (args |> registers_of_bargs env |> List.hd) (* FIXME: Needs to be in spec *)
| _ ->
args |> registers_of_bargs env |> (C.func_from_spec op)
end (* FIXME : add EcEnv op compatibility *)
with Not_found -> Format.eprintf "op %s not found@." op; assert false *)
in

let env = Circ.bind_circ v r env in

(env, r)

(* -------------------------------------------------------------------- *)
let circuit_from_bprgm (env: env) (prg : bprgm) =
List.fold_left_map circuit_of_bstmt env prg

(* -------------------------------------------------------------------- *)
let circ_dep_split (r : C.reg) : C.reg list =
let deps = C.deps r in
Expand Down Expand Up @@ -351,12 +386,11 @@ let bruteforce_equiv (r1 : C.reg) (r2 : C.reg) (range: int) : bool =
if res1 = res2 then true
else (Format.eprintf "i: %d | r1: %d | r2: %d@." i res1 res2; false)) |> Enum.fold (&&) true

(* -------------------------------------------------------------------- *)
exception BDepError

(* -------------------------------------------------------------------- *)
let decode_op (p : path) : symbol =
match EcPath.toqsymbol p with
let decode_op (p : path) : qsymbol =
EcPath.toqsymbol p
(* match EcPath.toqsymbol p with
| ["Top"; "JWord"; "W16u16"], ("VPSUB_16u16" as op)
| ["Top"; "JWord"; "W16u16"], ("VPSRA_16u16" as op)
| ["Top"; "JWord"; "W16u16"], ("VPADD_16u16" as op)
Expand Down Expand Up @@ -387,29 +421,14 @@ let decode_op (p : path) : symbol =
| _ ->
Format.eprintf "%s@." (EcPath.tostring p);
raise BDepError
raise BDepError *)



let rec circuit_of_form (env: env) (f : EcAst.form) : C.reg =
let trans_wtype (ty : ty) : width =
match (EcEnv.Ty.hnorm ty env).ty_node with
| Tconstr (p, []) -> begin
match EcPath.toqsymbol p with
| (["Top"; "JWord"; "W256"], "t") -> 256
| (["Top"; "JWord"; "W128"], "t") -> 128
| (["Top"; "JWord"; "W64" ], "t") -> 64
| (["Top"; "JWord"; "W32" ], "t") -> 32
| (["Top"; "JWord"; "W16" ], "t") -> 16
| (["Top"; "JWord"; "W8" ], "t") -> 8
| (["Top"; "Pervasive"], "int") -> 256
(* DEBUG PRINT V
| (qs, q) -> List.iter (Format.eprintf "%s ") qs; Format.eprintf "@. %s@." q; raise BDepError*)
| _ -> raise BDepError
end

| _ ->
raise BDepError in
trans_wtype ty env
in

let trans_jops (pth: qsymbol) : C.reg list -> C.reg =
(* TODO: Check if we need regs to be of correct size or not (semi-done) *)
Expand Down Expand Up @@ -496,9 +515,12 @@ let rec circuit_of_form (env: env) (f : EcAst.form) : C.reg =
| [a; b] -> [C.ugt b a]
| _ -> raise BDepError
end
| _ -> List.iter (Format.eprintf "%s ") (fst pth);
| q -> begin match EcEnv.Circ.lookup_op q env with
| Some op -> op
| None -> List.iter (Format.eprintf "%s ") (fst pth);
Format.eprintf "%s@.Not implemented yet@." (snd pth);
raise BDepError
end
in

match f.f_node with
Expand Down Expand Up @@ -582,26 +604,43 @@ and int_of_form (env: env) (f: EcAst.form) : int =
trans_jops (EcPath.toqsymbol pth) fs_c
| _ -> failwith "Cant apply to non op"
end

| _ -> failwith "Form cannot be converted to int"


(* EXPORTED FUNCTIONS *)
(* -------------------------------------------------------------------- *)
let bind_circuit (env: env) (op: psymbol) (c: string) : env =
(* check input types *)
(* check output types *)
(* match widths *)
let c = C.func_from_spec c in
assert false (* bind op to c *)
let c = begin
try C.func_from_spec c
with Not_found -> Format.eprintf "Unknown circuit\n"; raise BDepError
end in
let (op, t) = EcEnv.Op.lookup ([], op.pl_desc) env in
let fmt = EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env) in
let rec uncurry_ty (ty: ty) : ty list =
match ty.ty_node with
| Tfun (t1, t2) -> (uncurry_ty t1) @ (uncurry_ty t2)
| _ -> [ty]
in
let tys = List.rev (uncurry_ty t.op_ty) in
let in_tys = (List.tl tys |> List.rev) in
let out_ty = (List.hd tys) in
begin
Format.eprintf "Out: %d bits@." (trans_wtype out_ty env);
Format.eprintf "In:@.";
List.iter (Format.eprintf "%d bits@.") (List.map (fun t -> trans_wtype t env) in_tys);
(* check input types *)
(* check output types *)
(* match widths *)
EcEnv.Circ.bind_op (EcPath.toqsymbol op) c env
end
(* update and return scope *)



(* Maybe change this to EcCommands *)
(* FIXME might need changing to be able to actually query it when needed *)
let bind_bitstring (env: env) (tq: pqsymbol) (w: width) : env =
assert false
(* add binding to env *)
(* update scope *)
(* return scope *)
let q = tq.pl_desc in
EcEnv.Circ.bind_bitstring q w env

(* -------------------------------------------------------------------- *)
let bdep (env : env) (p : pgamepath) (f: psymbol) (n : int) (m : int) (vs : string list) (pcond: psymbol) : unit =
Expand Down Expand Up @@ -654,23 +693,10 @@ let bdep (env : env) (p : pgamepath) (f: psymbol) (n : int) (m : int) (vs : stri
| (["Top"; "JWord"; "W8" ], "of_int") -> 8
| _ -> raise BDepError in

let trans_wtype (ty : ty) : width =
match (EcEnv.Ty.hnorm ty env).ty_node with
| Tconstr (p, []) -> begin
match EcPath.toqsymbol p with
| (["Top"; "JWord"; "W256"], "t") -> 256
| (["Top"; "JWord"; "W128"], "t") -> 128
| (["Top"; "JWord"; "W64" ], "t") -> 64
| (["Top"; "JWord"; "W32" ], "t") -> 32
| (["Top"; "JWord"; "W16" ], "t") -> 16
| (["Top"; "JWord"; "W8" ], "t") -> 8
| (["Top"; "Pervasive"], "int") -> 256
| _ -> raise BDepError
end

| _ ->
raise BDepError in

let trans_wtype (ty : ty) : width =
trans_wtype ty env
in

let trans_arg (e : expr) : barg =
match e.e_node with
| Evar (PVloc y) ->
Expand Down
8 changes: 5 additions & 3 deletions src/ecCommands.ml
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,13 @@ and process_bdep (scope : EcScope.scope) ((p, f, n, m, vs, pc) : pgamepath * psy

and process_bind_bitstring (scope : EcScope.scope) (tq: pqsymbol) (w: int) =
let env = EcBDep.bind_bitstring (EcScope.env scope) tq w
in assert false (* update and return scope *)

in EcScope.Circ.update_env env scope
and process_bind_circuit (scope: EcScope.scope) (op: psymbol) (c: string) =
let env = EcBDep.bind_circuit (EcScope.env scope) op c in
{scope with sc_env=EcScope.initial env} (* update and return scope, how? *)
EcScope.Circ.update_env env scope


(* -------------------------------------------------------------------- *)
and process (ld : Loader.loader) (scope : EcScope.scope) g =
let loc = g.pl_loc in
Expand Down

0 comments on commit 278ee27

Please sign in to comment.