Skip to content

Commit

Permalink
Adds rewrite patterns for arith.{cmpi,select} and tensor.splat as…
Browse files Browse the repository at this point in the history
… sources to a vector.transfer_read op.

PiperOrigin-RevId: 628561147
  • Loading branch information
jax authors committed Apr 27, 2024
1 parent 0b5f3f8 commit d9b7535
Showing 1 changed file with 142 additions and 23 deletions.
165 changes: 142 additions & 23 deletions jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <utility>

#include "absl/algorithm/container.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h"
Expand Down Expand Up @@ -60,9 +61,9 @@ struct VectorizationPattern
}
};

template <typename DefiningOp>
LogicalResult matchAndRewriteTransferOfExpandOrCollapseShape(
vector::TransferReadOp op, PatternRewriter &rewriter) {
// Check preconditions for `vector.transfer_read` rewrite patterns.
LogicalResult checkPreconditions(vector::TransferReadOp op,
PatternRewriter &rewriter) {
if (op.hasOutOfBoundsDim()) {
return rewriter.notifyMatchFailure(op, "out of bounds transfer dim");
}
Expand All @@ -72,6 +73,39 @@ LogicalResult matchAndRewriteTransferOfExpandOrCollapseShape(
if (!op.getPermutationMap().isIdentity()) {
return rewriter.notifyMatchFailure(op, "non identity permutation map");
}
SmallVector<Value> indices = {op.getIndices().begin(), op.getIndices().end()};
if (absl::c_any_of(
indices, [](Value index) { return !isConstantIntValue(index, 0); })) {
return rewriter.notifyMatchFailure(op, "non zero indices");
}
return success();
}

// Create a `vector.transfer_read` based on the original |op|, which succeeds
// the checkPreconditions() call.
vector::TransferReadOp createTransferReadOp(vector::TransferReadOp op,
Value source,
RankedTensorType source_ty,
PatternRewriter &rewriter) {
// We know from preconditions that there are no out of bound dims.
SmallVector<bool> in_bounds(source_ty.getRank(), true);
return rewriter.create<vector::TransferReadOp>(
op.getLoc(),
VectorType::get(source_ty.getShape(), source_ty.getElementType()), source,
SmallVector<Value>(
source_ty.getRank(),
rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0)),
AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(source_ty.getRank(),
op->getContext())),
rewriter.getBoolArrayAttr(in_bounds));
}

template <typename DefiningOp>
LogicalResult matchAndRewriteTransferOfExpandOrCollapseShape(
vector::TransferReadOp op, PatternRewriter &rewriter) {
if (failed(checkPreconditions(op, rewriter))) {
return failure();
}
auto expand = op.getSource().template getDefiningOp<DefiningOp>();
if (!expand) {
return rewriter.notifyMatchFailure(
Expand All @@ -82,28 +116,12 @@ LogicalResult matchAndRewriteTransferOfExpandOrCollapseShape(
result_type.getShape() != expand.getResultType().getShape()) {
return rewriter.notifyMatchFailure(op, "output type mismatch");
}
SmallVector<Value> indices = {op.getIndices().begin(), op.getIndices().end()};
if (absl::c_any_of(
indices, [](Value index) { return !isConstantIntValue(index, 0); })) {
return rewriter.notifyMatchFailure(op, "non zero indices");
}
auto expand_src_type = expand.getSrcType();
// We know from preconditions that there are no out of bound dims.
SmallVector<bool> in_bounds(expand_src_type.getRank(), true);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
op, op.getType(),
rewriter.create<vector::TransferReadOp>(
op.getLoc(),
VectorType::get(expand_src_type.getShape(),
expand_src_type.getElementType()),
expand.getSrc(),
SmallVector<Value>(
expand_src_type.getRank(),
rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0)),
AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
expand_src_type.getRank(), op->getContext())),
op.getPadding(), /*mask=*/Value(),
rewriter.getBoolArrayAttr(in_bounds)));
createTransferReadOp(op, expand.getSrc(), expand_src_type, rewriter));
return success();
}

Expand Down Expand Up @@ -156,6 +174,107 @@ struct TransferReadOfConstant
}
};

// Rewrite `vector.transfer_read(arith.select)` as `arith.select` with
// `transfer_read` applied to its operands.
struct TransferReadOfSelect
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;

