Skip to content

Commit

Permalink
[Pass] Support inline function (#186)
Browse files Browse the repository at this point in the history
* .

* .
  • Loading branch information
yaoyaoding committed Apr 22, 2023
1 parent f361211 commit 9a65fa2
Show file tree
Hide file tree
Showing 16 changed files with 360 additions and 182 deletions.
34 changes: 8 additions & 26 deletions python/hidet/backend/codegen.py
Expand Up @@ -16,7 +16,7 @@
from hidet.ir import dtypes
from hidet.ir.type import DataType, PointerType, TensorPointerType, ReferenceType, TensorType, FuncType
from hidet.ir.type import VoidType
from hidet.ir.expr import Var, Expr, Add, Sub, Multiply, Div, Mod, FloorDiv, LessThan, Neg, NotEqual, Equal, LogicalAnd
from hidet.ir.expr import Var, Add, Sub, Multiply, Div, Mod, FloorDiv, LessThan, Neg, NotEqual, Equal, LogicalAnd
from hidet.ir.expr import LogicalOr, LogicalNot, BitwiseAnd, BitwiseOr, BitwiseXor, BitwiseNot, LeftShift, RightShift
from hidet.ir.expr import IfThenElse, Cast, Address, Reference, Dereference, Call, Let, Constant, TensorSlice, convert
from hidet.ir.expr import TensorElement
Expand Down Expand Up @@ -47,6 +47,9 @@ def __init__(self):
def canonize_funcname(name: str):
return 'hidet_' + name.replace('.', '_')

def scalar_literal(self, value, dtype: DataType):
raise NotImplementedError()

def param_declare(self, v: Var):
v_type = v.type
name_doc = self(v)
Expand Down Expand Up @@ -290,28 +293,9 @@ def visit_Call(self, e: Call):
func = self.ir_module.lookup(func_name)
func_name = Text(self.canonize_funcname(func_name))
if func.kind == 'cuda_kernel':
assert False
if isinstance(func.attrs['cuda_block_dim'], int) and func.attrs['cuda_block_dim'] > 1024:
raise ValueError('CUDA block dimension cannot be larger than 1024.')

def dim3_str(dims):
if isinstance(dims, (int, Expr)):
return self(dims)
else:
return Text('dim3(') + self(dims) + ')'

configs = [
dim3_str(func.attrs['cuda_grid_dim']), # grid dimension
dim3_str(func.attrs['cuda_block_dim']), # block dimension
func.attrs.get('cuda_dynamic_smem_bytes', 0), # dynamic shared memory size
# cuda stream (get_cuda_stream() function is defined in hidet/runtime.h)
'(cudaStream_t)get_cuda_stream()',
]
launch_config = Text('<<<') + doc_join([self(v) for v in configs], sep=', ') + Text('>>>')
else:
launch_config = []
raise RuntimeError('Call to cuda kernel should be lowered to LaunchKernelStmt.')
param_doc = Text('(') + doc_join([self(arg) for arg in e.args], Text(', ')) + ')'
return func_name + launch_config + param_doc
return func_name + param_doc
elif is_primitive_function(func_name):
entry = lookup_primitive_function(func_name)
if entry.function is not None:
Expand Down Expand Up @@ -538,8 +522,7 @@ def visit_AnyExpr(self, e: AnyExpr):
class CUDACodegen(Codegen):
# pylint: disable=abstract-method

@staticmethod
def scalar_literal(value, dtype: DataType):
def scalar_literal(self, value, dtype: DataType):
if dtype == dtypes.boolean:
ret = 'true' if value else 'false'
elif dtype == dtypes.float64:
Expand Down Expand Up @@ -680,8 +663,7 @@ def visit_IRModule(self, module: IRModule) -> Doc:
class CPUCodegen(Codegen):
# pylint: disable=abstract-method

@staticmethod
def scalar_literal(value, dtype: DataType):
def scalar_literal(self, value, dtype: DataType):
if dtype == dtypes.boolean:
ret = 'true' if value else 'false'
elif dtype == dtypes.float64:
Expand Down
7 changes: 6 additions & 1 deletion python/hidet/ir/dialects/pattern.py
Expand Up @@ -14,6 +14,7 @@
from hidet.ir.node import Node
from hidet.ir.type import TypeNode, DataType, TensorType, FuncType, data_type
from hidet.ir.expr import Expr, Constant, Add, Sub, Multiply, Div, Mod, FloorDiv, LessThan, Equal, LessEqual
from hidet.ir.expr import BitwiseXor
from hidet.ir.expr import TensorElement, IfThenElse, Call, Var, LogicalAnd, LogicalOr, BinaryOp, convert, var
from hidet.ir.compute import TensorNode, ScalarNode, ReduceOperation, ReduceCompute
from hidet.ir.stmt import DeclareScope
Expand Down Expand Up @@ -122,7 +123,10 @@ def __enter__(self):
self.dispatch[DataType](self.matcher, self.pattern, self.target)
else:
# noinspection PyArgumentList
self.dispatch[self.pattern.__class__](self.matcher, self.pattern, self.target)
dispatched = self.dispatch.get(self.pattern.__class__, None)
if dispatched is None:
raise NotImplementedError(f'Pattern {self.pattern} is not implemented')
dispatched(self.matcher, self.pattern, self.target)
except NotMatchedError as e:
# error from current <pattern, target>
del self.matched[self.pattern]
Expand Down Expand Up @@ -163,6 +167,7 @@ def dispatch_table():
Multiply: PatternMatcher.match_CommutativeBinary,
Div: PatternMatcher.match_Binary,
Mod: PatternMatcher.match_Binary,
BitwiseXor: PatternMatcher.match_Binary,
FloorDiv: PatternMatcher.match_Binary,
LessThan: PatternMatcher.match_Binary,
Equal: PatternMatcher.match_Binary,
Expand Down
19 changes: 18 additions & 1 deletion python/hidet/ir/functors/layout_functor.py
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.layout import DataLayout, StridesLayout, LocalLayout, ComposedLayout, SwizzleLayout
from hidet.ir.layout import DataLayout, StridesLayout, LocalLayout, ComposedLayout, SwizzleLayout, ConcatLayout
from hidet.utils import same_list
from .base_functor import BaseFunctor, BaseVisitor, BaseRewriter

Expand All @@ -25,6 +25,8 @@ def visit_dispatch(self, node):
return self.visit_ComposedLayout(node)
elif isinstance(node, SwizzleLayout):
return self.visit_SwizzleLayout(node)
elif isinstance(node, ConcatLayout):
return self.visit_ConcatLayout(node)
else:
raise ValueError('Can not recognize layout {}'.format(node))
else:
Expand All @@ -42,6 +44,9 @@ def visit_ComposedLayout(self, layout: ComposedLayout):
def visit_SwizzleLayout(self, layout: SwizzleLayout):
raise NotImplementedError()

def visit_ConcatLayout(self, layout: ConcatLayout):
raise NotImplementedError()


class LayoutVisitor(BaseVisitor, LayoutFunctor):
def visit_StridesLayout(self, layout: StridesLayout):
Expand All @@ -61,6 +66,10 @@ def visit_SwizzleLayout(self, layout: SwizzleLayout):
self.visit(layout.shape)
self.visit(layout.size)

def visit_ConcatLayout(self, layout: ConcatLayout):
self.visit(layout.lhs)
self.visit(layout.rhs)


class LayoutRewriter(BaseRewriter, LayoutFunctor):
def visit_StridesLayout(self, layout: StridesLayout):
Expand Down Expand Up @@ -93,3 +102,11 @@ def visit_SwizzleLayout(self, layout: SwizzleLayout):
return layout
else:
return SwizzleLayout(base, layout.dim, layout.regards_dim, layout.log_step)

def visit_ConcatLayout(self, layout: ConcatLayout):
lhs = self.visit(layout.lhs)
rhs = self.visit(layout.rhs)
if lhs is layout.lhs and rhs is layout.rhs:
return layout
else:
return ConcatLayout(lhs, rhs)
4 changes: 2 additions & 2 deletions python/hidet/ir/layout.py
Expand Up @@ -156,7 +156,7 @@ def product(outer, inner):
def concat(lhs, rhs):
lhs = to_data_layout(lhs)
rhs = to_data_layout(rhs)
return ConcatDataLayout(lhs, rhs)
return ConcatLayout(lhs, rhs)

@staticmethod
def local(shape: Sequence[Int]):
Expand Down Expand Up @@ -416,7 +416,7 @@ def global2cond(self, *args: Int) -> Bool:
return LogicalAnd(self.outer.within_bound(*outer_args), self.inner.within_bound(*inner_args))


class ConcatDataLayout(DataLayout):
class ConcatLayout(DataLayout):
def __init__(self, lhs: DataLayout, rhs: DataLayout):
super().__init__(shape=list(lhs.shape) + list(rhs.shape), size=lhs.size * rhs.size)
self.lhs = lhs
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/ir/tools/__init__.py
Expand Up @@ -10,7 +10,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .type_infer import infer_type, TypeInfer
from .util_functors import rewrite, collect, clone
from .util_functors import collect, clone
from .rewriter import rewrite
from .free_var_collector import collect_free_vars
from .printer import astext
from .simplifier import simplify, simplify_to_int
Expand Down
109 changes: 49 additions & 60 deletions python/hidet/ir/tools/printer.py
Expand Up @@ -29,17 +29,20 @@
ContinueStmt,
)
from hidet.ir.stmt import BreakStmt, DeclareScope, LaunchKernelStmt
from hidet.ir.layout import StridesLayout, ConcatLayout, LocalLayout, SwizzleLayout, ComposedLayout, RowMajorLayout
from hidet.ir.layout import ColumnMajorLayout
from hidet.ir.mapping import RepeatTaskMapping, SpatialTaskMapping, ComposedTaskMapping
from hidet.ir.compute import TensorNode, GridCompute, ArgReduceCompute, ReduceCompute, TensorInput, ScalarInput
from hidet.ir.dialects.pattern import AnyExpr
from hidet.ir.layout import RowMajorLayout, ColumnMajorLayout
from hidet.ir.task import Task
from hidet.utils import same_list
from hidet.utils.doc import Doc, NewLine, Text, doc_join
from hidet.utils.namer import Namer

