Skip to content

Commit

Permalink
without third part
Browse files Browse the repository at this point in the history
  • Loading branch information
walkalone20 committed Apr 29, 2024
1 parent 898ecab commit f76bd1b
Show file tree
Hide file tree
Showing 158 changed files with 1,251 additions and 4,218 deletions.
4 changes: 0 additions & 4 deletions paddle/cinn/hlir/dialect/operator/ir/op_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class GroupInfoAttribute : public pir::Attribute {
return storage() < right.storage();
}

static std::string name() { return "a_group_info"; }

const GroupInfo& data() const;
};

Expand All @@ -46,8 +44,6 @@ class CINNKernelInfoAttribute : public pir::Attribute {
return storage() < right.storage();
}

static std::string name() { return "a_cinn_kernel_info"; }

const cinn::hlir::framework::pir::CINNKernelInfo& data() const;
};

Expand Down
7 changes: 0 additions & 7 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
Expand Down Expand Up @@ -220,14 +219,8 @@ void ApplyCinnPass(::pir::Program* program,
ApplyPdToCinnPass(program, CreatePassManager);
ApplyCinnPreprocessPass(program, CreatePassManager);
ApplyBuildGroupOpPass(program, CreatePassManager);
LOG(INFO) << "====[pir-to-py-code group-ops begin]===" << std::endl
<< PirToPyCodeConverter().Convert(*program);
LOG(INFO) << "====[pir-to-py-code group-ops end]===";
ApplyGroupOpPass(program, CreatePassManager);
ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager);
LOG(INFO) << "====[pir-to-py-code fusion-ops begin]===" << std::endl
<< PirToPyCodeConverter().Convert(*program);
LOG(INFO) << "====[pir-to-py-code fusion-ops end]===";
LOG(INFO) << "FusionOp count before lowering : *****[ "
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
<< " ]*****";
Expand Down
34 changes: 0 additions & 34 deletions paddle/cinn/hlir/dialect/operator/transforms/attr_adt_type_id.cc

This file was deleted.

104 changes: 0 additions & 104 deletions paddle/cinn/hlir/dialect/operator/transforms/attr_adt_type_id.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ class MergeParallelMatmulPattern
std::vector<std::int64_t>(b.begin(), b.end() - 1);
};

auto IsDynamicShape = [&](const std::vector<int64_t>& dims) {
return std::any_of(
dims.begin(), dims.end(), [](int64_t dim) { return dim < 0; });
};

auto input_x = matmul_op.operand_source(0);
const std::vector<pir::Operation*> merge_ops = [&]() {
std::vector<pir::Operation*> ret;
Expand All @@ -87,9 +82,6 @@ class MergeParallelMatmulPattern
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims());
if (IsDynamicShape(cur_dim)) {
continue;
}
if (VectorPrefixEqual(pre_dim.value(), cur_dim)) {
ret.push_back(it->owner());
}
Expand Down Expand Up @@ -127,7 +119,6 @@ class MergeParallelMatmulPattern
.result(0);

for (size_t i = 0; i < merge_ops.size(); ++i) {
rewriter.SetInsertionPointAfter(merge_ops[i]);
auto split_out = rewriter
.Build<paddle::dialect::SliceOp>(
matmul_out,
Expand Down
48 changes: 25 additions & 23 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -941,12 +941,9 @@ class SigmoidOpPattern
: public pir::OpRewritePattern<paddle::dialect::SigmoidOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SigmoidOp>::OpRewritePattern;
bool Match(paddle::dialect::SigmoidOp op) const override {
return !CompatibleInfo::IsDeniedForCinn(*op.operation());
}

void Rewrite(paddle::dialect::SigmoidOp op,
pir::PatternRewriter &rewriter) const override {
bool MatchAndRewrite(paddle::dialect::SigmoidOp op,
pir::PatternRewriter &rewriter) const override {
auto input_dtype = paddle::dialect::TransToPhiDataType(
op->operand_source(0)
.type()
Expand Down Expand Up @@ -979,41 +976,46 @@ class SigmoidOpPattern
}

rewriter.ReplaceAllUsesWith(op.result(0), div);

rewriter.EraseOp(op);

return true;
}
};
class GatherOpPattern
: public pir::OpRewritePattern<paddle::dialect::GatherOp> {
public:
using pir::OpRewritePattern<paddle::dialect::GatherOp>::OpRewritePattern;

bool Match(paddle::dialect::GatherOp op) const override {
const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation());
auto axis_gen_op = op->operand_source(2).defining_op();
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
return !is_denied && full_op;
}

void Rewrite(paddle::dialect::GatherOp op,
pir::PatternRewriter &rewriter) const override {
bool MatchAndRewrite(paddle::dialect::GatherOp op,
pir::PatternRewriter &rewriter) const override {
auto gather_op = op->dyn_cast<paddle::dialect::GatherOp>();
auto x = op.operand_source(0);
auto index = op->operand_source(1);
const int axis = [&]() -> int {
auto axis_gen_op = op.operand_source(2).defining_op();
PADDLE_ENFORCE_EQ(axis_gen_op->isa<paddle::dialect::FullOp>(),
true,
::phi::errors::InvalidArgument(
"Not Supported: The gather operator for CINN "
"only supports constant value"));
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
return static_cast<int>(
full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data());
int axis = 0;
if (gather_op->attributes().count("index")) {
axis =
gather_op.attribute("index").dyn_cast<pir::Int32Attribute>().data();
} else {
auto axis_gen_op = op.operand_source(2).defining_op();
PADDLE_ENFORCE_EQ(axis_gen_op->isa<paddle::dialect::FullOp>(),
true,
::phi::errors::InvalidArgument(
"Not Supported: The gather operator for CINN "
"only supports constant value"));
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
axis = static_cast<int>(full_op.attribute("value")
.dyn_cast<::pir::FloatAttribute>()
.data());
return axis;
}
}();
auto out =
rewriter.Build<cinn::dialect::GatherOp>(x, index, axis)->result(0);
rewriter.ReplaceAllUsesWith(op->result(0), out);
rewriter.EraseOp(op);
return true;
}
};

Expand Down

0 comments on commit f76bd1b

Please sign in to comment.