Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add symbolic select #63971

Merged
merged 5 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@
func : scale
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : select
args : (Tensor condition, Tensor true_value, Tensor false_value )
output : Tensor(out)
infer_meta :
func : WhereInferMeta
spmd_rule: WhereInferSpmd
kernel :
func : where
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : slice
args : (Tensor x, int64_t[] axes, int64_t[] starts, int64_t[] ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
Expand Down
19 changes: 19 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,24 @@ class SigmoidOpPattern
rewriter.EraseOp(op);
}
};

class WhereOpPattern : public pir::OpRewritePattern<paddle::dialect::WhereOp> {
public:
using pir::OpRewritePattern<paddle::dialect::WhereOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::WhereOp op,
pir::PatternRewriter &rewriter) const override {
auto select_op = rewriter.Build<cinn::dialect::SelectOp>(
op->operand_source(0), op->operand_source(1), op->operand_source(2));

rewriter.ReplaceAllUsesWith(op.result(0), select_op.result(0));

rewriter.EraseOp(op);

return true;
}
};

class GatherOpPattern
: public pir::OpRewritePattern<paddle::dialect::GatherOp> {
public:
Expand Down Expand Up @@ -1132,6 +1150,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<UnsqueezeOpPattern>(context);
ps.Add<SigmoidOpPattern>(context);
ps.Add<GatherOpPattern>(context);
ps.Add<WhereOpPattern>(context);
ps.Add<FlattenOpPattern>(context);

return ps;
Expand Down
66 changes: 66 additions & 0 deletions paddle/cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,70 @@ std::shared_ptr<OpStrategy> StrategyForSelect(
return strategy;
}

std::shared_ptr<OpStrategy> StrategyForSelectSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
framework::CINNCompute select_compute([=](lang::Args args,
lang::RetValue *ret) {
PADDLE_ENFORCE_EQ(
!args.empty(),
true,
::common::errors::InvalidArgument(
"The input argument of select compute is empty! Please check."));
CINNValuePack pack_args = args[0];
PADDLE_ENFORCE_GE(pack_args.size(),
3U,
::common::errors::InvalidArgument(
"at least three input tensor for select compute."));
Expr condition = pack_args[0];
Expr true_value = pack_args[1];
Expr false_value = pack_args[2];
PADDLE_ENFORCE_NE(condition.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The condation arg's type should be Tensor."));
PADDLE_ENFORCE_NE(true_value.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The true_value arg's type should be Tensor."));
PADDLE_ENFORCE_NE(false_value.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The false_value arg's type should be Tensor."));
PADDLE_ENFORCE_EQ(pack_args.size(),
4U,
::common::errors::InvalidArgument(
"The size of inputs must be equal to 4."));
PADDLE_ENFORCE_EQ(pack_args[3].is_string(),
true,
::common::errors::InvalidArgument(
"The name arg's type should be string."));
std::string tensor_name = pack_args[3].operator std::string();

auto out = pe::Select(condition.as_tensor_ref(),
true_value.as_tensor_ref(),
false_value.as_tensor_ref(),
tensor_name);
auto stages = CreateStages({condition.as_tensor_ref(),
true_value.as_tensor_ref(),
false_value.as_tensor_ref(),
out});
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
PADDLE_ENFORCE_NE(out_type.size(),
0U,
::common::errors::InvalidArgument(
"Out_type of select op is empty! Please check."));
strategy->AddImpl(
select_compute, lang::PackedFunc(), "strategy.select.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForSelect(
const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
Expand Down Expand Up @@ -2674,6 +2738,8 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForSelect)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic", cinn::hlir::op::StrategyForSelectSymbolic)
.set_attr("infershape",
MakeOpFunction(cinn::hlir::op::InferShapeForSelect))
.set_attr("inferdtype",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ OP_SAME_OPERANDS_AND_RESULT(Scale_)
OP_SAME_OPERANDS_AND_RESULT(ScatterNdAdd)
OP_SAME_OPERANDS_AND_RESULT(Scatter)
OP_SAME_OPERANDS_AND_RESULT(Scatter_)
OP_SAME_OPERANDS_AND_RESULT(Select)
OP_SAME_OPERANDS_AND_RESULT(Sign)
OP_SAME_OPERANDS_AND_RESULT(Sin)
OP_SAME_OPERANDS_AND_RESULT(Sin_)
Expand Down Expand Up @@ -178,6 +179,7 @@ bool ScaleOpInferSymbolicShape(pir::Operation *op,
namespace cinn::dialect {
using paddle::dialect::ReverseOpInferSymbolicShape;
using paddle::dialect::ScaleOpInferSymbolicShape;
using paddle::dialect::SelectOpInferSymbolicShape;
} // namespace cinn::dialect

#undef OP_SAME_OPERANDS_AND_RESULT
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scale_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScatterNdAdd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_)
Expand All @@ -134,4 +135,5 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sigmoid_)
namespace cinn::dialect {
using paddle::dialect::ReverseOpInferSymbolicShape;
using paddle::dialect::ScaleOpInferSymbolicShape;
using paddle::dialect::SelectOpInferSymbolicShape;
} // namespace cinn::dialect
52 changes: 42 additions & 10 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4413,20 +4413,52 @@ void WhereInferMeta(const MetaTensor& condition,
auto x_dims = x.dims();
auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(
cond_dims,
x_dims,
cond_dims.size(),
x_dims.size(),
phi::errors::InvalidArgument(
"The dims of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%s], X's shape is [%s]",
cond_dims,
x_dims));
PADDLE_ENFORCE_EQ(x_dims,
y_dims,
"But received Condition's rank is [%d], X's rank is [%d]",
cond_dims.size(),
x_dims.size()));
size_t cond_dims_size = static_cast<size_t>(cond_dims.size());
for (size_t i = 0; i < cond_dims_size; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对于cond_dim、x_dim、y_dim的比较逻辑看起来是一致的,可以考虑使用lambda函数服用逻辑

if (cond_dims[i] == -1 || x_dims[i] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(
cond_dims[i],
x_dims[i],
phi::errors::InvalidArgument(
"The [%d] th of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%d], X's shape is [%d]",
i,
cond_dims[i],
x_dims[i]));
}

PADDLE_ENFORCE_EQ(x_dims.size(),
y_dims.size(),
phi::errors::InvalidArgument(
"The dims of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]",
x_dims,
y_dims));
"But received X's shape is [%d], Y's shape is [%d]",
x_dims.size(),
y_dims.size()));
size_t x_dims_size = static_cast<size_t>(x_dims.size());
for (size_t i = 0; i < x_dims_size; ++i) {
if (x_dims[i] == -1 || y_dims[i] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(
x_dims[i],
y_dims[i],
phi::errors::InvalidArgument(
"The [%d] th of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]",
i,
x_dims[i],
y_dims[i]));
}

out->share_meta(x);
}

Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/cinn/inference/test_llama_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def prepare_data(self):
self.input_ids = paddle.randint(0, 512, [1, 32], dtype="int64")

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 10)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 10})
utils.check_jit_kernel_number(static_fn, 8)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 8})

