Skip to content

Commit

Permalink
Add pattern to canonicalize for loop bounds
Browse files Browse the repository at this point in the history
- add pattern to canonicalize affine.for loop bounds (using
  canonicalizeMapAndOperands)
- rename AffineForLoopBoundFolder -> AffineForLoopBoundFolder for
  consistency

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow#111

COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#111 from bondhugula:bound-canonicalize ee8fb7f43a7ffd45f6df3f53c95098d8b7e494c7
PiperOrigin-RevId: 269041220
  • Loading branch information
bondhugula authored and tensorflower-gardener committed Sep 14, 2019
1 parent bda59b1 commit f5de230
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 22 deletions.
24 changes: 12 additions & 12 deletions examples/Linalg/Linalg3/Example.cpp
Expand Up @@ -71,7 +71,7 @@ TEST_FUNC(matmul_as_matvec) {
// CHECK-LABEL: func @matmul_as_matvec(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
// CHECK: %[[vB:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK: %[[vC:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK: linalg.matvec(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view<?xf32>
Expand All @@ -90,9 +90,9 @@ TEST_FUNC(matmul_as_dot) {
// CHECK-LABEL: func @matmul_as_dot(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
// CHECK: %[[vB:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK-NEXT: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[M]]) {
// CHECK-NEXT: affine.for %{{.*}} = 0 to %[[M]] {
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
// CHECK-NEXT: %[[vC:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, index, index, !linalg.view<f32>
// CHECK-NEXT: linalg.dot(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view<f32>
Expand All @@ -117,9 +117,9 @@ TEST_FUNC(matmul_as_loops) {
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[%[[rM]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: %[[vB:.*]] = linalg.view %{{.*}}[%[[rK]], %[[rN]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: %[[vC:.*]] = linalg.view %{{.*}}[%[[rM]], %[[rN]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[M]]) {
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[K]]) {
// CHECK: affine.for %{{.*}} = 0 to %[[M]] {
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
// CHECK: affine.for %{{.*}} = 0 to %[[K]] {
// CHECK: %{{.*}} = cmpi "eq", %{{.*}} : index
// CHECK: %{{.*}} = linalg.load %[[vC]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
// CHECK: %{{.*}} = select {{.*}} : f32
Expand All @@ -146,11 +146,11 @@ TEST_FUNC(matmul_as_matvec_as_loops) {
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[{{.*}}, {{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
// CHECK: %[[vB:.*]] = linalg.view %{{.*}}[{{.*}}, {{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK: %[[vC:.*]] = linalg.view %{{.*}}[{{.*}}, {{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[M]]) {
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[K]]) {
// CHECK: affine.for %{{.*}} = 0 to %[[M]] {
// CHECK: affine.for %{{.*}} = 0 to %[[K]] {
// CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : index
// CHECK: %[[C:.*]] = linalg.load %[[vC]][%{{.*}}] : !linalg.view<?xf32>
// CHECK: %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32
Expand Down Expand Up @@ -181,10 +181,10 @@ TEST_FUNC(matmul_as_matvec_as_affine) {
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
// CHECK-NOT: {{.*}} = linalg.
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[M]]) {
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[K]]) {
// CHECK: affine.for %{{.*}} = 0 to %[[M]] {
// CHECK: affine.for %{{.*}} = 0 to %[[K]] {
// CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : index
// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
Expand Down
16 changes: 8 additions & 8 deletions examples/Linalg/Linalg4/Example.cpp
Expand Up @@ -77,9 +77,9 @@ TEST_FUNC(matmul_tiled_loops) {
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[M]]) step 8 {
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) step 9 {
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[K]]) {
// CHECK: affine.for %{{.*}} = 0 to %[[M]] step 8 {
// CHECK: affine.for %{{.*}} = 0 to %[[N]] step 9 {
// CHECK: affine.for %{{.*}} = 0 to %[[K]] {
// CHECK: affine.for %{{.*}} = max (d0) -> (0, d0)(%{{.*}}) to min (d0)[s0] -> (s0, d0 + 8)(%{{.*}})[%[[M]]] {
// CHECK: affine.for %{{.*}} = max (d0) -> (0, d0)(%{{.*}}) to min (d0)[s0] -> (s0, d0 + 9)(%{{.*}})[%[[N]]] {
// CHECK-NEXT: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : index
Expand Down Expand Up @@ -107,8 +107,8 @@ TEST_FUNC(matmul_tiled_views) {
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[M]]) step 8 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) step 9 {
// CHECK: affine.for %{{.*}} = 0 to %[[M]] step 8 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to %[[N]] step 9 {
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%{{.*}})
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %{{.*}}:%[[i0max]]:{{.*}} : !linalg.range
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
Expand Down Expand Up @@ -141,8 +141,8 @@ TEST_FUNC(matmul_tiled_views_as_loops) {
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
// CHECK: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[M]]) step 8 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) step 9 {
// CHECK: affine.for %{{.*}} = 0 to %[[M]] step 8 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to %[[N]] step 9 {
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%{{.*}})
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %{{.*}}:%[[i0max]]:{{.*}} : !linalg.range
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
Expand All @@ -153,7 +153,7 @@ TEST_FUNC(matmul_tiled_views_as_loops) {
// CHECK-NEXT: %[[vC:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: affine.for %{{.*}} = (d0) -> (d0)(%{{.*}}) to (d0) -> (d0)(%[[i0max]]) {
// CHECK-NEXT: affine.for %{{.*}} = (d0) -> (d0)(%{{.*}}) to (d0) -> (d0)(%[[i1max]]) {
// CHECK-NEXT: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[K]]) {
// CHECK-NEXT: affine.for %{{.*}} = 0 to %[[K]] {
// CHECK-NEXT: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %{{.*}} = linalg.load %[[vC]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
Expand Down
37 changes: 35 additions & 2 deletions lib/Dialect/AffineOps/AffineOps.cpp
Expand Up @@ -1358,7 +1358,7 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
};

/// This is a pattern to fold constant loop bounds.
struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
struct AffineForOpBoundFolder : public OpRewritePattern<AffineForOp> {
using OpRewritePattern<AffineForOp>::OpRewritePattern;

PatternMatchResult matchAndRewrite(AffineForOp forOp,
Expand Down Expand Up @@ -1413,11 +1413,44 @@ struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
return matchSuccess();
}
};

// This is a pattern to canonicalize affine for op loop bounds.
struct AffineForOpBoundCanonicalizer : public OpRewritePattern<AffineForOp> {
using OpRewritePattern<AffineForOp>::OpRewritePattern;

PatternMatchResult matchAndRewrite(AffineForOp forOp,
PatternRewriter &rewriter) const override {
SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());

auto lbMap = forOp.getLowerBoundMap();
auto ubMap = forOp.getUpperBoundMap();
auto prevLbMap = lbMap;
auto prevUbMap = ubMap;

canonicalizeMapAndOperands(&lbMap, &lbOperands);
canonicalizeMapAndOperands(&ubMap, &ubOperands);

// Any canonicalization change always leads to updated map(s).
if (lbMap == prevLbMap && ubMap == prevUbMap)
return matchFailure();

if (lbMap != prevLbMap)
forOp.setLowerBound(lbOperands, lbMap);
if (ubMap != prevUbMap)
forOp.setUpperBound(ubOperands, ubMap);

rewriter.updatedRootInPlace(forOp);
return matchSuccess();
}
};

} // end anonymous namespace

void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<AffineForEmptyLoopFolder, AffineForLoopBoundFolder>(context);
results.insert<AffineForEmptyLoopFolder, AffineForOpBoundFolder,
AffineForOpBoundCanonicalizer>(context);
}

AffineBound AffineForOp::getLowerBound() {
Expand Down
30 changes: 30 additions & 0 deletions test/AffineOps/canonicalize.mlir
Expand Up @@ -446,3 +446,33 @@ func @canonicalize_affine_if(%M : index, %N : index) {
}
return
}

// -----

// CHECK-DAG: [[LBMAP:#map[0-9]+]] = ()[s0] -> (0, s0)
// CHECK-DAG: [[UBMAP:#map[0-9]+]] = ()[s0] -> (1024, s0 + s0)

// CHECK-LABEL: func @canonicalize_bounds
// CHECK-SAME: [[M:%.*]]: index,
// CHECK-SAME: [[N:%.*]]: index)
func @canonicalize_bounds(%M : index, %N : index) {
%c0 = constant 0 : index
%c1024 = constant 1024 : index
// Drop unused operand %N, drop duplicate operand %M, propagate %c1024, and
// promote %M to a symbolic one.
// CHECK: affine.for %{{.*}} = 0 to min [[UBMAP]](){{\[}}[[M]]{{\]}}
affine.for %i = 0 to min (d0, d1, d2, d3) -> (d0, d1 + d2) (%c1024, %M, %M, %N) {
"foo"() : () -> ()
}
// Promote %M to symbolic position.
// CHECK: affine.for %{{.*}} = 0 to #map{{[0-9]+}}(){{\[}}[[M]]{{\]}}
affine.for %i = 0 to (d0) -> (4 * d0) (%M) {
"foo"() : () -> ()
}
// Lower bound canonicalize.
// CHECK: affine.for %{{.*}} = max [[LBMAP]](){{\[}}[[N]]{{\]}} to [[M]]
affine.for %i = max (d0, d1) -> (d0, d1) (%c0, %N) to %M {
"foo"() : () -> ()
}
return
}

0 comments on commit f5de230

Please sign in to comment.