from hidet.ir.functors import IRFunctor

_show_var_id = False


class IRPrinter(IRFunctor):
def __init__(self):
Expand Down Expand Up @@ -67,38 +70,34 @@ def visit_PyConstant(self, c: Union[str, int, float, None]):

def visit_Function(self, func: Function):
self.namer.clear()
doc = Doc()

# parameters
doc += 'fn('
param_docs = []
head_doc = Doc()
head_doc += Text('def ') + func.name + '('
for i, param in enumerate(func.params):
line = []
if i != 0:
line.append(NewLine())
line.extend([self(param), ': ', self(param.type)])
param_docs.append(line)
doc += doc_join(param_docs, Text(', '))
doc += ')'
doc = doc.indent(3)
head_doc += (NewLine() + self(param) + ': ' + self(param.type)).indent(4)
if i < len(func.params) - 1:
head_doc += ','
head_doc += NewLine() + ')'

# attributes
attr_doc = Doc()
for attr_name, attr_value in func.attrs.items():
doc += (NewLine() + '# {}: {}'.format(attr_name, attr_value)).indent(4)
attr_doc += (NewLine() + '# {}: {}'.format(attr_name, attr_value)).indent(4)

# body
doc += self(func.body).indent(4)
body_doc = self(func.body).indent(4)