::mlir::LogicalResult matchAndRewrite(
::mlir::vector::TransferReadOp op,
::mlir::PatternRewriter &rewriter) const override {
if (failed(checkPreconditions(op, rewriter))) {
return failure();
}
auto select = op.getSource().getDefiningOp<::mlir::arith::SelectOp>();
if (!select) {
return rewriter.notifyMatchFailure(op, "source not an arith.select");
}
auto true_value_ty =
dyn_cast<RankedTensorType>(select.getTrueValue().getType());
if (!true_value_ty) {
return rewriter.notifyMatchFailure(
op, "true value is not a ranked tensor type");
}
// We do not check the type of the false_value since the verifier enforces
// that types of true_value, false_value, and result match.
auto false_value_ty =
dyn_cast<RankedTensorType>(select.getFalseValue().getType());
auto condition_type =
dyn_cast<RankedTensorType>(select.getCondition().getType());
if (!condition_type) {
return rewriter.notifyMatchFailure(
op, "condition is not a ranked tensor type");
}
auto transfer_read = [&](Value value, RankedTensorType type) {
return createTransferReadOp(op, value, type, rewriter);
};
rewriter.replaceOpWithNewOp<::mlir::arith::SelectOp>(
op, transfer_read(select.getCondition(), condition_type),
transfer_read(select.getTrueValue(), true_value_ty),
transfer_read(select.getFalseValue(), false_value_ty));
return ::mlir::success();
}
};

// Rewrite `vector.transfer_read(arith.cmpi)` as `arith.cmpi` with
// `transfer_read` applied to its operands.
struct TransferReadOfCmpI
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;

::mlir::LogicalResult matchAndRewrite(
::mlir::vector::TransferReadOp op,
::mlir::PatternRewriter &rewriter) const override {
if (failed(checkPreconditions(op, rewriter))) {
return failure();
}
auto cmp = op.getSource().getDefiningOp<::mlir::arith::CmpIOp>();
if (!cmp) {
return rewriter.notifyMatchFailure(op, "source not an arith.cmpi");
}
auto lhs_type = dyn_cast<RankedTensorType>(cmp.getLhs().getType());
if (!lhs_type) {
return rewriter.notifyMatchFailure(op, "lhs is not a ranked tensor type");
}
auto rhs_type = dyn_cast<RankedTensorType>(cmp.getRhs().getType());
if (!rhs_type) {
return rewriter.notifyMatchFailure(op, "rhs is not a ranked tensor type");
}
auto transfer_read = [&](Value value, RankedTensorType type) {
return createTransferReadOp(op, value, type, rewriter);
};
rewriter.replaceOpWithNewOp<::mlir::arith::CmpIOp>(
op, cmp.getPredicate(), transfer_read(cmp.getLhs(), lhs_type),
transfer_read(cmp.getRhs(), rhs_type));
return ::mlir::success();
}
};

// Rewrite `vector.transfer_read(tensor.splat)` as `vector.broadcast`.
struct TransferReadOfSplat
: public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> {
using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern;

::mlir::LogicalResult matchAndRewrite(
::mlir::vector::TransferReadOp op,
::mlir::PatternRewriter &rewriter) const override {
if (failed(checkPreconditions(op, rewriter))) {
return failure();
}
auto splat = op.getSource().getDefiningOp<::mlir::tensor::SplatOp>();
if (!splat) {
return rewriter.notifyMatchFailure(op, "source not a tensor.splat");
}
if (!splat.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "not statically shaped");
}
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, op.getVectorType(),
splat.getInput());
return ::mlir::success();
}
};

struct LinalgVectorizationPass
: public impl::LinalgVectorizationPassBase<LinalgVectorizationPass> {
LinalgVectorizationPass() = default;
Expand All @@ -177,9 +296,9 @@ struct LinalgVectorizationPass
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
patterns.add<TransferReadOfExpandShape>(ctx);
patterns.add<TransferReadOfCollapseShape>(ctx);
patterns.add<TransferReadOfConstant>(ctx);
patterns.add<TransferReadOfCmpI, TransferReadOfCollapseShape,
TransferReadOfConstant, TransferReadOfExpandShape,
TransferReadOfSelect, TransferReadOfSplat>(ctx);

// We do not want to apply the vector patterns above to the ops that are
// unrelated to the original linalg op.
Expand Down

0 comments on commit d9b7535

Please sign in to comment.