Skip to content

Commit

Permalink
[CINN]Fix GatherOpPattern in pd_to_cinn_pass (#63972)
Browse files Browse the repository at this point in the history
* [CINN]Fix GatherOpPattern in pd_to_cinn_pass

* fix return

* fix attribute
  • Loading branch information
Aurelius84 committed Apr 29, 2024
1 parent 0474465 commit 698d2f0
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -941,9 +941,12 @@ class SigmoidOpPattern
: public pir::OpRewritePattern<paddle::dialect::SigmoidOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SigmoidOp>::OpRewritePattern;
bool Match(paddle::dialect::SigmoidOp op) const override {
return !CompatibleInfo::IsDeniedForCinn(*op.operation());
}

bool MatchAndRewrite(paddle::dialect::SigmoidOp op,
pir::PatternRewriter &rewriter) const override {
void Rewrite(paddle::dialect::SigmoidOp op,
pir::PatternRewriter &rewriter) const override {
auto input_dtype = paddle::dialect::TransToPhiDataType(
op->operand_source(0)
.type()
Expand Down Expand Up @@ -976,46 +979,41 @@ class SigmoidOpPattern
}

rewriter.ReplaceAllUsesWith(op.result(0), div);

rewriter.EraseOp(op);

return true;
}
};
class GatherOpPattern
: public pir::OpRewritePattern<paddle::dialect::GatherOp> {
public:
using pir::OpRewritePattern<paddle::dialect::GatherOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::GatherOp op,
pir::PatternRewriter &rewriter) const override {
bool Match(paddle::dialect::GatherOp op) const override {
const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation());
auto axis_gen_op = op->operand_source(2).defining_op();
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
return !is_denied && full_op;
}

void Rewrite(paddle::dialect::GatherOp op,
pir::PatternRewriter &rewriter) const override {
auto gather_op = op->dyn_cast<paddle::dialect::GatherOp>();
auto x = op.operand_source(0);
auto index = op->operand_source(1);
const int axis = [&]() -> int {
int axis = 0;
if (gather_op->attributes().count("index")) {
axis =
gather_op.attribute("index").dyn_cast<pir::Int32Attribute>().data();
} else {
auto axis_gen_op = op.operand_source(2).defining_op();
PADDLE_ENFORCE_EQ(axis_gen_op->isa<paddle::dialect::FullOp>(),
true,
::phi::errors::InvalidArgument(
"Not Supported: The gather operator for CINN "
"only supports constant value"));
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
axis = static_cast<int>(full_op.attribute("value")
.dyn_cast<::pir::FloatAttribute>()
.data());
return axis;
}
auto axis_gen_op = op.operand_source(2).defining_op();
PADDLE_ENFORCE_EQ(axis_gen_op->isa<paddle::dialect::FullOp>(),
true,
::phi::errors::InvalidArgument(
"Not Supported: The gather operator for CINN "
"only supports constant value"));
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
return static_cast<int>(
full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data());
}();
auto out =
rewriter.Build<cinn::dialect::GatherOp>(x, index, axis)->result(0);
rewriter.ReplaceAllUsesWith(op->result(0), out);
rewriter.EraseOp(op);
return true;
}
};

Expand Down

0 comments on commit 698d2f0

Please sign in to comment.