def eval(self, use_cinn):
paddle.seed(2024)
Expand Down
71 changes: 71 additions & 0 deletions test/ir/pir/cinn/sub_graphs/test_sub_graph_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle


class WhereCase(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(
self,
cond,
true_branch,
false_branch,
):
return paddle.tensor.where(cond, true_branch.sin(), false_branch)


class TestWhere(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议继承 base.py 里的TestBase,只重写 init函数逻辑即可。

def setUp(self):
self.inputs = (
paddle.rand(shape=[16, 16], dtype=paddle.float32).cast("bool"),
paddle.rand(shape=[16, 16], dtype=paddle.float32),
paddle.rand(shape=[16, 16], dtype=paddle.float32),
)

def train(self, net, to_static, with_prim=False, with_cinn=False):
if to_static:
paddle.set_flags({'FLAGS_prim_all': with_prim})
if with_cinn:
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = True
net = paddle.jit.to_static(
net, build_strategy=build_strategy, full_graph=True
)
else:
net = paddle.jit.to_static(net, full_graph=True)
paddle.seed(123)
outs = net(*self.inputs)
return outs

def test_where(self):
net = WhereCase()
st_out = self.train(net, to_static=True)
cinn_out = self.train(
net, to_static=True, with_prim=True, with_cinn=True
)
for st, cinn in zip(
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
):
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8)


if __name__ == '__main__':
unittest.main()