Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 committed Apr 29, 2024
1 parent 3bb5bb4 commit e3449b4
Showing 1 changed file with 2 additions and 38 deletions.
40 changes: 2 additions & 38 deletions test/ir/pir/cinn/symbolic/test_cinn_transform_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def forward(self, x, index):
return paddle.gather(x, index, axis=-1)


class TestGatherAxisPosStatic(unittest.TestCase):
class TestGatherAxisPosSymbolic(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.prepare_data()
Expand All @@ -61,7 +61,7 @@ def check_jit_kernel_info(self, static_fn):
def eval(self, use_cinn):
net = GatherLayerAxisPos()
input_spec = [
InputSpec(shape=[32, 4], dtype='float32'),
InputSpec(shape=[None, 4], dtype='float32'),
InputSpec(shape=[1], dtype='int32'),
]
net = utils.apply_to_static(net, use_cinn, input_spec)
Expand All @@ -79,42 +79,6 @@ def test_eval(self):
)


# class TestGatherAxisPosSymbolic(unittest.TestCase):
# def setUp(self):
# paddle.seed(2022)
# self.prepare_data()
#
# def prepare_data(self):
# self.shape = [None, 4 ]
# self.x = paddle.randn(self.shape, dtype="float32")
# self.x.stop_gradient = True
# self.index = paddle.to_tensor([1])
# self.index.stop_gradient = True
#
# def check_jit_kernel_info(self, static_fn):
# utils.check_jit_kernel_number(static_fn, 1)
# utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})
#
# def eval(self, use_cinn):
# net = GatherLayerAxisPos()
# input_spec = [
# InputSpec(shape=[None, 4], dtype='float32'),
# InputSpec(shape=[1], dtype='int32'),
# ]
# net = utils.apply_to_static(net, use_cinn, input_spec)
# net.eval()
# out = net(self.x, self.index)
# if use_cinn:
# self.check_jit_kernel_info(net.forward)
# return out
#
# def test_eval(self):
# cinn_out = self.eval(use_cinn=True)
# dy_out = self.eval(use_cinn=False)
# np.testing.assert_allclose(
# cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
# )
#
class TestGatherAxisNegStatic(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
Expand Down

0 comments on commit e3449b4

Please sign in to comment.