Skip to content

Commit

Permalink
Add support for some multi-store cases in affine fusion
Browse files Browse the repository at this point in the history
This PR is a stepping stone towards supporting generic multi-store
source loop nests in affine loop fusion. It extends the algorithm to
support fusion of multi-store loop nests that:
 1. have only one store that writes to a function-local live out, and
 2. the remaining stores are involved in loop nest self dependences
    or no dependences within the function.

Closes #162

COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#162 from dcaballe:dcaballe/multi-output-fusion 7fb7dec6fe8b45f5ce176f018bfe37b256420c45
PiperOrigin-RevId: 273773907
  • Loading branch information
dcaballe authored and tensorflower-gardener committed Oct 9, 2019
1 parent 6165ddd commit 4b1e0f8
Showing 1 changed file with 66 additions and 34 deletions.
100 changes: 66 additions & 34 deletions third_party/mlir/lib/Transforms/LoopFusion.cpp
Expand Up @@ -322,6 +322,44 @@ struct MemRefDependenceGraph {
return false;
}

// Returns the unique AffineStoreOp in `node` that meets all the following:
// *) store is the only one that writes to a function-local memref live out
// of `node`,
// *) store is not the source of a self-dependence on `node`.
// Otherwise, returns a null AffineStoreOp.
AffineStoreOp getUniqueOutgoingStore(Node *node) {
AffineStoreOp uniqueStore;

// Return null if `node` doesn't have any outgoing edges.
auto outEdgeIt = outEdges.find(node->id);
if (outEdgeIt == outEdges.end())
return nullptr;

const auto &nodeOutEdges = outEdgeIt->second;
for (auto *op : node->stores) {
auto storeOp = cast<AffineStoreOp>(op);
auto *memref = storeOp.getMemRef();
// Skip this store if there are no dependences on its memref. This means
// that store either:
// *) writes to a memref that is only read within the same loop nest
// (self-dependence edges are not represented in graph at the moment),
// *) writes to a function live out memref (function parameter), or
// *) is dead.
if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) {
return (edge.value != memref);
}))
continue;

if (uniqueStore)
// Found multiple stores to function-local live-out memrefs.
return nullptr;
// Found first store to function-local live-out memref.
uniqueStore = storeOp;
}

return uniqueStore;
}

// Returns true if node 'id' can be removed from the graph. Returns false
// otherwise. A node can be removed from the graph iff the following
// conditions are met:
Expand Down Expand Up @@ -963,42 +1001,30 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
return newMemRef;
}

// Checks if node 'srcId' (which writes to a live out memref), can be safely
// fused into node 'dstId'. Returns true if the following conditions are met:
// *) 'srcNode' only writes to live out 'memref'.
// *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId').
// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's
// write region to 'memref'.
// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
// may write to multiple memrefs but it is required that only one of them,
// 'srcLiveOutStoreOp', have an output edge.
// Returns true if 'dstNode's read/write region to 'memref' is a super set of
// 'srcNode's write region to 'memref'.
// TODO(andydavis) Generalize this to handle more live in/out cases.
static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
Value *memref,
AffineStoreOp srcLiveOutStoreOp,
MemRefDependenceGraph *mdg) {
auto *srcNode = mdg->getNode(srcId);
assert(srcLiveOutStoreOp && "Expected a valid store op");
assert(mdg->getOutEdgeCount(srcId) == 1 && "Expected only one output edge");
auto *dstNode = mdg->getNode(dstId);
Value *memref = srcLiveOutStoreOp.getMemRef();

// Gather all memrefs from 'srcNode' store ops.
DenseSet<Value *> storeMemrefs;
for (auto *storeOpInst : srcNode->stores) {
storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
}
// Return false if any of the following are true:
// *) 'srcNode' writes to a live in/out memref other than 'memref'.
// *) 'srcNode' has more than one output edge on 'memref'.
// Check that all stores are to the same memref.
if (storeMemrefs.size() != 1 ||
mdg->getOutEdgeCount(srcNode->id, memref) != 1)
return false;
// Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
auto *srcStoreOpInst = srcNode->stores.front();
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
// Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());
if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute MemRefRegion for source operation\n.");
return false;
}
SmallVector<int64_t, 4> srcShape;
// Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
// by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// by 'srcStoreOp' at depth 'dstLoopDepth'.
Optional<int64_t> srcNumElements =
srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
if (!srcNumElements.hasValue())
Expand Down Expand Up @@ -1491,17 +1517,25 @@ struct GreedyFusion {
// Skip if 'srcNode' is not a loop nest.
if (!isa<AffineForOp>(srcNode->op))
continue;
// Skip if 'srcNode' has more than one store to any memref.
// TODO(andydavis) Support fusing multi-output src loop nests.
if (srcNode->stores.size() != 1)
// Skip if 'srcNode' has more than one live-out store to a
// function-local memref.
// TODO(andydavis) Support more generic multi-output src loop nests
// fusion.
auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
if (!srcStoreOp)
continue;
// Unique outgoing store found must write to 'memref' since 'memref'
// is the one that established the producer-consumer relationship
// between 'srcNode' and 'dstNode'.
assert(srcStoreOp.getMemRef() == memref &&
"Found store to unexpected memref");

// Skip if 'srcNode' writes to any live in or escaping memrefs,
// and cannot be fused.
bool writesToLiveInOrOut =
mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
if (writesToLiveInOrOut &&
!canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
!canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
continue;

// Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
Expand All @@ -1515,8 +1549,6 @@ struct GreedyFusion {
if (insertPointInst == nullptr)
continue;

// Get unique 'srcNode' store op.
auto *srcStoreOpInst = srcNode->stores.front();
// Gather 'dstNode' store ops to 'memref'.
SmallVector<Operation *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
Expand All @@ -1526,8 +1558,8 @@ struct GreedyFusion {
unsigned bestDstLoopDepth;
mlir::ComputationSliceState sliceState;
// Check if fusion would be profitable.
if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
dstLoadOpInsts, dstStoreOpInsts, &sliceState,
if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
dstStoreOpInsts, &sliceState,
&bestDstLoopDepth, maximalFusion))
continue;
// TODO(andydavis) Remove the following test code when canFuseLoops
Expand All @@ -1542,7 +1574,7 @@ struct GreedyFusion {
}
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest) {
LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
<< *sliceLoopNest.getOperation() << "\n");
Expand Down

0 comments on commit 4b1e0f8

Please sign in to comment.