Skip to content

Commit

Permalink
[CINN / Symbolic]Fix gather infer symbolic bugs (PaddlePaddle#63973)
Browse files Browse the repository at this point in the history
* fix gather infersymbolic

* fix

* fix

* fix
  • Loading branch information
2742195759 authored and co63oc committed May 10, 2024
1 parent 590ec9a commit 9ba186f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,22 @@ bool GatherOpInferSymbolicShape(
return numel;
}();

const auto &axis_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(2));
int axis = 0;
const auto &attributes = op->attributes();
if (op->HasAttribute("axis")) { // CINN Dialect
axis = attributes.at("axis").dyn_cast<pir::Int32Attribute>().data();
} else {
PADDLE_ENFORCE_EQ(
op->num_operands() == 3,
true,
phi::errors::InvalidArgument(
"in GatherOpInferSymbolicShape: The number of operands should be "
"3 when the axis is not set."));
const auto &axis_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(2));
axis =
static_cast<int>(axis_shape_or_data.data().value()[0].Get<int64_t>());
}

const std::vector<symbol::DimExpr> &input_sym_shape =
input_shape_or_data.data().has_value()
Expand All @@ -248,8 +262,6 @@ bool GatherOpInferSymbolicShape(
? index_shape_or_data.data().value()
: index_shape_or_data.shape();

int axis =
static_cast<int>(axis_shape_or_data.data().value()[0].Get<int64_t>());
if (axis < 0) axis += input_sym_shape.size();

const auto &out_sym_shape = [&] {
Expand Down
73 changes: 37 additions & 36 deletions test/ir/pir/cinn/symbolic/test_cinn_transform_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,42 +79,43 @@ def test_eval(self):
)


# class TestGatherAxisPosSymbolic(unittest.TestCase):
# def setUp(self):
# paddle.seed(2022)
# self.prepare_data()
#
# def prepare_data(self):
# self.shape = [None, 4 ]
# self.x = paddle.randn(self.shape, dtype="float32")
# self.x.stop_gradient = True
# self.index = paddle.to_tensor([1])
# self.index.stop_gradient = True
#
# def check_jit_kernel_info(self, static_fn):
# utils.check_jit_kernel_number(static_fn, 1)
# utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})
#
# def eval(self, use_cinn):
# net = GatherLayerAxisPos()
# input_spec = [
# InputSpec(shape=[None, 4], dtype='float32'),
# InputSpec(shape=[1], dtype='int32'),
# ]
# net = utils.apply_to_static(net, use_cinn, input_spec)
# net.eval()
# out = net(self.x, self.index)
# if use_cinn:
# self.check_jit_kernel_info(net.forward)
# return out
#
# def test_eval(self):
# cinn_out = self.eval(use_cinn=True)
# dy_out = self.eval(use_cinn=False)
# np.testing.assert_allclose(
# cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
# )
#
class TestGatherAxisPosSymbolic(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.shape = [32, 4]
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = True
self.index = paddle.to_tensor([1])
self.index.stop_gradient = True

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def eval(self, use_cinn):
net = GatherLayerAxisPos()
input_spec = [
InputSpec(shape=[None, 4], dtype='float32'),
InputSpec(shape=[1], dtype='int32'),
]
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x, self.index)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval(self):
cinn_out = self.eval(use_cinn=True)
dy_out = self.eval(use_cinn=False)
np.testing.assert_allclose(
cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
)


class TestGatherAxisNegStatic(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
Expand Down

0 comments on commit 9ba186f

Please sign in to comment.