[XLA:GPU] Add support for destructuring collapsing reshapes in symbolic tiles. #67802
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[XLA:GPU] Add support for destructuring collapsing reshapes in symbolic tiles.
Take a reshape
[6,8] reshape([48])
. The indexing map going through thisreshape from output to input would be
(d0, d1) -> (d0 * 8 + d1)
.Previously to this change, there was no support for extracting sizes and
strides from a multivariate expression.
Deriving a
size_map
for the indexing map above is as simple as creating a map(s0, s1) -> (s0 * s1)
,which is easy enough. However, deriving a
stride_map
for the expression iscomplicated. Conceptually, the stride of the composite expression should
correspond to the stride of the minormost dimension involved in the reshape
along which we capture more than a single element.
This holds because there are restrictions on what a valid tiling can be when
going through a collapse. Let
s
be ann
-dimensional shape that is fullycollapsed. In order to be propagated successfully through the collapse, the
pattern of the tiling of
s
has to look like(1*, partial_dim?, full_dims*, 1*)
where
full_dims
are dimensions that are captured completely, andpartial_dim
is a dimension that can be captured with an arbitrary tile.This restriction is necessary to ensure that the gap between two elements
captured in the expression is always the same (i.e., the set of elements can
be described using a single
stride
and is thus a tile).Based on the above, algorithm for extracting the
stride
could therefore berepresented as a series of nested
if
statementsif size0 != 1 then stride0 else (if size1 != 1 then stride1 else ...)
where
{size,stride}i
corresponds to thei-th
major {size,stride}.We implement a utility function that allows us to generate
if
statements asaffine expressions.