Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Add support for destructuring collapsing reshapes in symbolic tiles. #67802

Merged
merged 1 commit into from
May 17, 2024

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented May 17, 2024

[XLA:GPU] Add support for destructuring collapsing reshapes in symbolic tiles.

Take a reshape [6,8] reshape([48]). The indexing map going through this
reshape from output to input would be
    (d0, d1) -> (d0 * 8 + d1).
Previously to this change, there was no support for extracting sizes and
strides from a multivariate expression.

Deriving a size_map for the indexing map above is as simple as creating a map
    (s0, s1) -> (s0 * s1),
which is easy enough. However, deriving a stride_map for the expression is
complicated. Conceptually, the stride of the composite expression should
correspond to the stride of the minormost dimension involved in the reshape
along which we capture more than a single element.

This holds because there are restrictions on what a valid tiling can be when
going through a collapse. Let s be an n-dimensional shape that is fully
collapsed. In order to be propagated successfully through the collapse, the
pattern of the tiling of s has to look like
    (1*, partial_dim?, full_dims*, 1*)
where full_dims are dimensions that are captured completely, and
partial_dim is a dimension that can be captured with an arbitrary tile.
This restriction is necessary to ensure that the gap between two elements
captured in the expression is always the same (i.e., the set of elements can
be described using a single stride and is thus a tile).

Based on the above, algorithm for extracting the stride could therefore be
represented as a series of nested if statements
    if size0 != 1 then stride0 else (if size1 != 1 then stride1 else ...)
where {size,stride}i corresponds to the i-th major {size,stride}.

We implement a utility function that allows us to generate if statements as
affine expressions.

…ic tiles.

Take a reshape `[6,8] reshape([48])`. The indexing map going through this\
reshape from output to input would be\
    `(d0, d1) -> (d0 * 8 + d1)`.\
Previously to this change, there was no support for extracting sizes and\
strides from a multivariate expression.

Deriving a `size_map` for the indexing map above is as simple as creating a map\
    `(s0, s1) -> (s0 * s1)`,\
which is easy enough. However, deriving a `stride_map` for the expression is\
complicated. Conceptually, the stride of the composite expression should\
correspond to the stride of the minormost dimension involved in the reshape\
along which we capture more than a single element.

This holds because there are restrictions on what a valid tiling can be when\
going through a collapse. Let `s` be an `n`-dimensional shape that is fully\
collapsed. In order to be propagated successfully through the collapse, the\
pattern of the tiling of `s` has to look like\
    `(1*, partial_dim?, full_dims*, 1*)`\
where `full_dims` are dimensions that are captured completely, and\
`partial_dim` is a dimension that can be captured with an arbitrary tile.\
This restriction is necessary to ensure that the gap between two elements\
captured in the expression is always the same (i.e., the set of elements can\
be described using a single `stride` and is thus a tile).

Based on the above, algorithm for extracting the `stride` could therefore be\
represented as a series of nested `if` statements\
    `if size0 != 1 then stride0 else (if size1 != 1 then stride1 else ...)`\
where `{size,stride}i` corresponds to the `i-th` major {size,stride}.

We implement a utility function that allows us to generate `if` statements as\
affine expressions.

PiperOrigin-RevId: 634784163
@copybara-service copybara-service bot merged commit 2a6c0d6 into master May 17, 2024
2 checks passed
@copybara-service copybara-service bot deleted the exported_pr_618164952 branch May 17, 2024 15:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant