Skip to content

Commit

Permalink
Fix NULL dereference in binary CPU ops (#115241)
Browse files Browse the repository at this point in the history
* Fix NULL dereference in binary CPU ops (#115183)

Targeted fix for #113037

A more fundamental one, where those functions are not even called for
empty tensors are coming later

Pull Request resolved: #115183
Approved by: https://github.com/drisspg, https://github.com/atalman, https://github.com/huydhn

* Fix build after conflict resolution

* Also include #113262 to pass the test

---------

Co-authored-by: Nikita Shulga <nshulga@meta.com>
  • Loading branch information
huydhn and malfet committed Dec 6, 2023
1 parent 5965649 commit 448700d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 29 deletions.
68 changes: 41 additions & 27 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -101,7 +101,7 @@ void mul_kernel(TensorIteratorBase& iter) {
using comp_t = c10::complex<float>;
return comp_t{a} * comp_t{b};
});
} else if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
} else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "mul_cpu_reduced_float", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t b = iter.original_scalar_value<opmath_t>(2);
Expand All @@ -125,7 +125,7 @@ void mul_kernel(TensorIteratorBase& iter) {

void div_true_kernel(TensorIteratorBase& iter) {
const auto dtype = iter.common_dtype();
if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_cpu_reduced_float", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t b = iter.original_scalar_value<opmath_t>(2);
Expand Down Expand Up @@ -162,19 +162,28 @@ void div_trunc_kernel(TensorIteratorBase& iter) {
return a / b;
});
});
} else if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_trunc_cpu_reduced_float", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t b = iter.original_scalar_value<opmath_t>(2);
iter.remove_operand(2);
cpu_kernel_vec(iter,
[=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return std::trunc(static_cast<opmath_t>(a) / b);
},
[=](Vectorized<scalar_t> a) {
return binary_op_scalar(a, b, [](const Vectorized<opmath_t>& x, const Vectorized<opmath_t>& y) { return (x / y).trunc(); });
} else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
dtype, "div_trunc_cpu_reduced_float", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t b = iter.original_scalar_value<opmath_t>(2);
iter.remove_operand(2);
cpu_kernel_vec(
iter,
[=](scalar_t a)
__ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return std::trunc(static_cast<opmath_t>(a) / b);
},
[=](Vectorized<scalar_t> a) {
return binary_op_scalar(
a,
b,
[](const Vectorized<opmath_t>& x,
const Vectorized<opmath_t>& y) {
return (x / y).trunc();
});
});
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "div_trunc_cpu", [&]() {
cpu_kernel_vec(iter,
Expand Down Expand Up @@ -223,20 +232,25 @@ void div_floor_kernel(TensorIteratorBase& iter) {
});
} else {
// See NOTE: [Floor Division in Python]
if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_floor_cpu_reduced_float", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t b = iter.original_scalar_value<opmath_t>(2);
iter.remove_operand(2);
using vec_t = Vectorized<opmath_t>;
cpu_kernel_vec(iter,
[=](scalar_t a) -> scalar_t {
return div_floor_floating(static_cast<opmath_t>(a), b);
},
[=](Vectorized<scalar_t> a) {
return binary_op_scalar(a, b, [](const vec_t& x, const vec_t& y) { return div_floor_floating_vec(x, y); });
if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
dtype, "div_floor_cpu_reduced_float", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t b = iter.original_scalar_value<opmath_t>(2);
iter.remove_operand(2);
using vec_t = Vectorized<opmath_t>;
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t {
return div_floor_floating(static_cast<opmath_t>(a), b);
},
[=](Vectorized<scalar_t> a) {
return binary_op_scalar(
a, b, [](const vec_t& x, const vec_t& y) {
return div_floor_floating_vec(x, y);
});
});
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "div_floor_cpu", [&]() {
using vec_t = Vectorized<scalar_t>;
Expand Down
8 changes: 6 additions & 2 deletions test/test_foreach.py
Expand Up @@ -516,16 +516,20 @@ def test_reduce_op(self, device, dtype, op, is_fastpath):
sum(ref((ref_tensors,), ord=ord)).backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])

@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
# TODO: enable empty list case
for tensors in [[torch.randn([0])]]:
for tensors in [[torch.randn([0], device=device, dtype=dtype)],
[torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)]]:
res = torch._foreach_add(tensors, 1)
self.assertEqual(res, tensors)

torch._foreach_add_(tensors, 1)
self.assertEqual(res, tensors)

# Regression test for https://github.com/pytorch/pytorch/issues/113156
torch._foreach_mul_(tensors, 1)

@ops(
filter(lambda op: not op.has_no_out_of_place, foreach_binary_op_db),
dtypes=OpDTypes.supported,
Expand Down
15 changes: 15 additions & 0 deletions test/test_numpy_interop.py
Expand Up @@ -472,6 +472,21 @@ def test_numpy_scalar_cmp(self, device, dtype):
else:
self.assertTrue(t == a)

@onlyCPU
def test_empty_tensors_interop(self, device):
x = torch.rand((), dtype=torch.float16)
y = torch.tensor(np.random.rand(0), dtype=torch.float16)
# Same can be achieved by running
# y = torch.empty_strided((0,), (0,), dtype=torch.float16)

# Regression test for https://github.com/pytorch/pytorch/issues/115068
self.assertEqual(torch.true_divide(x, y).shape, y.shape)
# Regression test for https://github.com/pytorch/pytorch/issues/115066
self.assertEqual(torch.mul(x, y).shape, y.shape)
# Regression test for https://github.com/pytorch/pytorch/issues/113037
self.assertEqual(torch.div(x, y, rounding_mode='floor').shape, y.shape)


instantiate_device_type_tests(TestNumPyInterop, globals())

if __name__ == '__main__':
Expand Down

0 comments on commit 448700d

Please sign in to comment.