return doc
return head_doc + attr_doc + body_doc + NewLine()

def visit_IRModule(self, ir_module: IRModule):
doc = Doc()
self.ir_module = ir_module
if ir_module.task is not None:
doc += self(ir_module.task)
doc += NewLine()
for name, func in ir_module.functions.items():
doc += ['def ', name, ' ', self(func), NewLine(), NewLine()]
for func in ir_module.functions.values():
doc += self(func) + NewLine()
return doc

def visit_Add(self, e: Add):
Expand Down Expand Up @@ -216,6 +215,8 @@ def visit_Address(self, e: Address):
return Text('&') + self(e.expr)

def visit_Var(self, e: Var):
if _show_var_id:
return Text('{}@{}'.format(self.namer.get_name(e), e.id))
return Text(self.namer.get_name(e))

def visit_Constant(self, e: Constant):
Expand Down Expand Up @@ -365,25 +366,23 @@ def visit_SeqStmt(self, stmt: SeqStmt):
def visit_ScalarType(self, t: DataType):
return Text('{}'.format(t.name))

def visit_TensorType(self, t: TensorType):
def _tensor_type(self, t: TensorType):
items = [self(t.dtype), '[' + self(t.shape) + ']']
if isinstance(t.layout, RowMajorLayout):
if isinstance(t.layout, RowMajorLayout) or t.layout is None:
# default layout, do not print
pass
elif isinstance(t.layout, ColumnMajorLayout):
items.append(Text('col_major'))
elif t.layout is None:
# skip None
pass
else:
items.append(Text(type(t.layout).__name__))
return Text('tensor(') + doc_join(items, ', ') + ')'
items.append(self(t.layout))
return doc_join(items, ', ')

