Skip to content

Commit

Permalink
[CINN] add constant fold for max/min dim expr
Browse files Browse the repository at this point in the history
Signed-off-by: ZelinMa557 <3388706467@qq.com>
  • Loading branch information
ZelinMa557 committed Apr 29, 2024
1 parent 71fd732 commit 7fd2942
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ bool DiagonalOpInferSymbolicShape(
res_shape = shape_analysis->GetNextSymName();
}
}
out_dims.push_back(res_shape);
out_dims.push_back(symbol::SimplifyDimExpr(res_shape));

symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(out_dims)};
Expand Down
78 changes: 78 additions & 0 deletions paddle/pir/src/dialect/shape/utils/dim_expr_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,82 @@ struct FoldOperandTrait<Add> {
}
};

template <>
struct FoldOperandTrait<Max> {
using const_value_type = std::int64_t;

static bool IsConstPattern(const DimExpr& dim_expr) {
if (dim_expr.Has<std::int64_t>()) {
return true;
}
if (dim_expr.Has<Negative<DimExpr>>()) {
const auto& operand = dim_expr.Get<Negative<DimExpr>>()->data;
return operand.Has<std::int64_t>();
}
return false;
}

static const_value_type MakeUnit() { return INT64_MIN; }
static void Accumulate(const_value_type* value, const DimExpr& expr) {
*value = std::max(*value, GetInteger(expr));
}
static bool IsUnit(const const_value_type& value) {
return value == INT64_MIN;
}
static bool IsUnitDimExpr(const DimExpr& dim_expr) {
if (!dim_expr.Has<std::int64_t>()) {
return false;
}
return dim_expr.Get<std::int64_t>() == INT64_MIN;
}
static void MakeAndAppendDimExpr(const const_value_type& value,
List<DimExpr>* ret) {
(*ret)->emplace_back(value);
}

static bool IsInversedPair(const DimExpr& lhs, const DimExpr& rhs) {
return false;
}
};

template <>
struct FoldOperandTrait<Min> {
using const_value_type = std::int64_t;

static bool IsConstPattern(const DimExpr& dim_expr) {
if (dim_expr.Has<std::int64_t>()) {
return true;
}
if (dim_expr.Has<Negative<DimExpr>>()) {
const auto& operand = dim_expr.Get<Negative<DimExpr>>()->data;
return operand.Has<std::int64_t>();
}
return false;
}

static const_value_type MakeUnit() { return INT64_MAX; }
static void Accumulate(const_value_type* value, const DimExpr& expr) {
*value = std::min(*value, GetInteger(expr));
}
static bool IsUnit(const const_value_type& value) {
return value == INT64_MAX;
}
static bool IsUnitDimExpr(const DimExpr& dim_expr) {
if (!dim_expr.Has<std::int64_t>()) {
return false;
}
return dim_expr.Get<std::int64_t>() == INT64_MAX;
}
static void MakeAndAppendDimExpr(const const_value_type& value,
List<DimExpr>* ret) {
(*ret)->emplace_back(value);
}

static bool IsInversedPair(const DimExpr& lhs, const DimExpr& rhs) {
return false;
}
};

using ConstRational = std::pair<std::int64_t, std::int64_t>;

ConstRational SimplifiedConstRational(int64_t num, int64_t dem) {
Expand Down Expand Up @@ -903,6 +979,8 @@ DimExpr Simplify(const DimExpr& expr) {
DoPass<FoldUnitConstant<Broadcast>>(&keep_rewrite, &ret);
DoPass<FoldConstants<Add>>(&keep_rewrite, &ret);
DoPass<FoldConstants<Mul>>(&keep_rewrite, &ret);
DoPass<FoldConstants<Max>>(&keep_rewrite, &ret);
DoPass<FoldConstants<Min>>(&keep_rewrite, &ret);
DoPass<FoldConstants<Broadcast>>(&keep_rewrite, &ret);
DoPass<FoldInversedPairToUnit<Add>>(&keep_rewrite, &ret);
DoPass<FoldInversedPairToUnit<Mul>>(&keep_rewrite, &ret);
Expand Down
16 changes: 16 additions & 0 deletions test/cpp/pir/shape_dialect/simplify_dim_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,20 @@ TEST(Simplify, NestSymbolicMulAddUnit) {
ASSERT_TRUE((simplified_dim_expr.Has<std::string>()));
ASSERT_TRUE((simplified_dim_expr == sym));
}

TEST(Simplify, ConstantMaxMin) {
List<DimExpr> max_lists{DimExpr(4), DimExpr(6)};
DimExpr dim_expr1{Max<DimExpr>{max_lists}};

DimExpr simplified_dim_expr1 = SimplifyDimExpr(dim_expr);
ASSERT_TRUE((simplified_dim_expr1.Has<std::int64_t>()));
ASSERT_EQ((simplified_dim_expr1.Get<std::int64_t>()), 6);

List<DimExpr> min_lists{DimExpr(2), DimExpr(3)};
DimExpr dim_expr2{Min<DimExpr>{min_lists}};

DimExpr simplified_dim_expr2 = SimplifyDimExpr(dim_expr);
ASSERT_TRUE((simplified_dim_expr2.Has<std::int64_t>()));
ASSERT_EQ((simplified_dim_expr2.Get<std::int64_t>()), 2);
}
} // namespace symbol::test
4 changes: 2 additions & 2 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def prepare_data(self):
self.cases = [np.random.rand(4, 5, 6)]
self.expected = [
[
'shape[3, Min(2, 2)], data[NULL]',
'shape[2, Min(3, 2)], data[NULL]',
'shape[3, 2], data[NULL]',
'shape[2, 2], data[NULL]',
'shape[S2, Min(S0, S1)], data[NULL]',
'shape[S0, Min(S2, S1)], data[NULL]',
'shape[S0, S3], data[NULL]',
Expand Down

0 comments on commit 7fd2942

Please sign in to comment.