Skip to content

Commit

Permalink
Type promotion rule fix (pytorch#574)
Browse files Browse the repository at this point in the history
* refactor tests

* add category based type promotion

* add category test

* re-order test

* clang-tidy; improve integer variance

* flake
  • Loading branch information
shmsong committed Dec 15, 2020
1 parent 1a21483 commit 3ce81c9
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 60 deletions.
164 changes: 114 additions & 50 deletions test/test_jit_cuda_fuser.py
Expand Up @@ -33,6 +33,14 @@ class TestCudaFuser(JitTestCase):
math.pi, 10, float("inf"),
float("nan")], dtype=torch.float, device='cuda')

int_types = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64
]

def _getSubgraphInFusion(self, graph):
num_node = 0
subgraph = None
Expand Down Expand Up @@ -144,6 +152,42 @@ def t(x, y, z, q):
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_half(self):
def t(x: torch.Tensor):
o = torch.mul(x, 1.0)
o = torch.sum(o, dim=[2])
return o

t_jit = torch.jit.script(t)
x = torch.randn(8, 4, 16, dtype=torch.float16, device="cuda")
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_float(self):
def t(x: torch.Tensor):
o = torch.mul(x, 1.0)
o = torch.sum(o, dim=[2], dtype=torch.float32)
return o
t_jit = torch.jit.script(t)

x = torch.randn(8, 4, 16, dtype=torch.float, device="cuda")
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down Expand Up @@ -409,19 +453,23 @@ def test_unary_ops(self):
for op in operations:
self._unary_test_helper(op)

def _unary_type_test_helper(self, operation, dtype, data=None):
def _unary_type_test_helper(self, operation, dtype, random_data=True):
shape = (4, 8, 32, 32)

def t(x: torch.Tensor):
o = x * 1.0
o = operation(o)
return o

if random_data:
x = torch.randn(shape, dtype=torch.float32, device="cuda")
if dtype in self.int_types:
# prefer a larger variance for integer types
x *= 5
x = x.to(dtype=dtype)
else:
x = self.special_values.to(dtype=dtype)
try:
if data is None:
x = torch.randn(shape, dtype=dtype, device="cuda")
else:
x = special_values.to(dtype=dtype)
ref = t(x)
except Exception:
# same way as TE checker, if eager mode throws, ignore this test
Expand All @@ -432,23 +480,20 @@ def t(x: torch.Tensor):
o = t(x)
self.assertEqual(o, jit_o, msg=f"""
failing case:
{dtype} {operation} {data}
{dtype} {operation} {x}
""")

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_data_compatibility(self):
dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
*self.int_types,
torch.float16,
torch.float32,
torch.float64,
torch.bool
torch.float64
# Bool cannot pass yet due to comment on logical ops
# torch.bool
]
operations = [torch.neg,
torch.abs,
Expand Down Expand Up @@ -483,10 +528,65 @@ def test_data_compatibility(self):
prev_fallback = os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK']
os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '0'
for op, dtype in itertools.product(operations, dtypes):
self._unary_type_test_helper(op, dtype) # test special numbers
self._unary_type_test_helper(op, dtype, False) # test special numbers
self._unary_type_test_helper(op, dtype) # test random data
os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = prev_fallback

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_category_rule(self):
def run_tensor(x, z):
def t(x: torch.Tensor, z: torch.Tensor):
o = x + z
o = torch.abs(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, z)
jit_o = t_jit(x, z)
o = t(x, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)

def run_scalar(x, z):
def t(x: torch.Tensor, z: float):
o = x + z
o = torch.abs(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, z)
jit_o = t_jit(x, z)
o = t(x, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)

# n-dim with 0-dim (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.tensor(2.0, dtype=torch.double, device="cuda")
run_tensor(x, z)

# n-dim with 0-dim (type-promote)
x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
z = torch.tensor(2.0, dtype=torch.double, device="cuda")
run_tensor(x, z)

# n-dim with n-dim (type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.double, device="cuda")
run_tensor(x, z)

# n-dim with scalar (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float16, device="cuda")
z = 3.
run_scalar(x, z)

# n-dim with scalar (type-promote)
x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
z = 3.
run_scalar(x, z)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down Expand Up @@ -1088,42 +1188,6 @@ def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor):
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
torch._C._jit_set_nvfuser_guard_mode(old_guard)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_dtype(self):
def t(x: torch.Tensor):
o = torch.mul(x, 1.0)
o = torch.sum(o, dim=[2], dtype=torch.float32)
return o
t_jit = torch.jit.script(t)

