Skip to content

Commit

Permalink
[CINN] Add symbolic select (#63971)
Browse files Browse the repository at this point in the history
* fix where op infer meta check

* convert where to cinn select

* convert where to cinn select

* [CINN] add symbolic select for where

---------

Co-authored-by: phlrain <phliuhongyu@126.com>
  • Loading branch information
hxzd5568 and phlrain committed May 10, 2024
1 parent 7fbd868 commit 8b03a31
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 12 deletions.
10 changes: 10 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@
func : scale
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : select
args : (Tensor condition, Tensor true_value, Tensor false_value )
output : Tensor(out)
infer_meta :
func : WhereInferMeta
spmd_rule: WhereInferSpmd
kernel :
func : where
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : slice
args : (Tensor x, int64_t[] axes, int64_t[] starts, int64_t[] ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
Expand Down
19 changes: 19 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,24 @@ class SigmoidOpPattern
rewriter.EraseOp(op);
}
};

class WhereOpPattern : public pir::OpRewritePattern<paddle::dialect::WhereOp> {
public:
using pir::OpRewritePattern<paddle::dialect::WhereOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::WhereOp op,
pir::PatternRewriter &rewriter) const override {
auto select_op = rewriter.Build<cinn::dialect::SelectOp>(
op->operand_source(0), op->operand_source(1), op->operand_source(2));

rewriter.ReplaceAllUsesWith(op.result(0), select_op.result(0));

rewriter.EraseOp(op);

return true;
}
};

