From 9a65fa2e8f1ea4e8e1a5a1da374735240518907c Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 22 Apr 2023 15:32:26 -0400 Subject: [PATCH] [Pass] Support inline function (#186) * . * . --- python/hidet/backend/codegen.py | 34 +--- python/hidet/ir/dialects/pattern.py | 7 +- python/hidet/ir/functors/layout_functor.py | 19 ++- python/hidet/ir/layout.py | 4 +- python/hidet/ir/tools/__init__.py | 3 +- python/hidet/ir/tools/printer.py | 109 ++++++------ python/hidet/ir/tools/rewriter.py | 81 +++++++++ python/hidet/ir/tools/util_functors.py | 20 +-- python/hidet/ir/type.py | 12 ++ python/hidet/ir/utils/call_graph.py | 4 +- python/hidet/lang/__init__.py | 2 +- python/hidet/transforms/__init__.py | 3 + python/hidet/transforms/explicit_unroll.py | 60 +------ .../hidet/transforms/flatten_tensor_index.py | 7 +- python/hidet/transforms/inline_function.py | 156 ++++++++++++++++++ .../hidet/transforms/rule_based_simplifier.py | 21 +-- 16 files changed, 360 insertions(+), 182 deletions(-) create mode 100644 python/hidet/ir/tools/rewriter.py create mode 100644 python/hidet/transforms/inline_function.py diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index cedb7052b..2edd600bf 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -16,7 +16,7 @@ from hidet.ir import dtypes from hidet.ir.type import DataType, PointerType, TensorPointerType, ReferenceType, TensorType, FuncType from hidet.ir.type import VoidType -from hidet.ir.expr import Var, Expr, Add, Sub, Multiply, Div, Mod, FloorDiv, LessThan, Neg, NotEqual, Equal, LogicalAnd +from hidet.ir.expr import Var, Add, Sub, Multiply, Div, Mod, FloorDiv, LessThan, Neg, NotEqual, Equal, LogicalAnd from hidet.ir.expr import LogicalOr, LogicalNot, BitwiseAnd, BitwiseOr, BitwiseXor, BitwiseNot, LeftShift, RightShift from hidet.ir.expr import IfThenElse, Cast, Address, Reference, Dereference, Call, Let, Constant, TensorSlice, convert from hidet.ir.expr import TensorElement @@ -47,6 +47,9 @@ def __init__(self): def canonize_funcname(name: str): return 'hidet_' + name.replace('.', '_') + def scalar_literal(self, value, dtype: DataType): + raise NotImplementedError() + def param_declare(self, v: Var): v_type = v.type name_doc = self(v) @@ -290,28 +293,9 @@ def visit_Call(self, e: Call): func = self.ir_module.lookup(func_name) func_name = Text(self.canonize_funcname(func_name)) if func.kind == 'cuda_kernel': - assert False - if isinstance(func.attrs['cuda_block_dim'], int) and func.attrs['cuda_block_dim'] > 1024: - raise ValueError('CUDA block dimension cannot be larger than 1024.') - - def dim3_str(dims): - if isinstance(dims, (int, Expr)): - return self(dims) - else: - return Text('dim3(') + self(dims) + ')' - - configs = [ - dim3_str(func.attrs['cuda_grid_dim']), # grid dimension - dim3_str(func.attrs['cuda_block_dim']), # block dimension - func.attrs.get('cuda_dynamic_smem_bytes', 0), # dynamic shared memory size - # cuda stream (get_cuda_stream() function is defined in hidet/runtime.h) - '(cudaStream_t)get_cuda_stream()', - ] - launch_config = Text('<<<') + doc_join([self(v) for v in configs], sep=', ') + Text('>>>') - else: - launch_config = [] + raise RuntimeError('Call to cuda kernel should be lowered to LaunchKernelStmt.') param_doc = Text('(') + doc_join([self(arg) for arg in e.args], Text(', ')) + ')' - return func_name + launch_config + param_doc + return func_name + param_doc elif is_primitive_function(func_name): entry = lookup_primitive_function(func_name) if entry.function is not None: @@ -538,8 +522,7 @@ def visit_AnyExpr(self, e: AnyExpr): class CUDACodegen(Codegen): # pylint: disable=abstract-method - @staticmethod - def scalar_literal(value, dtype: DataType): + def scalar_literal(self, value, dtype: DataType): if dtype == dtypes.boolean: ret = 'true' if value else 'false' elif dtype == dtypes.float64: @@ -680,8 +663,7 @@ def visit_IRModule(self, module: IRModule) -> Doc: class CPUCodegen(Codegen): # pylint: disable=abstract-method - @staticmethod - def scalar_literal(value, dtype: DataType): + def scalar_literal(self, value, dtype: DataType): if dtype == dtypes.boolean: ret = 'true' if value else 'false' elif dtype == dtypes.float64: diff --git a/python/hidet/ir/dialects/pattern.py b/python/hidet/ir/dialects/pattern.py index 375a65c23..b2941993e 100644 --- a/python/hidet/ir/dialects/pattern.py +++ b/python/hidet/ir/dialects/pattern.py @@ -14,6 +14,7 @@ from hidet.ir.node import Node from hidet.ir.type import TypeNode, DataType, TensorType, FuncType, data_type from hidet.ir.expr import Expr, Constant, Add, Sub, Multiply, Div, Mod, FloorDiv, LessThan, Equal, LessEqual +from hidet.ir.expr import BitwiseXor from hidet.ir.expr import TensorElement, IfThenElse, Call, Var, LogicalAnd, LogicalOr, BinaryOp, convert, var from hidet.ir.compute import TensorNode, ScalarNode, ReduceOperation, ReduceCompute from hidet.ir.stmt import DeclareScope @@ -122,7 +123,10 @@ def __enter__(self): self.dispatch[DataType](self.matcher, self.pattern, self.target) else: # noinspection PyArgumentList - self.dispatch[self.pattern.__class__](self.matcher, self.pattern, self.target) + dispatched = self.dispatch.get(self.pattern.__class__, None) + if dispatched is None: + raise NotImplementedError(f'Pattern {self.pattern} is not implemented') + dispatched(self.matcher, self.pattern, self.target) except NotMatchedError as e: # error from current del self.matched[self.pattern] @@ -163,6 +167,7 @@ def dispatch_table(): Multiply: PatternMatcher.match_CommutativeBinary, Div: PatternMatcher.match_Binary, Mod: PatternMatcher.match_Binary, + BitwiseXor: PatternMatcher.match_Binary, FloorDiv: PatternMatcher.match_Binary, LessThan: PatternMatcher.match_Binary, Equal: PatternMatcher.match_Binary, diff --git a/python/hidet/ir/functors/layout_functor.py b/python/hidet/ir/functors/layout_functor.py index cf378820e..87433101f 100644 --- a/python/hidet/ir/functors/layout_functor.py +++ b/python/hidet/ir/functors/layout_functor.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from hidet.ir.layout import DataLayout, StridesLayout, LocalLayout, ComposedLayout, SwizzleLayout +from hidet.ir.layout import DataLayout, StridesLayout, LocalLayout, ComposedLayout, SwizzleLayout, ConcatLayout from hidet.utils import same_list from .base_functor import BaseFunctor, BaseVisitor, BaseRewriter @@ -25,6 +25,8 @@ def visit_dispatch(self, node): return self.visit_ComposedLayout(node) elif isinstance(node, SwizzleLayout): return self.visit_SwizzleLayout(node) + elif isinstance(node, ConcatLayout): + return self.visit_ConcatLayout(node) else: raise ValueError('Can not recognize layout {}'.format(node)) else: @@ -42,6 +44,9 @@ def visit_ComposedLayout(self, layout: ComposedLayout): def visit_SwizzleLayout(self, layout: SwizzleLayout): raise NotImplementedError() + def visit_ConcatLayout(self, layout: ConcatLayout): + raise NotImplementedError() + class LayoutVisitor(BaseVisitor, LayoutFunctor): def visit_StridesLayout(self, layout: StridesLayout): @@ -61,6 +66,10 @@ def visit_SwizzleLayout(self, layout: SwizzleLayout): self.visit(layout.shape) self.visit(layout.size) + def visit_ConcatLayout(self, layout: ConcatLayout): + self.visit(layout.lhs) + self.visit(layout.rhs) + class LayoutRewriter(BaseRewriter, LayoutFunctor): def visit_StridesLayout(self, layout: StridesLayout): @@ -93,3 +102,11 @@ def visit_SwizzleLayout(self, layout: SwizzleLayout): return layout else: return SwizzleLayout(base, layout.dim, layout.regards_dim, layout.log_step) + + def visit_ConcatLayout(self, layout: ConcatLayout): + lhs = self.visit(layout.lhs) + rhs = self.visit(layout.rhs) + if lhs is layout.lhs and rhs is layout.rhs: + return layout + else: + return ConcatLayout(lhs, rhs) diff --git a/python/hidet/ir/layout.py b/python/hidet/ir/layout.py index ab8c74619..583ff6baa 100644 --- a/python/hidet/ir/layout.py +++ b/python/hidet/ir/layout.py @@ -156,7 +156,7 @@ def product(outer, inner): def concat(lhs, rhs): lhs = to_data_layout(lhs) rhs = to_data_layout(rhs) - return ConcatDataLayout(lhs, rhs) + return ConcatLayout(lhs, rhs) @staticmethod def local(shape: Sequence[Int]): @@ -416,7 +416,7 @@ def global2cond(self, *args: Int) -> Bool: return LogicalAnd(self.outer.within_bound(*outer_args), self.inner.within_bound(*inner_args)) -class ConcatDataLayout(DataLayout): +class ConcatLayout(DataLayout): def __init__(self, lhs: DataLayout, rhs: DataLayout): super().__init__(shape=list(lhs.shape) + list(rhs.shape), size=lhs.size * rhs.size) self.lhs = lhs diff --git a/python/hidet/ir/tools/__init__.py b/python/hidet/ir/tools/__init__.py index 6562ecc67..740ece429 100644 --- a/python/hidet/ir/tools/__init__.py +++ b/python/hidet/ir/tools/__init__.py @@ -10,7 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from .type_infer import infer_type, TypeInfer -from .util_functors import rewrite, collect, clone +from .util_functors import collect, clone +from .rewriter import rewrite from .free_var_collector import collect_free_vars from .printer import astext from .simplifier import simplify, simplify_to_int diff --git a/python/hidet/ir/tools/printer.py b/python/hidet/ir/tools/printer.py index 0debec8f2..a989f980b 100644 --- a/python/hidet/ir/tools/printer.py +++ b/python/hidet/ir/tools/printer.py @@ -29,10 +29,11 @@ ContinueStmt, ) from hidet.ir.stmt import BreakStmt, DeclareScope, LaunchKernelStmt +from hidet.ir.layout import StridesLayout, ConcatLayout, LocalLayout, SwizzleLayout, ComposedLayout, RowMajorLayout +from hidet.ir.layout import ColumnMajorLayout from hidet.ir.mapping import RepeatTaskMapping, SpatialTaskMapping, ComposedTaskMapping from hidet.ir.compute import TensorNode, GridCompute, ArgReduceCompute, ReduceCompute, TensorInput, ScalarInput from hidet.ir.dialects.pattern import AnyExpr -from hidet.ir.layout import RowMajorLayout, ColumnMajorLayout from hidet.ir.task import Task from hidet.utils import same_list from hidet.utils.doc import Doc, NewLine, Text, doc_join @@ -40,6 +41,8 @@ from hidet.ir.functors import IRFunctor +_show_var_id = False + class IRPrinter(IRFunctor): def __init__(self): @@ -67,29 +70,25 @@ def visit_PyConstant(self, c: Union[str, int, float, None]): def visit_Function(self, func: Function): self.namer.clear() - doc = Doc() # parameters - doc += 'fn(' - param_docs = [] + head_doc = Doc() + head_doc += Text('def ') + func.name + '(' for i, param in enumerate(func.params): - line = [] - if i != 0: - line.append(NewLine()) - line.extend([self(param), ': ', self(param.type)]) - param_docs.append(line) - doc += doc_join(param_docs, Text(', ')) - doc += ')' - doc = doc.indent(3) + head_doc += (NewLine() + self(param) + ': ' + self(param.type)).indent(4) + if i < len(func.params) - 1: + head_doc += ',' + head_doc += NewLine() + ')' # attributes + attr_doc = Doc() for attr_name, attr_value in func.attrs.items(): - doc += (NewLine() + '# {}: {}'.format(attr_name, attr_value)).indent(4) + attr_doc += (NewLine() + '# {}: {}'.format(attr_name, attr_value)).indent(4) # body - doc += self(func.body).indent(4) + body_doc = self(func.body).indent(4) - return doc + return head_doc + attr_doc + body_doc + NewLine() def visit_IRModule(self, ir_module: IRModule): doc = Doc() @@ -97,8 +96,8 @@ def visit_IRModule(self, ir_module: IRModule): if ir_module.task is not None: doc += self(ir_module.task) doc += NewLine() - for name, func in ir_module.functions.items(): - doc += ['def ', name, ' ', self(func), NewLine(), NewLine()] + for func in ir_module.functions.values(): + doc += self(func) + NewLine() return doc def visit_Add(self, e: Add): @@ -216,6 +215,8 @@ def visit_Address(self, e: Address): return Text('&') + self(e.expr) def visit_Var(self, e: Var): + if _show_var_id: + return Text('{}@{}'.format(self.namer.get_name(e), e.id)) return Text(self.namer.get_name(e)) def visit_Constant(self, e: Constant): @@ -365,25 +366,23 @@ def visit_SeqStmt(self, stmt: SeqStmt): def visit_ScalarType(self, t: DataType): return Text('{}'.format(t.name)) - def visit_TensorType(self, t: TensorType): + def _tensor_type(self, t: TensorType): items = [self(t.dtype), '[' + self(t.shape) + ']'] - if isinstance(t.layout, RowMajorLayout): + if isinstance(t.layout, RowMajorLayout) or t.layout is None: # default layout, do not print pass - elif isinstance(t.layout, ColumnMajorLayout): - items.append(Text('col_major')) - elif t.layout is None: - # skip None - pass else: - items.append(Text(type(t.layout).__name__)) - return Text('tensor(') + doc_join(items, ', ') + ')' + items.append(self(t.layout)) + return doc_join(items, ', ') + + def visit_TensorType(self, t: TensorType): + return Text('tensor(') + self._tensor_type(t) + ')' def visit_PointerType(self, t: PointerType): return Text('PointerType(') + self(t.base_type) + ')' def visit_TensorPointerType(self, t: TensorPointerType): - return Text('TensorPointerType(') + self(t.tensor_type) + ')' + return Text('tensor_pointer(') + self._tensor_type(t.tensor_type) + ')' def visit_ReferenceType(self, t: ReferenceType): return Text('ReferenceType(') + self(t.base_type) + ')' @@ -429,8 +428,6 @@ def visit_Task(self, e: Task): Text('computations: ') + self.print_tensor_nodes(e.outputs).indent(), Text('attributes: {') + self({k: str(v) for k, v in e.attrs.items()}) + '}', ] - # if len(e.task_graph.nodes) > 1: - # lines.append(Text('task_graph: ') + self(e.task_graph)) front_part = doc_join(lines, NewLine()) inverse_map_doc = Doc() if e.inverse_map: @@ -440,37 +437,6 @@ def visit_Task(self, e: Task): inverse_map_doc += (NewLine() + self.namer.get_name(tensor) + ': ' + inverse_map_body).indent() return Text('Task(') + (NewLine() + front_part + inverse_map_doc).indent() + NewLine() + ')' - # def visit_TaskGraph(self, task_graph: TaskGraph): - # head = Text('TaskGraph(') + self(task_graph.input_tensors) + ') {' - # body = [] - # for task in task_graph.nodes: - # arg_items = [] - # for task_input in task.inputs: - # if task_input in task_graph.consume: - # arg_items.append(self(task_input) + '=' + self(task_graph.consume[task_input])) - # else: - # arg_items.append(self(task_input)) - # for name, value in task.attributes.items(): - # arg_items.append(self(name) + '=' + self(str(value))) - # args = doc_join(arg_items, ', ') - # assign_line = self(task.outputs) + ' = ' + task.name + '(' + args + ')' - # if task is task_graph.anchor: - # assign_line = assign_line + ' [anchor]' - # if task is task_graph.anchor: - # compute_body = Doc() - # else: - # compute_body = self.print_tensor_nodes(task.outputs, exclude_nodes=task.inputs).indent() - # body.append(assign_line + compute_body) - # - # body.append( - # 'return ' - # + self([task_graph.consume[v] if v in task_graph.consume else v for v in task_graph.output_tensors]) - # ) - # - # body = (NewLine() + doc_join(body, NewLine())).indent() - # tail = NewLine() + '}' - # return head + body + tail - def visit_TensorNode(self, e: TensorNode): return self.namer.get_name(e) @@ -506,6 +472,29 @@ def visit_RepeatTaskMapping(self, mapping: RepeatTaskMapping): def visit_ComposedTaskMapping(self, mapping: ComposedTaskMapping): return self(mapping.outer) + '.' + self(mapping.inner) + def visit_StridesLayout(self, layout: StridesLayout): + if isinstance(layout, RowMajorLayout): + return Text('row(') + self(layout.shape) + ')' + elif isinstance(layout, ColumnMajorLayout): + return Text('column(') + self(layout.shape) + ')' + else: + return Text('strides(') + self(layout.strides) + ')' + + def visit_SwizzleLayout(self, layout: SwizzleLayout): + items = [self(layout.base), Text('dim=') + self(layout.dim), Text('regards=') + self(layout.regards_dim)] + if layout.log_step != 0: + items.append(Text('log_step=') + self(layout.log_step)) + return Text('swizzle(') + doc_join(items, ', ') + ')' + + def visit_LocalLayout(self, layout: LocalLayout): + return Text('local(') + self(layout.shape) + ')' + + def visit_ComposedLayout(self, layout: ComposedLayout): + return self(layout.outer) + ' * ' + self(layout.inner) + + def visit_ConcatLayout(self, layout: ConcatLayout): + return Text('concat(') + self(layout.lhs) + ', ' + self(layout.rhs) + ')' + def astext(obj: Node) -> str: if isinstance(obj, Node): diff --git a/python/hidet/ir/tools/rewriter.py b/python/hidet/ir/tools/rewriter.py new file mode 100644 index 000000000..6e2885326 --- /dev/null +++ b/python/hidet/ir/tools/rewriter.py @@ -0,0 +1,81 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Union, Mapping + +from hidet.ir.expr import Let, Var +from hidet.ir.functors import IRRewriter +from hidet.ir.node import Node +from hidet.ir.stmt import ForMappingStmt, DeclareStmt, ForStmt +from hidet.ir.stmt import LetStmt + + +class MapBasedRewriter(IRRewriter): + def __init__(self, rmap): + super().__init__() + self.memo.update(rmap) + + +class CloneRewriter(IRRewriter): + """ + A rewriter that will create a new var for each statement/expr that will declare vars + """ + + def __init__(self, remap: Dict[Node, Node]): + super().__init__() + self.memo.update(remap) + + def process_var(self, v: Var): + visited_v = self.visit(v) + if visited_v is v: + new_var = Var(v.hint, type=v.type, name=v.name) + else: + new_var = visited_v + self.memo[v] = new_var + return new_var + + def visit_ForStmt(self, stmt: ForStmt): + loop_var = self.process_var(stmt.loop_var) + extent = self.visit(stmt.extent) + body = self.visit(stmt.body) + return ForStmt(loop_var, extent, body, attr=stmt.attr) + + def visit_ForTaskStmt(self, stmt: ForMappingStmt): + loop_vars: List[Var] = [self.process_var(v) for v in stmt.loop_vars] + worker = self.visit(stmt.worker) + body = self.visit(stmt.body) + return ForMappingStmt(loop_vars=loop_vars, mapping=stmt.mapping, worker=worker, body=body) + + def visit_LetStmt(self, stmt: LetStmt): + bind_vars = [self.process_var(v) for v in stmt.bind_vars] + bind_values = [self.visit(bind_value) for bind_value in stmt.bind_values] + body = self.visit(stmt.body) + return LetStmt(bind_vars, bind_values, body) + + def visit_DeclareStmt(self, stmt: DeclareStmt): + v = self.process_var(stmt.var) + init = self.visit(stmt.init) if stmt.init is not None else None + return DeclareStmt(v, init, stmt.is_static) + + def visit_Let(self, e: Let): + v = self.process_var(e.var) + value = self.visit(e.value) + body = self.visit(e.body) + return Let(v, value, body) + + +def rewrite(node: Union[Node, tuple, list, dict], rewrite_map: Mapping[Node, Node], clone_internal_var=False): + assert isinstance(rewrite_map, dict) + if clone_internal_var: + rewriter = CloneRewriter(rewrite_map) + else: + rewriter = MapBasedRewriter(rewrite_map) + return rewriter.rewrite(node) diff --git a/python/hidet/ir/tools/util_functors.py b/python/hidet/ir/tools/util_functors.py index 22eca2558..821ab2223 100644 --- a/python/hidet/ir/tools/util_functors.py +++ b/python/hidet/ir/tools/util_functors.py @@ -9,18 +9,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Mapping -from hidet.ir.type import TypeNode +from typing import Union + from hidet.ir.expr import Let, Var, Expr from hidet.ir.func import Function -from hidet.ir.stmt import Stmt, LetStmt from hidet.ir.functors import IRVisitor, IRRewriter - - -class MapBasedRewriter(IRRewriter): - def __init__(self, rmap): - super().__init__() - self.memo.update(rmap) +from hidet.ir.stmt import Stmt, LetStmt class IRCollector(IRVisitor): @@ -67,14 +61,6 @@ def visit_Let(self, e: Let): return Let(v, self(e.value), self(e.body)) -def rewrite( - node: Union[Function, Expr, Stmt, TypeNode, tuple, list], rewrite_map: Mapping[Union[Stmt, Expr], Union[Stmt, Expr]] -): - assert isinstance(rewrite_map, dict) - rewriter = MapBasedRewriter(rewrite_map) - return rewriter.rewrite(node) - - def collect(node: Union[Function, Expr, Stmt, list, tuple], node_types, stop_when_found=False) -> list: """ Collect sub-nodes in given node with specific types. diff --git a/python/hidet/ir/type.py b/python/hidet/ir/type.py index 6b02bc86e..2503ae751 100644 --- a/python/hidet/ir/type.py +++ b/python/hidet/ir/type.py @@ -32,6 +32,18 @@ def __invert__(self) -> TypeNode: else: raise ValueError('Can not recognize type {}'.format(self)) + def is_void(self): + return isinstance(self, VoidType) + + def is_tensor(self): + return isinstance(self, TensorType) + + def is_pointer(self): + return isinstance(self, (PointerType, TensorPointerType)) + + def is_data_type(self): + return isinstance(self, DataType) + class DataType(TypeNode): """ diff --git a/python/hidet/ir/utils/call_graph.py b/python/hidet/ir/utils/call_graph.py index e6ac4c13b..9659b4f72 100644 --- a/python/hidet/ir/utils/call_graph.py +++ b/python/hidet/ir/utils/call_graph.py @@ -32,7 +32,7 @@ def add_callee(self, callee): class CallGraph: - def __init__(self, ir_module: IRModule): + def __init__(self, ir_module: IRModule, allow_missing: bool = False): # pylint: disable=import-outside-toplevel from hidet.ir.primitives import is_primitive_function, lookup_primitive_function @@ -54,6 +54,8 @@ def __init__(self, ir_module: IRModule): entry = lookup_primitive_function(call.func_var.hint) if entry.function is not None: name = call.func_var.hint + if name not in ir_module.functions and allow_missing: + continue callee = ir_module.lookup(name) else: continue diff --git a/python/hidet/lang/__init__.py b/python/hidet/lang/__init__.py index a60d4f6ad..5dd14389f 100644 --- a/python/hidet/lang/__init__.py +++ b/python/hidet/lang/__init__.py @@ -11,7 +11,7 @@ # limitations under the License. from typing import Union, Sequence, Optional, List from hidet.ir.type import TypeNode, DataType, TensorType, PointerType, VoidType, ReferenceType, void_p, data_type -from hidet.ir.expr import Expr, Var, cast, view, Dereference +from hidet.ir.expr import Expr, Var, cast, view, address, Dereference from hidet.ir.mapping import row_spatial, row_repeat, col_repeat, col_spatial, TaskMapping, auto_map from hidet.ir.layout import DataLayout from hidet.ir.primitives import printf diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index 16e8b61a8..8ee7adf3c 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -22,6 +22,7 @@ from .simplify_stmt import simplify_stmt_pass from .expand_let_expr import expand_let_expr_pass from .resolve_generic_primitive_function import resolve_primitive_func_pass +from .inline_function import inline_function_pass from .add_explicit_cast import add_explicit_cast_pass from .inline_let_stmt import inline_let_stmt_pass from .rule_based_simplifier import rule_based_simplify_pass @@ -46,6 +47,7 @@ def lower(ir_module: IRModule) -> IRModule: explicit_unroll_pass(), flatten_tensor_index_pass(), lower_special_cast_pass(), + inline_function_pass(), resolve_primitive_func_pass(), import_primitive_functions_pass(), resolve_primitive_func_pass(), @@ -57,6 +59,7 @@ def lower(ir_module: IRModule) -> IRModule: expand_let_expr_pass(), inline_let_stmt_pass(inline_all=False), rule_based_simplify_pass(), + inline_let_stmt_pass(inline_all=False), simplify_stmt_pass(), ] diff --git a/python/hidet/transforms/explicit_unroll.py b/python/hidet/transforms/explicit_unroll.py index 0e60213bc..2fde3b8f9 100644 --- a/python/hidet/transforms/explicit_unroll.py +++ b/python/hidet/transforms/explicit_unroll.py @@ -9,68 +9,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union, Dict +from typing import List, Union -from hidet.ir.node import Node from hidet.ir.dtypes import int32 -from hidet.ir.expr import Constant, Let, Var -from hidet.ir.stmt import ForMappingStmt, LetStmt, DeclareStmt, Stmt, ForStmt, Expr, SeqStmt +from hidet.ir.expr import Constant from hidet.ir.functors import IRRewriter -from hidet.ir.tools import simplify +from hidet.ir.stmt import Stmt, ForStmt, Expr, SeqStmt +from hidet.ir.tools import simplify, rewrite from hidet.transforms.base import Pass, FunctionBodyPass Int = Union[Expr, int] TaskIndex = List[Int] -class CloneRewriter(IRRewriter): - """ - A rewriter that will create a new var for each statement/expr that will declare vars - """ - - def __init__(self, remap: Dict[Node, Node]): - super().__init__() - self.memo.update(remap) - - def process_var(self, v: Var): - visited_v = self.visit(v) - if visited_v is v: - new_var = Var(v.hint, type=v.type, name=v.name) - else: - new_var = visited_v - self.memo[v] = new_var - return new_var - - def visit_ForStmt(self, stmt: ForStmt): - loop_var = self.process_var(stmt.loop_var) - extent = self.visit(stmt.extent) - body = self.visit(stmt.body) - return ForStmt(loop_var, extent, body, attr=stmt.attr) - - def visit_ForTaskStmt(self, stmt: ForMappingStmt): - loop_vars: List[Var] = [self.process_var(v) for v in stmt.loop_vars] - worker = self.visit(stmt.worker) - body = self.visit(stmt.body) - return ForMappingStmt(loop_vars=loop_vars, mapping=stmt.mapping, worker=worker, body=body) - - def visit_LetStmt(self, stmt: LetStmt): - bind_vars = [self.process_var(v) for v in stmt.bind_vars] - bind_values = [self.visit(bind_value) for bind_value in stmt.bind_values] - body = self.visit(stmt.body) - return LetStmt(bind_vars, bind_values, body) - - def visit_DeclareStmt(self, stmt: DeclareStmt): - v = self.process_var(stmt.var) - init = self.visit(stmt.init) if stmt.init is not None else None - return DeclareStmt(v, init, stmt.is_static) - - def visit_Let(self, e: Let): - v = self.process_var(e.var) - value = self.visit(e.value) - body = self.visit(e.body) - return Let(v, value, body) - - class ExplicitUnrollRewriter(IRRewriter): def visit_ForStmt(self, stmt: ForStmt): if stmt.attr.unroll and stmt.attr.explicit_unroll: @@ -86,8 +37,7 @@ def visit_ForStmt(self, stmt: ForStmt): seq: List[Stmt] = [] for i in range(extent_int): - clone_rewriter = CloneRewriter(remap={stmt.loop_var: int32(i)}) - seq.append(clone_rewriter(body)) + seq.append(rewrite(body, {stmt.loop_var: int32(i)}, clone_internal_var=True)) if len(seq) == 1: return seq[0] else: diff --git a/python/hidet/transforms/flatten_tensor_index.py b/python/hidet/transforms/flatten_tensor_index.py index 8c2a5454b..196832b68 100644 --- a/python/hidet/transforms/flatten_tensor_index.py +++ b/python/hidet/transforms/flatten_tensor_index.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from hidet.ir.type import TensorType, tensor_type, PointerType, TensorPointerType +from hidet.ir.type import TensorType, tensor_type, tensor_pointer_type, PointerType, TensorPointerType from hidet.ir.expr import Var, TensorElement, TensorSlice, Constant from hidet.ir.stmt import BufferStoreStmt, DeclareStmt from hidet.ir.func import Function @@ -28,9 +28,10 @@ def visit_Function(self, func: Function): for var in func.params: if isinstance(var.type, TensorType): size = simplify(var.type.layout.size) - self.memo[var] = Var(var.hint, tensor_type(var.type.dtype, [size], DataLayout.row_major([size]))) + self.memo[var] = Var(var.hint, tensor_pointer_type(var.type.dtype, [size])) elif isinstance(var.type, TensorPointerType): - self.memo[var] = var + size = simplify(var.type.tensor_type.layout.size) + self.memo[var] = Var(var.hint, tensor_pointer_type(var.type.tensor_type.dtype, [size])) body = self(func.body) params = [self(p) for p in func.params] return Function( diff --git a/python/hidet/transforms/inline_function.py b/python/hidet/transforms/inline_function.py new file mode 100644 index 000000000..c8ba531d4 --- /dev/null +++ b/python/hidet/transforms/inline_function.py @@ -0,0 +1,156 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Dict, Set + +from hidet.ir.expr import Call, Var, Expr +from hidet.ir.func import IRModule, Function +from hidet.ir.functors import IRRewriter +from hidet.ir.stmt import Stmt, SeqStmt, DeclareStmt, EvaluateStmt, ReturnStmt +from hidet.ir.tools import collect, rewrite +from hidet.ir.type import ReferenceType, TensorType +from hidet.ir.utils.call_graph import CallGraph, CallGraphNode +from hidet.transforms import Pass + + +class InlineFunctionRewriter(IRRewriter): + def __init__(self, updated_ir_module: IRModule): + super().__init__(use_memo=False) + self.ir_module = updated_ir_module + self.stmts: List[Stmt] = [] + self.should_inline_cache: Dict[Function, bool] = {} + + def should_inline(self, callee: Function): + """ + Check if a function should be inlined. + + Currently, we only inline functions that + 1. have no return value + 2. have no return statement + 3. have no reference type and tensor type arguments + + Parameters + ---------- + callee: Function + The function to be checked + + Returns + ------- + ret: bool + True if the function should be inlined + """ + if callee in self.should_inline_cache: + return self.should_inline_cache[callee] + + if not callee.ret_type.is_void(): + ret = False + elif callee.kind in ['packed_func', 'host_kernel', 'cuda_kernel']: + ret = False + elif any(isinstance(arg.type, (ReferenceType, TensorType)) for arg in callee.params): + ret = False + elif len(collect(callee.body, ReturnStmt)) > 0: + ret = False + else: + ret = True + self.should_inline_cache[callee] = ret + return ret + + def inline(self, caller: Function) -> Function: + return self.visit(caller) + + def visit(self, node): + if isinstance(node, Stmt): + ret = super().visit(node) + if len(self.stmts) > 0: + # the inlined statements should be inserted before the current statement + ret = SeqStmt(self.stmts + [ret]) + self.stmts.clear() + return ret + else: + return super().visit(node) + + def visit_Call(self, e: Call): + if e.func_var.hint not in self.ir_module.functions: + # primitive function that has not been imported yet + return super().visit_Call(e) + + callee: Function = self.ir_module.functions[e.func_var.hint] + if self.should_inline(callee): + assert len(e.args) == len(callee.params) + args: List[Expr] = [self.visit(arg) for arg in e.args] + param_vars: List[Var] = [] + remap: Dict[Var, Expr] = {} + for arg, param in zip(args, callee.params): + param_var = Var(param.hint, rewrite(param.type, remap, clone_internal_var=True)) + param_vars.append(param_var) + self.stmts.append(DeclareStmt(param_var, init=arg)) + remap[param] = param_var + callee_body = rewrite(callee.body, remap, clone_internal_var=True) + self.stmts.append(callee_body) + return None + else: + return super().visit_Call(e) + + def visit_EvaluateStmt(self, stmt: EvaluateStmt): + expr = self.visit(stmt.expr) + if expr is None: + return SeqStmt([]) + else: + if expr is stmt.expr: + return stmt + else: + return EvaluateStmt(expr) + + +def inline_callees(caller: Function, updated_ir_module: IRModule) -> Function: + rewriter = InlineFunctionRewriter(updated_ir_module) + return rewriter.inline(caller) + + +class PruneUnusedFunctionRewriter(IRRewriter): + def visit_IRModule(self, module: IRModule): + call_graph = CallGraph(module, allow_missing=True) + unused_func_names: Set[str] = set() + for node in call_graph.nodes: + func: Function = node.func + if func.kind in ['packed_func', 'host_kernel', 'cuda_kernel']: + continue + if len(node.callers) == 0: + unused_func_names.add(func.name) + for func_name in unused_func_names: + del module.functions[func_name] + if func_name in module.global_vars: + del module.global_vars[func_name] + + return module + + +def prune_unused_functions(ir_module: IRModule): + rewriter = PruneUnusedFunctionRewriter() + return rewriter.visit(ir_module) + + +class InlineFunctionPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + call_graph = CallGraph(ir_module, allow_missing=True) + updated_ir_module = IRModule(task=ir_module.task) + for node in call_graph.reversed_order: + assert isinstance(node, CallGraphNode) + func = inline_callees(node.func, updated_ir_module) + updated_ir_module.functions[func.name] = func + + updated_ir_module = prune_unused_functions(updated_ir_module) + + return updated_ir_module + + +def inline_function_pass(): + return InlineFunctionPass() diff --git a/python/hidet/transforms/rule_based_simplifier.py b/python/hidet/transforms/rule_based_simplifier.py index f1f830768..21e2be98f 100644 --- a/python/hidet/transforms/rule_based_simplifier.py +++ b/python/hidet/transforms/rule_based_simplifier.py @@ -14,20 +14,8 @@ from itertools import product from hidet.ir.dialects.pattern import AnyExpr, match -from hidet.ir.expr import ( - Add, - convert, - Sub, - Multiply, - Mod, - LessThan, - LessEqual, - Equal, - BinaryOp, - LogicalAnd, - IfThenElse, - LogicalOr, -) +from hidet.ir.expr import Add, convert, Sub, Multiply, Mod, LessThan, LessEqual, Equal, BinaryOp, LogicalAnd, IfThenElse +from hidet.ir.expr import LogicalOr, BitwiseXor, BitwiseAnd, BitwiseOr, BitwiseNot from hidet.ir.expr import Div, Constant, Expr from hidet.ir.functors import IRRewriter from hidet.ir.tools import rewrite @@ -61,6 +49,10 @@ class ConstExprSimplifier(IRRewriter): Sub: operator.sub, Multiply: operator.mul, Div: c_div, + BitwiseOr: operator.or_, + BitwiseAnd: operator.and_, + BitwiseXor: operator.xor, + BitwiseNot: operator.invert, Mod: operator.mod, LessThan: operator.lt, LessEqual: operator.le, @@ -117,6 +109,7 @@ def __init__(self): (e1 * one, e1), (e1 * zero, zero), (e1 // one, e1), + (e1 ^ zero, e1), # add ((c1 + e1) + e2, (e1 + e2) + c1), ((e1 + c1) + c2, e1 + (c1 + c2)),