def visit_TensorType(self, t: TensorType):
return Text('tensor(') + self._tensor_type(t) + ')'

def visit_PointerType(self, t: PointerType):
return Text('PointerType(') + self(t.base_type) + ')'

def visit_TensorPointerType(self, t: TensorPointerType):
return Text('TensorPointerType(') + self(t.tensor_type) + ')'
return Text('tensor_pointer(') + self._tensor_type(t.tensor_type) + ')'

def visit_ReferenceType(self, t: ReferenceType):
return Text('ReferenceType(') + self(t.base_type) + ')'
Expand Down Expand Up @@ -429,8 +428,6 @@ def visit_Task(self, e: Task):
Text('computations: ') + self.print_tensor_nodes(e.outputs).indent(),
Text('attributes: {') + self({k: str(v) for k, v in e.attrs.items()}) + '}',
]
# if len(e.task_graph.nodes) > 1:
# lines.append(Text('task_graph: ') + self(e.task_graph))
front_part = doc_join(lines, NewLine())
inverse_map_doc = Doc()
if e.inverse_map:
Expand All @@ -440,37 +437,6 @@ def visit_Task(self, e: Task):
inverse_map_doc += (NewLine() + self.namer.get_name(tensor) + ': ' + inverse_map_body).indent()
return Text('Task(') + (NewLine() + front_part + inverse_map_doc).indent() + NewLine() + ')'

# def visit_TaskGraph(self, task_graph: TaskGraph):
# head = Text('TaskGraph(') + self(task_graph.input_tensors) + ') {'
# body = []
# for task in task_graph.nodes:
# arg_items = []
# for task_input in task.inputs:
# if task_input in task_graph.consume:
# arg_items.append(self(task_input) + '=' + self(task_graph.consume[task_input]))
# else:
# arg_items.append(self(task_input))
# for name, value in task.attributes.items():
# arg_items.append(self(name) + '=' + self(str(value)))
# args = doc_join(arg_items, ', ')
# assign_line = self(task.outputs) + ' = ' + task.name + '(' + args + ')'
# if task is task_graph.anchor:
# assign_line = assign_line + ' [anchor]'
# if task is task_graph.anchor:
# compute_body = Doc()
# else:
# compute_body = self.print_tensor_nodes(task.outputs, exclude_nodes=task.inputs).indent()
# body.append(assign_line + compute_body)
#
# body.append(
# 'return '
# + self([task_graph.consume[v] if v in task_graph.consume else v for v in task_graph.output_tensors])
# )
#
# body = (NewLine() + doc_join(body, NewLine())).indent()
# tail = NewLine() + '}'
# return head + body + tail

def visit_TensorNode(self, e: TensorNode):
return self.namer.get_name(e)

Expand Down Expand Up @@ -506,6 +472,29 @@ def visit_RepeatTaskMapping(self, mapping: RepeatTaskMapping):
def visit_ComposedTaskMapping(self, mapping: ComposedTaskMapping):
return self(mapping.outer) + '.' + self(mapping.inner)

def visit_StridesLayout(self, layout: StridesLayout):
if isinstance(layout, RowMajorLayout):
return Text('row(') + self(layout.shape) + ')'
elif isinstance(layout, ColumnMajorLayout):
return Text('column(') + self(layout.shape) + ')'
else:
return Text('strides(') + self(layout.strides) + ')'

def visit_SwizzleLayout(self, layout: SwizzleLayout):
items = [self(layout.base), Text('dim=') + self(layout.dim), Text('regards=') + self(layout.regards_dim)]
if layout.log_step != 0:
items.append(Text('log_step=') + self(layout.log_step))
return Text('swizzle(') + doc_join(items, ', ') + ')'

def visit_LocalLayout(self, layout: LocalLayout):
return Text('local(') + self(layout.shape) + ')'

def visit_ComposedLayout(self, layout: ComposedLayout):
return self(layout.outer) + ' * ' + self(layout.inner)

def visit_ConcatLayout(self, layout: ConcatLayout):
return Text('concat(') + self(layout.lhs) + ', ' + self(layout.rhs) + ')'


def astext(obj: Node) -> str:
if isinstance(obj, Node):
Expand Down

0 comments on commit 9a65fa2

Please sign in to comment.