class GatherOpPattern
: public pir::OpRewritePattern<paddle::dialect::GatherOp> {
public:
Expand Down Expand Up @@ -1132,6 +1150,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<UnsqueezeOpPattern>(context);
ps.Add<SigmoidOpPattern>(context);
ps.Add<GatherOpPattern>(context);
ps.Add<WhereOpPattern>(context);
ps.Add<FlattenOpPattern>(context);

return ps;
Expand Down
66 changes: 66 additions & 0 deletions paddle/cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,70 @@ std::shared_ptr<OpStrategy> StrategyForSelect(
return strategy;
}

std::shared_ptr<OpStrategy> StrategyForSelectSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
framework::CINNCompute select_compute([=](lang::Args args,
lang::RetValue *ret) {
PADDLE_ENFORCE_EQ(
!args.empty(),
true,
::common::errors::InvalidArgument(
"The input argument of select compute is empty! Please check."));
CINNValuePack pack_args = args[0];
PADDLE_ENFORCE_GE(pack_args.size(),
3U,
::common::errors::InvalidArgument(
"at least three input tensor for select compute."));
Expr condition = pack_args[0];
Expr true_value = pack_args[1];
Expr false_value = pack_args[2];
PADDLE_ENFORCE_NE(condition.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The condation arg's type should be Tensor."));
PADDLE_ENFORCE_NE(true_value.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The true_value arg's type should be Tensor."));
PADDLE_ENFORCE_NE(false_value.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The false_value arg's type should be Tensor."));
PADDLE_ENFORCE_EQ(pack_args.size(),
4U,
::common::errors::InvalidArgument(
"The size of inputs must be equal to 4."));
PADDLE_ENFORCE_EQ(pack_args[3].is_string(),
true,
::common::errors::InvalidArgument(
"The name arg's type should be string."));
std::string tensor_name = pack_args[3].operator std::string();

auto out = pe::Select(condition.as_tensor_ref(),
true_value.as_tensor_ref(),
false_value.as_tensor_ref(),
tensor_name);
auto stages = CreateStages({condition.as_tensor_ref(),
true_value.as_tensor_ref(),
false_value.as_tensor_ref(),
out});
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
PADDLE_ENFORCE_NE(out_type.size(),
0U,
::common::errors::InvalidArgument(
"Out_type of select op is empty! Please check."));
strategy->AddImpl(
select_compute, lang::PackedFunc(), "strategy.select.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForSelect(
const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
Expand Down Expand Up @@ -2674,6 +2738,8 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForSelect)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic", cinn::hlir::op::StrategyForSelectSymbolic)
.set_attr("infershape",
MakeOpFunction(cinn::hlir::op::InferShapeForSelect))
.set_attr("inferdtype",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ OP_SAME_OPERANDS_AND_RESULT(Scale_)
OP_SAME_OPERANDS_AND_RESULT(ScatterNdAdd)
OP_SAME_OPERANDS_AND_RESULT(Scatter)
OP_SAME_OPERANDS_AND_RESULT(Scatter_)
OP_SAME_OPERANDS_AND_RESULT(Select)
OP_SAME_OPERANDS_AND_RESULT(Sign)
OP_SAME_OPERANDS_AND_RESULT(Sin)
OP_SAME_OPERANDS_AND_RESULT(Sin_)
Expand Down Expand Up @@ -178,6 +179,7 @@ bool ScaleOpInferSymbolicShape(pir::Operation *op,
namespace cinn::dialect {
using paddle::dialect::ReverseOpInferSymbolicShape;
using paddle::dialect::ScaleOpInferSymbolicShape;
using paddle::dialect::SelectOpInferSymbolicShape;
} // namespace cinn::dialect

#undef OP_SAME_OPERANDS_AND_RESULT
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scale_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScatterNdAdd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_)
Expand All @@ -134,4 +135,5 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sigmoid_)
namespace cinn::dialect {
using paddle::dialect::ReverseOpInferSymbolicShape;
using paddle::dialect::ScaleOpInferSymbolicShape;
using paddle::dialect::SelectOpInferSymbolicShape;
} // namespace cinn::dialect
52 changes: 42 additions & 10 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4524,20 +4524,52 @@ void WhereInferMeta(const MetaTensor& condition,
auto x_dims = x.dims();
auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(
cond_dims,
x_dims,
cond_dims.size(),
x_dims.size(),
phi::errors::InvalidArgument(
"The dims of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%s], X's shape is [%s]",
cond_dims,
x_dims));
PADDLE_ENFORCE_EQ(x_dims,
y_dims,
"But received Condition's rank is [%d], X's rank is [%d]",
cond_dims.size(),
x_dims.size()));
size_t cond_dims_size = static_cast<size_t>(cond_dims.size());
for (size_t i = 0; i < cond_dims_size; ++i) {
if (cond_dims[i] == -1 || x_dims[i] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(
cond_dims[i],
x_dims[i],
phi::errors::InvalidArgument(
"The [%d] th of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%d], X's shape is [%d]",
i,
cond_dims[i],
x_dims[i]));
}

PADDLE_ENFORCE_EQ(x_dims.size(),
y_dims.size(),
phi::errors::InvalidArgument(
"The dims of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]",
x_dims,
y_dims));
"But received X's shape is [%d], Y's shape is [%d]",
x_dims.size(),
y_dims.size()));
size_t x_dims_size = static_cast<size_t>(x_dims.size());
for (size_t i = 0; i < x_dims_size; ++i) {
if (x_dims[i] == -1 || y_dims[i] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(
x_dims[i],
y_dims[i],
phi::errors::InvalidArgument(
"The [%d] th of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]",
i,
x_dims[i],
y_dims[i]));
}

out->share_meta(x);
}

Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/cinn/inference/test_llama_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def prepare_data(self):
self.input_ids = paddle.randint(0, 512, [1, 32], dtype="int64")

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 10)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 10})
utils.check_jit_kernel_number(static_fn, 8)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 8})

def eval(self, use_cinn):
paddle.seed(2024)
Expand Down
71 changes: 71 additions & 0 deletions test/ir/pir/cinn/sub_graphs/test_sub_graph_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle


class WhereCase(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(
self,
cond,
true_branch,
false_branch,
):
return paddle.tensor.where(cond, true_branch.sin(), false_branch)


class TestWhere(unittest.TestCase):
def setUp(self):
self.inputs = (
paddle.rand(shape=[16, 16], dtype=paddle.float32).cast("bool"),
paddle.rand(shape=[16, 16], dtype=paddle.float32),
paddle.rand(shape=[16, 16], dtype=paddle.float32),
)

def train(self, net, to_static, with_prim=False, with_cinn=False):
if to_static:
paddle.set_flags({'FLAGS_prim_all': with_prim})
if with_cinn:
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = True
net = paddle.jit.to_static(
net, build_strategy=build_strategy, full_graph=True
)
else:
net = paddle.jit.to_static(net, full_graph=True)
paddle.seed(123)
outs = net(*self.inputs)
return outs

def test_where(self):
net = WhereCase()
st_out = self.train(net, to_static=True)
cinn_out = self.train(
net, to_static=True, with_prim=True, with_cinn=True
)
for st, cinn in zip(
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
):
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8)


if __name__ == '__main__':
unittest.main()

0 comments on commit 8b03a31

Please sign in to comment.