x = torch.randn(8, 4, 16, dtype=torch.float, device="cuda")
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_half(self):
def t(x: torch.Tensor):
o = torch.mul(x, 1.0)
o = torch.sum(o, dim=[2])
return o

t_jit = torch.jit.script(t)
x = torch.randn(8, 4, 16, dtype=torch.float16, device="cuda")
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down
56 changes: 46 additions & 10 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Expand Up @@ -228,20 +228,41 @@ TensorView* arithOpOverloads(
->template as<TensorView>();
}

namespace {
enum class Category { Scalar, ZeroDimTensor, DimTensor };

inline Category getCategory(const Val* v) {
if (v->isA<TensorView>()) {
if (v->as<TensorView>()->nDims() > 0) {
return Category::DimTensor;
} else {
return Category::ZeroDimTensor;
}
} else {
return Category::Scalar;
}
}

// replicated logic from Aten/native/TypeProperties.cpp, minus complex support
DataType getCommonType(DataType higher, DataType lower) {
if (isFloatingPointType(higher)) {
return higher;
}
if (higher == DataType::Bool || isFloatingPointType(lower)) {
return promote_type(higher, lower);
}
if (higher != DataType::Null) {
return higher;
}
return lower;
}
} // namespace

// Type promotion logic for binary operators
DataType getOutputType(BinaryOpType op_type, Val* v1, Val* v2) {
DataType v1_dtype = v1->getDataType().value();
DataType v2_dtype = v2->getDataType().value();

// If we have a tensor view in one argument but a scalar in the other, don't
// type promote, just use the tensorview type
if (v1->isA<TensorView>() && !v2->isA<TensorView>()) {
v2_dtype = v1_dtype;
}
if (v2->isA<TensorView>() && !v1->isA<TensorView>()) {
v1_dtype = v2_dtype;
}

const bool floating_input =
isFloatingPointType(v1_dtype) || isFloatingPointType(v2_dtype);

Expand All @@ -251,11 +272,27 @@ DataType getOutputType(BinaryOpType op_type, Val* v1, Val* v2) {
const bool all_integer_input =
isIntegralType(v1_dtype) && isIntegralType(v2_dtype);

// Combine categories
const auto v1_cat = getCategory(v1);
const auto v2_cat = getCategory(v2);
if (v1_cat != v2_cat) {
const DataType higher = v1_cat > v2_cat ? v1_dtype : v2_dtype;
const DataType lower = v1_cat > v2_cat ? v2_dtype : v1_dtype;
const DataType common_type = getCommonType(higher, lower);
v1_dtype = common_type;
v2_dtype = common_type;
}

if (isIntegerOp(op_type) || (alsoBooleanOperator(op_type) && integer_input)) {
// If integer op or maybe bool op with integer inputs meaning binary op
if (integer_input && all_integer_input) {
return promote_type(v1_dtype, v2_dtype);
} else if (integer_input && !all_integer_input) {
TORCH_CHECK(
!floating_input,
"Operator ",
op_type,
" not supported with floating point inputs.");
return isIntegralType(v1_dtype) ? v1_dtype : v2_dtype;
} else {
TORCH_INTERNAL_ASSERT(
Expand All @@ -264,7 +301,6 @@ DataType getOutputType(BinaryOpType op_type, Val* v1, Val* v2) {
"Inputs should be manually casted first.");
}
} else if (isLogicalOp(op_type)) {
// If boolean op
return DataType::Bool;
} else if (alsoBooleanOperator(op_type)) {
// If boolean op that can't have floating inputs (& or |)
Expand Down

0 comments on commit 3ce81c9

Please sign in to comment.