Skip to content

Commit

Permalink
[CINN] add symbolic select for where
Browse files Browse the repository at this point in the history
  • Loading branch information
hxzd5568 committed May 9, 2024
1 parent fe0bb4a commit 5e1ad8a
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 10 deletions.
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@
args : (Tensor condition, Tensor true_value, Tensor false_value )
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [condition]
func : WhereInferMeta
spmd_rule: WhereInferSpmd
kernel :
func : where
interfaces : paddle::dialect::InferSymbolicShapeInterface
Expand Down
66 changes: 66 additions & 0 deletions paddle/cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,70 @@ std::shared_ptr<OpStrategy> StrategyForSelect(
return strategy;
}

std::shared_ptr<OpStrategy> StrategyForSelectSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
framework::CINNCompute select_compute([=](lang::Args args,
lang::RetValue *ret) {
PADDLE_ENFORCE_EQ(
!args.empty(),
true,
::common::errors::InvalidArgument(
"The input argument of select compute is empty! Please check."));
CINNValuePack pack_args = args[0];
PADDLE_ENFORCE_GE(pack_args.size(),
3U,
::common::errors::InvalidArgument(
"at least three input tensor for select compute."));
Expr condition = pack_args[0];
Expr true_value = pack_args[1];
Expr false_value = pack_args[2];
PADDLE_ENFORCE_NE(condition.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The condation arg's type should be Tensor."));
PADDLE_ENFORCE_NE(true_value.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The true_value arg's type should be Tensor."));
PADDLE_ENFORCE_NE(false_value.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The false_value arg's type should be Tensor."));
PADDLE_ENFORCE_EQ(pack_args.size(),
4U,
::common::errors::InvalidArgument(
"The size of inputs must be equal to 4."));
PADDLE_ENFORCE_EQ(pack_args[3].is_string(),
true,
::common::errors::InvalidArgument(
"The name arg's type should be string."));
std::string tensor_name = pack_args[3].operator std::string();

auto out = pe::Select(condition.as_tensor_ref(),
true_value.as_tensor_ref(),
false_value.as_tensor_ref(),
tensor_name);
auto stages = CreateStages({condition.as_tensor_ref(),
true_value.as_tensor_ref(),
false_value.as_tensor_ref(),
out});
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
PADDLE_ENFORCE_NE(out_type.size(),
0U,
::common::errors::InvalidArgument(
"Out_type of select op is empty! Please check."));
strategy->AddImpl(
select_compute, lang::PackedFunc(), "strategy.select.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForSelect(
const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
Expand Down Expand Up @@ -2674,6 +2738,8 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForSelect)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic", cinn::hlir::op::StrategyForSelectSymbolic)
.set_attr("infershape",
MakeOpFunction(cinn::hlir::op::InferShapeForSelect))
.set_attr("inferdtype",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ bool ScaleOpInferSymbolicShape(pir::Operation *op,
namespace cinn::dialect {
using paddle::dialect::ReverseOpInferSymbolicShape;
using paddle::dialect::ScaleOpInferSymbolicShape;
using paddle::dialect::SelectOpInferSymbolicShape;
} // namespace cinn::dialect

#undef OP_SAME_OPERANDS_AND_RESULT
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,5 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sigmoid_)
namespace cinn::dialect {
using paddle::dialect::ReverseOpInferSymbolicShape;
using paddle::dialect::ScaleOpInferSymbolicShape;
using paddle::dialect::SelectOpInferSymbolicShape;
} // namespace cinn::dialect
8 changes: 4 additions & 4 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4357,8 +4357,8 @@ void WhereInferMeta(const MetaTensor& condition,
"But received Condition's rank is [%d], X's rank is [%d]",
cond_dims.size(),
x_dims.size()));

for (size_t i = 0; i < cond_dims.size(); ++i) {
size_t cond_dims_size = static_cast<size_t>(cond_dims.size());
for (size_t i = 0; i < cond_dims_size; ++i) {
if (cond_dims[i] == -1 || x_dims[i] == -1) {
continue;
}
Expand All @@ -4380,8 +4380,8 @@ void WhereInferMeta(const MetaTensor& condition,
"But received X's shape is [%d], Y's shape is [%d]",
x_dims.size(),
y_dims.size()));

for (size_t i = 0; i < x_dims.size(); ++i) {
size_t x_dims_size = static_cast<size_t>(x_dims.size());
for (size_t i = 0; i < x_dims_size; ++i) {
if (x_dims[i] == -1 || y_dims[i] == -1) {
continue;
}
Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/cinn/inference/test_llama_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def prepare_data(self):
self.input_ids = paddle.randint(0, 512, [1, 32], dtype="int64")

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 10)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 10})
utils.check_jit_kernel_number(static_fn, 8)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 8})

def eval(self, use_cinn):
paddle.seed(2024)
Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/cinn/sub_graphs/test_sub_graph_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ def test_where(self):
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8)


# if __name__ == '__main__':
# unittest.main()
if __name__ == '__main__':
unittest.main()

0 comments on commit 5e1ad8a

Please sign in to comment.