Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript][Relax] Use tir.SizeVar for shape variables #16949

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 19 additions & 3 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Struc
def get_symbolic_vars(self) -> Set[str]:
return {}

def get_symbolic_size_vars(self) -> Set[str]:
return self.get_symbolic_vars()

def asobject(self):
return self.as_struct_info(None)

Expand All @@ -172,9 +175,6 @@ class ObjectProxy(StructInfoProxy):
def __init__(self) -> None:
pass

def get_symbolic_vars(self) -> Set[str]:
return set()

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
return ObjectStructInfo()

Expand Down Expand Up @@ -327,6 +327,12 @@ def get_symbolic_vars(self) -> Set[str]:
else:
return set().union(*[p.get_symbolic_vars() for p in self.params])

def get_symbolic_size_vars(self) -> Set[str]:
if self.params is None:
return set()
else:
return set().union(*[p.get_symbolic_size_vars() for p in self.params])

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo:
if self.ret is None:
ret = None
Expand Down Expand Up @@ -377,6 +383,9 @@ def __init__(
def get_symbolic_vars(self) -> Set[str]:
return set().union(*[f.get_symbolic_vars() for f in self.fields])

def get_symbolic_size_vars(self) -> Set[str]:
return set().union(*[f.get_symbolic_size_vars() for f in self.fields])

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TupleStructInfo:
fields = [field.as_struct_info(dict_globals) for field in self.fields]
return TupleStructInfo(fields)
Expand Down Expand Up @@ -463,6 +472,13 @@ def get_symbolic_vars(self) -> Set[str]:
else:
return set()

def get_symbolic_size_vars(self) -> Set[str]:
# While variables defined by R.Shape and R.Tensor arguments
# are known to be non-negative, R.Prim arguments may be
# negative. Overriding the default implementation of
# `get_symbolic_size_vars()`
return set()

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
if self.value is None:
return PrimStructInfo(dtype=self.dtype)
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,23 @@ def is_recursive(node: doc.FunctionDef) -> bool:
def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None:
# Collect symbolic vars from parameters
symbolic_vars = set()
symbolic_size_vars = set()
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for function parameters.")
param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation)
symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars())
symbolic_size_vars.update(param_sinfo_proxy.get_symbolic_size_vars())

assert len(symbolic_size_vars - symbolic_vars) == 0, (
"Internal error: "
"All collected tir.SizeVar names must also appear in the list of tir.Var names"
)

# Define symbolic vars to the current var_table frame
for var_name in symbolic_vars:
self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False)
var_cls = tir.SizeVar if var_name in symbolic_size_vars else tir.Var
self.var_table.add(var_name, var_cls(var_name, "int64"), allow_shadowing=False)


@dispatch.register(token="relax", type_name="FunctionDef")
Expand Down
151 changes: 150 additions & 1 deletion tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2293,7 +2293,6 @@ def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]):
assert func.attrs is not None


@pytest.mark.xfail(reason="Bug: Implicit bounds not provided when parsing")
def test_function_symbolic_variables_are_annotated():
"""Symbolic variables must be exposed for struct inference

Expand All @@ -2317,5 +2316,155 @@ def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent-1"]):
tvm.ir.assert_structural_equal(inferred_sinfo, expected)


def test_symbolic_shape_variables_are_size_var():
"""Symbolic variables inferred from shapes are SizeVar

The indices in `R.strided_slice` follow Python's conventions for
negative indices. Absent any additional information, a slice
`arr[0:i]` would either have length `i` when `i >= 0`, or length
`len(arr) + i` when `i < 0`.

In this case, though, the dynamic `extent` variable is known to be
non-negative, because negative values may not be used as the
dimensions of `R.Tensor` or `R.Shape`. Because Relax struct
inference is performed while TVMScript is being parsed, this
constraint must be exposed during TVMScript parsing in order to
correctly infer the resulting StructInfo.

