Skip to content

Commit

Permalink
[CINN] add symbolic select
Browse files Browse the repository at this point in the history
  • Loading branch information
hxzd5568 committed May 8, 2024
1 parent 444e08f commit 50500ec
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 8 deletions.
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@
args : (Tensor condition, Tensor true_value, Tensor false_value )
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [condition]
func : WhereInferMeta
spmd_rule: WhereInferSpmd
kernel :
func : where
interfaces : paddle::dialect::InferSymbolicShapeInterface
Expand Down
66 changes: 66 additions & 0 deletions paddle/cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,70 @@ std::shared_ptr<OpStrategy> StrategyForSelect(
return strategy;
}

std::shared_ptr<OpStrategy> StrategyForSelectSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
framework::CINNCompute select_compute([=](lang::Args args,
lang::RetValue *ret) {
PADDLE_ENFORCE_EQ(
!args.empty(),
true,
::common::errors::InvalidArgument(
"The input argument of select compute is empty! Please check."));
CINNValuePack pack_args = args[0];
PADDLE_ENFORCE_GE(pack_args.size(),
3U,
::common::errors::InvalidArgument(
"at least three input tensor for select compute."));
Expr condition = pack_args[0];
Expr true_value = pack_args[1];
Expr false_value = pack_args[2];
PADDLE_ENFORCE_NE(condition.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The condation arg's type should be Tensor."));
PADDLE_ENFORCE_NE(true_value.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The true_value arg's type should be Tensor."));
PADDLE_ENFORCE_NE(false_value.as_tensor(),
nullptr,
::common::errors::InvalidArgument(
"The false_value arg's type should be Tensor."));
PADDLE_ENFORCE_EQ(pack_args.size(),
4U,
::common::errors::InvalidArgument(
"The size of inputs must be equal to 4."));
PADDLE_ENFORCE_EQ(pack_args[3].is_string(),
true,
::common::errors::InvalidArgument(
"The name arg's type should be string."));
std::string tensor_name = pack_args[3].operator std::string();

auto out = pe::Select(condition.as_tensor_ref(),
true_value.as_tensor_ref(),
false_value.as_tensor_ref(),
tensor_name);
auto stages = CreateStages({condition.as_tensor_ref(),
true_value.as_tensor_ref(),
false_value.as_tensor_ref(),
out});
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
PADDLE_ENFORCE_NE(out_type.size(),
0U,
::common::errors::InvalidArgument(
"Out_type of select op is empty! Please check."));
strategy->AddImpl(
select_compute, lang::PackedFunc(), "strategy.select.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForSelect(
const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
Expand Down Expand Up @@ -2674,6 +2738,8 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForSelect)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic", cinn::hlir::op::StrategyForSelectSymbolic)
.set_attr("infershape",
MakeOpFunction(cinn::hlir::op::InferShapeForSelect))
.set_attr("inferdtype",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,30 @@ bool Where_OpInferSymbolicShape(pir::Operation *op,
return WhereOpInferSymbolicShape(op, infer_context);
}

bool SelectOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
// The shape of output is the same as input `values` (op->operand_source(1))
shape_analysis->SetShapeOrDataForValue(
op->result(0),
shape_analysis->GetShapeOrDataForValue(op->operand_source(0)));

const std::vector<pir::Value> &operands = {op->operand_source(0),
op->operand_source(1)};

size_t rank = shape_analysis->GetShapeOrDataForValue(op->operand_source(0))
.shape()
.size();

for (size_t i = 0; i < rank; ++i) {
paddle::dialect::details::BuildCstrEqForTensorListAlongAxis(
shape_analysis, operands, i);
}

return true;
}

} // namespace paddle::dialect

namespace cinn::dialect {
using paddle::dialect::SelectOpInferSymbolicShape;
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stack)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select)

} // namespace paddle::dialect

namespace cinn::dialect {
using paddle::dialect::SelectOpInferSymbolicShape;
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ OP_SAME_OPERANDS_AND_RESULT(Scale_)
OP_SAME_OPERANDS_AND_RESULT(ScatterNdAdd)
OP_SAME_OPERANDS_AND_RESULT(Scatter)
OP_SAME_OPERANDS_AND_RESULT(Scatter_)
OP_SAME_OPERANDS_AND_RESULT(Select)
OP_SAME_OPERANDS_AND_RESULT(Sign)
OP_SAME_OPERANDS_AND_RESULT(Sin)
OP_SAME_OPERANDS_AND_RESULT(Sin_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scale_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScatterNdAdd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_)
Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/cinn/inference/test_llama_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def prepare_data(self):
self.input_ids = paddle.randint(0, 512, [1, 32], dtype="int64")

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 10)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 10})
utils.check_jit_kernel_number(static_fn, 8)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 8})

def eval(self, use_cinn):
paddle.seed(2024)
Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/cinn/sub_graphs/test_sub_graph_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ def test_where(self):
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8)


# if __name__ == '__main__':
# unittest.main()
if __name__ == '__main__':
unittest.main()

0 comments on commit 50500ec

Please sign in to comment.