Skip to content

Commit

Permalink
Migrate uses of bar.dyn_cast<Foo>() to dyn_cast<Foo>(bar) (#394)
Browse files Browse the repository at this point in the history
* Migrate method dyn_cast to llvm::dyn_cast

```bash
perl -i -pe 's/([^ ]*)\.dyn_cast<([^>]*)>\(\)/llvm::dyn_cast<\2>(\1)/g' **/*.cpp **/*.h **/*.cc
```

* manually migrate some dyn_cast_or_null

* strip llvm:: prefix

* clang-format

---------

Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com>
  • Loading branch information
j2kun and j2kun committed Apr 3, 2024
1 parent 69c116a commit 95ee215
Show file tree
Hide file tree
Showing 23 changed files with 265 additions and 272 deletions.
52 changes: 26 additions & 26 deletions lib/polygeist/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1631,9 +1631,9 @@ class CopySimplification final : public OpRewritePattern<T> {
Type elTy = dstTy.getElementType();

size_t width = 1;
if (auto IT = elTy.dyn_cast<IntegerType>())
if (auto IT = dyn_cast<IntegerType>(elTy))
width = IT.getWidth() / 8;
else if (auto FT = elTy.dyn_cast<FloatType>())
else if (auto FT = dyn_cast<FloatType>(elTy))
width = FT.getWidth() / 8;
else {
// TODO extend to llvm compatible type
Expand Down Expand Up @@ -1734,9 +1734,9 @@ class SetSimplification final : public OpRewritePattern<T> {
return failure();

size_t width = 1;
if (auto IT = elTy.dyn_cast<IntegerType>())
if (auto IT = dyn_cast<IntegerType>(elTy))
width = IT.getWidth() / 8;
else if (auto FT = elTy.dyn_cast<FloatType>())
else if (auto FT = dyn_cast<FloatType>(elTy))
width = FT.getWidth() / 8;
else {
// TODO extend to llvm compatible type
Expand Down Expand Up @@ -1787,7 +1787,7 @@ class SetSimplification final : public OpRewritePattern<T> {
SmallVector<Value> idxs;
Value val;

if (auto IT = elTy.dyn_cast<IntegerType>())
if (auto IT = dyn_cast<IntegerType>(elTy))
val =
rewriter.create<arith::ConstantIntOp>(op.getLoc(), 0, IT.getWidth());
else {
Expand Down Expand Up @@ -2479,7 +2479,7 @@ class SelectOfExt final : public OpRewritePattern<arith::SelectOp> {

LogicalResult matchAndRewrite(arith::SelectOp op,
PatternRewriter &rewriter) const override {
auto ty = op.getType().dyn_cast<IntegerType>();
auto ty = dyn_cast<IntegerType>(op.getType());
if (!ty)
return failure();
IntegerAttr lhs, rhs;
Expand Down Expand Up @@ -2589,10 +2589,10 @@ class CmpProp final : public OpRewritePattern<CmpIOp> {
v.getDefiningOp<LLVM::UndefOp>() ||
v.getDefiningOp<polygeist::UndefOp>();
if (auto extOp = v.getDefiningOp<ExtUIOp>())
if (auto it = extOp.getIn().getType().dyn_cast<IntegerType>())
if (auto it = dyn_cast<IntegerType>(extOp.getIn().getType()))
change |= it.getWidth() == 1;
if (auto extOp = v.getDefiningOp<ExtSIOp>())
if (auto it = extOp.getIn().getType().dyn_cast<IntegerType>())
if (auto it = dyn_cast<IntegerType>(extOp.getIn().getType()))
change |= it.getWidth() == 1;
}
if (!change) {
Expand Down Expand Up @@ -3175,7 +3175,7 @@ bool valueCmp(Cmp cmp, Value bval, ValueOrInt val) {
}
}

if (auto baval = bval.dyn_cast<BlockArgument>()) {
if (auto baval = dyn_cast<BlockArgument>(bval)) {
if (affine::AffineForOp afFor =
dyn_cast<affine::AffineForOp>(baval.getOwner()->getParentOp())) {
auto for_lb = afFor.getLowerBoundMap().getResults()[baval.getArgNumber()];
Expand Down Expand Up @@ -3487,7 +3487,7 @@ bool valueCmp(Cmp cmp, AffineExpr expr, size_t numDim, ValueRange operands,

// Range is [lb, ub)
bool rangeIncludes(Value bval, ValueOrInt lb, ValueOrInt ub) {
if (auto baval = bval.dyn_cast<BlockArgument>()) {
if (auto baval = dyn_cast<BlockArgument>(bval)) {
if (affine::AffineForOp afFor =
dyn_cast<affine::AffineForOp>(baval.getOwner()->getParentOp())) {
return valueCmp(
Expand Down Expand Up @@ -3621,7 +3621,7 @@ struct AffineIfSinking : public OpRewritePattern<affine::AffineIfOp> {
if (!opd) {
return failure();
}
auto ival = op.getOperands()[opd.getPosition()].dyn_cast<BlockArgument>();
auto ival = dyn_cast<BlockArgument>(op.getOperands()[opd.getPosition()]);
if (!ival) {
return failure();
}
Expand Down Expand Up @@ -3663,7 +3663,7 @@ struct AffineIfSinking : public OpRewritePattern<affine::AffineIfOp> {
if (!par.getRegion().isAncestor(v.getParentRegion()) ||
op.getThenRegion().isAncestor(v.getParentRegion()))
return;
if (auto ba = v.dyn_cast<BlockArgument>()) {
if (auto ba = dyn_cast<BlockArgument>(v)) {
if (ba.getOwner()->getParentOp() == par) {
return;
}
Expand Down Expand Up @@ -4285,7 +4285,7 @@ struct MergeNestedAffineParallelIf
continue;
}
if (auto dim = cur.dyn_cast<AffineDimExpr>()) {
auto ival = operands[dim.getPosition()].dyn_cast<BlockArgument>();
auto ival = dyn_cast<BlockArgument>(operands[dim.getPosition()]);
if (!ival || ival.getOwner()->getParentOp() != op) {
rhs = rhs + dim;
if (failure)
Expand Down Expand Up @@ -4315,7 +4315,7 @@ struct MergeNestedAffineParallelIf

if (auto dim = bop.getLHS().dyn_cast<AffineDimExpr>()) {
auto ival =
operands[dim.getPosition()].dyn_cast<BlockArgument>();
dyn_cast<BlockArgument>(operands[dim.getPosition()]);
if (!ival || ival.getOwner()->getParentOp() != op) {
rhs = rhs + bop;
// While legal, this may run before parallel merging
Expand Down Expand Up @@ -4511,7 +4511,7 @@ struct MergeParallelInductions
continue;
}
if (auto dim = cur.dyn_cast<AffineDimExpr>()) {
auto ival = operands[dim.getPosition()].dyn_cast<BlockArgument>();
auto ival = dyn_cast<BlockArgument>(operands[dim.getPosition()]);
if (!ival || ival.getOwner()->getParentOp() != op) {
rhs = rhs + dim;
continue;
Expand All @@ -4538,7 +4538,7 @@ struct MergeParallelInductions
}

if (auto dim = bop.getLHS().dyn_cast<AffineDimExpr>()) {
auto ival = operands[dim.getPosition()].dyn_cast<BlockArgument>();
auto ival = dyn_cast<BlockArgument>(operands[dim.getPosition()]);
if (!ival || ival.getOwner()->getParentOp() != op) {
rhs = rhs + bop;
continue;
Expand Down Expand Up @@ -4983,8 +4983,8 @@ template <typename T> struct BufferElimination : public OpRewritePattern<T> {
auto opd = map.getResults()[0].dyn_cast<AffineDimExpr>();
if (!opd)
continue;
auto val = ((Value)load.getMapOperands()[opd.getPosition()])
.dyn_cast<BlockArgument>();
auto val = dyn_cast<BlockArgument>(
((Value)load.getMapOperands()[opd.getPosition()]));
if (!val)
continue;

Expand Down Expand Up @@ -5033,8 +5033,8 @@ template <typename T> struct BufferElimination : public OpRewritePattern<T> {
auto opd = map.getResults()[0].dyn_cast<AffineDimExpr>();
if (!opd)
continue;
auto val = ((Value)load.getMapOperands()[opd.getPosition()])
.dyn_cast<BlockArgument>();
auto val = dyn_cast<BlockArgument>(
((Value)load.getMapOperands()[opd.getPosition()]));
if (!val)
continue;

Expand Down Expand Up @@ -5279,7 +5279,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
}
if (storeIdxs[pair.index()].isValue) {
Value auval = storeIdxs[pair.index()].v_val;
BlockArgument bval = auval.dyn_cast<BlockArgument>();
BlockArgument bval = dyn_cast<BlockArgument>(auval);
if (!bval) {
LLVM_DEBUG(llvm::dbgs() << " + non bval expr " << bval << "\n");
continue;
Expand Down Expand Up @@ -5339,7 +5339,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
if (!VI.isValue)
continue;
auto V = VI.v_val;
auto BA = V.dyn_cast<BlockArgument>();
auto BA = dyn_cast<BlockArgument>(V);
if (!BA) {
LLVM_DEBUG(llvm::dbgs() << " + non map oper " << V << "\n");
return failure();
Expand Down Expand Up @@ -5426,7 +5426,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
if (!VI.isValue)
continue;
auto V = VI.v_val;
auto BA = V.dyn_cast<BlockArgument>();
auto BA = dyn_cast<BlockArgument>(V);
Operation *c = BA.getOwner()->getParentOp();
if (isa<affine::AffineParallelOp>(c) || isa<scf::ParallelOp>(c)) {
Operation *tmp = store;
Expand Down Expand Up @@ -5468,7 +5468,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
if (storeIdxSet.count(V))
continue;

if (auto BA = V.dyn_cast<BlockArgument>()) {
if (auto BA = dyn_cast<BlockArgument>(V)) {
Operation *parent = BA.getOwner()->getParentOp();

if (auto sop = storeVal.getDefiningOp())
Expand Down Expand Up @@ -5518,10 +5518,10 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
if (!isa<MemoryEffects::Read>(res.getEffect()))
return false;
unsigned addr = 0;
if (auto MT = v.getType().dyn_cast<MemRefType>())
if (auto MT = dyn_cast<MemRefType>(v.getType()))
addr = MT.getMemorySpaceAsInt();
else if (auto LT =
v.getType().dyn_cast<LLVM::LLVMPointerType>())
dyn_cast<LLVM::LLVMPointerType>(v.getType()))
addr = LT.getAddressSpace();
else
return false;
Expand Down
8 changes: 4 additions & 4 deletions lib/polygeist/Passes/AffineCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ static bool legalCondition(Value en, bool dim = false) {
// if (!outer || legalCondition(IC.getOperand(), false)) return true;
//}
if (!dim)
if (auto BA = en.dyn_cast<BlockArgument>()) {
if (auto BA = dyn_cast<BlockArgument>(en)) {
if (isa<affine::AffineForOp, affine::AffineParallelOp>(
BA.getOwner()->getParentOp()))
return true;
Expand Down Expand Up @@ -730,7 +730,7 @@ static void setLocationAfter(PatternRewriter &b, mlir::Value val) {
it++;
b.setInsertionPoint(val.getDefiningOp()->getBlock(), it);
}
if (auto bop = val.dyn_cast<mlir::BlockArgument>())
if (auto bop = dyn_cast<mlir::BlockArgument>(val))
b.setInsertionPoint(bop.getOwner(), bop.getOwner()->begin());
}

Expand All @@ -745,7 +745,7 @@ struct IndexCastMovement : public OpRewritePattern<IndexCastOp> {
}

mlir::Value val = op.getOperand();
if (auto bop = val.dyn_cast<mlir::BlockArgument>()) {
if (auto bop = dyn_cast<mlir::BlockArgument>(val)) {
if (op.getOperation()->getBlock() != bop.getOwner()) {
op.getOperation()->moveBefore(bop.getOwner(), bop.getOwner()->begin());
return success();
Expand Down Expand Up @@ -1010,7 +1010,7 @@ bool isValidIndex(Value val) {
if (val.getDefiningOp<ConstantIntOp>())
return true;

if (auto ba = val.dyn_cast<BlockArgument>()) {
if (auto ba = dyn_cast<BlockArgument>(val)) {
auto *owner = ba.getOwner();
assert(owner);

Expand Down
28 changes: 14 additions & 14 deletions lib/polygeist/Passes/CanonicalizeFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct ForBreakAddUpgrade : public OpRewritePattern<scf::ForOp> {

auto condition = outerIfOp.getCondition();
// and that the outermost if's condition is an iter arg of the for
auto condArg = condition.dyn_cast<BlockArgument>();
auto condArg = dyn_cast<BlockArgument>(condition);
if (!condArg)
return failure();
if (condArg.getOwner()->getParentOp() != forOp)
Expand All @@ -121,7 +121,7 @@ struct ForBreakAddUpgrade : public OpRewritePattern<scf::ForOp> {
// and is false unless coming from inside the if
auto forYieldOp = cast<scf::YieldOp>(block.getTerminator());
auto opres =
forYieldOp.getOperand(condArg.getArgNumber() - 1).dyn_cast<OpResult>();
dyn_cast<OpResult>(forYieldOp.getOperand(condArg.getArgNumber() - 1));
if (!opres)
return failure();
if (opres.getOwner() != outerIfOp)
Expand All @@ -143,7 +143,7 @@ struct ForBreakAddUpgrade : public OpRewritePattern<scf::ForOp> {
if (opres.getResultNumber() == regionArg.getArgNumber() - 1)
continue;

auto opres2 = forYieldOperand.dyn_cast<OpResult>();
auto opres2 = dyn_cast<OpResult>(forYieldOperand);
if (!opres2)
continue;
if (opres2.getOwner() != outerIfOp)
Expand Down Expand Up @@ -627,13 +627,13 @@ yop2.results()[idx]);
*/

bool isTopLevelArgValue(Value value, Region *region) {
if (auto arg = value.dyn_cast<BlockArgument>())
if (auto arg = dyn_cast<BlockArgument>(value))
return arg.getParentRegion() == region;
return false;
}

bool isBlockArg(Value value) {
if (auto arg = value.dyn_cast<BlockArgument>())
if (auto arg = dyn_cast<BlockArgument>(value))
return true;
return false;
}
Expand All @@ -642,7 +642,7 @@ bool dominateWhile(Value value, WhileOp loop) {
if (Operation *op = value.getDefiningOp()) {
DominanceInfo dom(loop);
return dom.properlyDominates(op, loop);
} else if (auto arg = value.dyn_cast<BlockArgument>()) {
} else if (auto arg = dyn_cast<BlockArgument>(value)) {
return arg.getOwner()->getParentOp()->isProperAncestor(loop);
} else {
assert("????");
Expand Down Expand Up @@ -682,11 +682,11 @@ struct WhileToForHelper {
negativeStep = false;

auto condOp = loop.getConditionOp();
indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
indVar = dyn_cast<BlockArgument>(cmpIOp.getLhs());
Type extType = nullptr;
// todo handle ext
if (auto ext = cmpIOp.getLhs().getDefiningOp<ExtSIOp>()) {
indVar = ext.getIn().dyn_cast<BlockArgument>();
indVar = dyn_cast<BlockArgument>(ext.getIn());
extType = ext.getType();
}
// Condition is not the same as an induction variable
Expand Down Expand Up @@ -1004,7 +1004,7 @@ struct MoveWhileAndDown : public OpRewritePattern<WhileOp> {

Value extraCmp = andIOp->getOperand(1 - i);
Value lookThrough = nullptr;
if (auto BA = extraCmp.dyn_cast<BlockArgument>()) {
if (auto BA = dyn_cast<BlockArgument>(extraCmp)) {
lookThrough = oldYield.getOperand(BA.getArgNumber());
}
if (!helper.computeLegality(/*sizeCheck*/ false, lookThrough)) {
Expand Down Expand Up @@ -1341,7 +1341,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
// yield i:pair<2>
// }
if (!std::get<0>(pair).use_empty()) {
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
if (auto blockArg = dyn_cast<BlockArgument>(elseYielded))
if (blockArg.getOwner() == &op.getBefore().front()) {
if (afterYield.getResults()[blockArg.getArgNumber()] ==
std::get<2>(pair) &&
Expand Down Expand Up @@ -1580,7 +1580,7 @@ struct WhileCmpOffset : public OpRewritePattern<WhileOp> {
if (addI.getOperand(1).getDefiningOp() &&
!op.getBefore().isAncestor(
addI.getOperand(1).getDefiningOp()->getParentRegion()))
if (auto blockArg = addI.getOperand(0).dyn_cast<BlockArgument>()) {
if (auto blockArg = dyn_cast<BlockArgument>(addI.getOperand(0))) {
if (blockArg.getOwner() == &op.getBefore().front()) {
auto rng = llvm::make_early_inc_range(blockArg.getUses());

Expand Down Expand Up @@ -1859,7 +1859,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
auto isDefinedOutsideOfBody = [&](Value value) {
auto *definingOp = value.getDefiningOp();
if (!definingOp) {
if (auto ba = value.dyn_cast<BlockArgument>())
if (auto ba = dyn_cast<BlockArgument>(value))
definingOp = ba.getOwner()->getParentOp();
assert(definingOp);
}
Expand Down Expand Up @@ -2125,7 +2125,7 @@ struct WhileShiftToInduction : public OpRewritePattern<WhileOp> {
if (!matchPattern(cmpIOp.getRhs(), m_Zero()))
return failure();

auto indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
auto indVar = dyn_cast<BlockArgument>(cmpIOp.getLhs());
if (!indVar)
return failure();

Expand All @@ -2144,7 +2144,7 @@ struct WhileShiftToInduction : public OpRewritePattern<WhileOp> {
if (!matchPattern(shiftOp.getRhs(), m_One()))
return failure();

auto prevIndVar = shiftOp.getLhs().dyn_cast<BlockArgument>();
auto prevIndVar = dyn_cast<BlockArgument>(shiftOp.getLhs());
if (!prevIndVar)
return failure();

Expand Down
4 changes: 2 additions & 2 deletions lib/polygeist/Passes/CollectKernelStatistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ std::array<StrideTy, 3> estimateStride(mlir::OperandRange indices,
} else {
return UNKNOWN;
}
} else if (auto ba = v.dyn_cast<BlockArgument>()) {
} else if (auto ba = dyn_cast<BlockArgument>(v)) {
return 0;
if (isa<gpu::GPUFuncOp>(ba.getOwner()->getParentOp())) {
return 0;
Expand Down Expand Up @@ -339,7 +339,7 @@ static void generateAlternativeKernelDescs(mlir::ModuleOp m) {
isa<arith::SubFOp>(&op) || isa<arith::AddFOp>(&op) ||
isa<arith::RemFOp>(&op) || false) {
int width =
op.getOperand(0).getType().dyn_cast<FloatType>().getWidth();
dyn_cast<FloatType>(op.getOperand(0).getType()).getWidth();
addTo(floatOps, width, blockTrips);
} else if (isa<arith::MulIOp>(&op) || isa<arith::DivUIOp>(&op) ||
isa<arith::DivSIOp>(&op) || isa<arith::SubIOp>(&op) ||
Expand Down

0 comments on commit 95ee215

Please sign in to comment.