"""

@R.function(private=True)
def inferred_sinfo(A: R.Tensor(["extent"])):
extent = T.int64()
output = R.strided_slice(A, [0], [0], [extent])
return output

@R.function(private=True)
def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent"]):
extent = T.int64()
output: R.Tensor([extent]) = R.strided_slice(A, [0], [0], [extent])
return output

tvm.ir.assert_structural_equal(inferred_sinfo, expected)

assert isinstance(inferred_sinfo.params[0].struct_info.shape[0], tir.SizeVar)


def test_symbolic_variables_from_prim_value_may_be_negative():
"""Symbolic variables inferred from R.Prim are Var

Not all symbolic variables represent shapes. While a
`relax::PrimValue` can be the source of definition for a TIR
variable, a `relax::PrimValue` may not represent a shape, and may
be negative.

This test is similar to
`test_symbolic_shape_variables_are_size_var`, except that the
`extent` variable is defined by a `R.Prim` argument, and not by a
`R.Tensor` argument. As a result, we do not know whether `extent`
is negative, and cannot simplify expressions that depend on
`extent<0`.

"""

@R.function(private=True)
def inferred_sinfo(A: R.Tensor([16]), _: R.Prim(value="extent")):
extent = T.int64()
output = R.strided_slice(A, [0], [0], [extent])
return output

@R.function(private=True)
def expected(A: R.Tensor([16]), _: R.Prim(value="extent")):
extent = T.int64()
output: R.Tensor(
[T.min(T.max(T.if_then_else(extent < 0, extent + 16, extent), 0), 16)]
) = R.strided_slice(A, [0], [0], [extent])
return output

tvm.ir.assert_structural_equal(inferred_sinfo, expected)

assert not isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar)


def test_other_arguments_may_cause_prim_value_to_define_size_var():
"""Other arguments may cause R.Prim to hold SizeVar

This test is similar to
`test_symbolic_variables_from_prim_value_may_be_negative`, except
that `extent` also appears in a `R.Shape`. While the
`R.Prim(value="extent")` occurs first in the parameter list, and
is the source of definition, the presence of `extent` in `R.Shape`
parameter shows that it is a `SizeVar`.

"""

@R.function(private=True)
def inferred_sinfo(
A: R.Tensor([16]),
_prim: R.Prim(value="extent"),
_shape: R.Shape(
["extent"],
),
):
extent = T.int64()
output = R.strided_slice(A, [0], [0], [extent])
return output

@R.function(private=True)
def expected(
A: R.Tensor([16]),
_prim: R.Prim(value="extent"),
_shape: R.Shape(["extent"]),
):
extent = T.int64()
output: R.Tensor([T.min(extent, 16)]) = R.strided_slice(A, [0], [0], [extent])
return output

tvm.ir.assert_structural_equal(inferred_sinfo, expected)

assert isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar)


@pytest.mark.xfail(reason="Bug: Implicit bounds not provided when parsing")
def test_known_positive_expressions():
"""Expressions may be known as non-negative

The variable `N` is not defined as a shape variable, and may be
either positive or negative. However, the expression `N+16` is
used as the shape of a tensor, and is therefore known not to be
negative. Later use of the expression `N+16 < 0` may therefore be
simplified.

This test is currently marked as failing. When using
`relax::BlockBuilder::VisitWithNewScope` is provided with
parameters, it can mark shape expressions as non-negative, in
addition to individual variables. However, this is not currently
used for TVMScript parsing.

"""

@R.function(private=True)
def inferred_sinfo(
A: R.Tensor(["N + 16"]),
_: R.Prim(value="N"),
):
N = T.int64()
output = R.strided_slice(A, [0], [0], [N + 16])
return output

@R.function(private=True)
def expected(
A: R.Tensor(["N + 16"]),
_: R.Prim(value="N"),
):
N = T.int64()
output: R.Tensor([N + 16]) = R.strided_slice(A, [0], [0], [N + 16])
return output

tvm.ir.assert_structural_equal(inferred_sinfo, expected)

assert not isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar)


if __name__ == "__main__":
tvm.testing.main()