Skip to content

Commit e15c359

Browse files
[FXML-2172] Torch Div Conversion Fix (#45)
* fix(TorchToTosa.cpp): adjust torch div conversion check the return type of the division to figure out whether to use the floating point implementation of a division or to use the integer. the issue rose from the fact that the inputs are all integer but the result was casted to floating point. The conversion then chose to use the integer implementation of division which is not legal in tosa when all the inputs get casted to floating point. * test(e2e): integer division resulting in a float pytorch example of two integers being divided that should case to a float * fix(TorchToTosa.cpp): correct type promotion for reciprocal the operation should only be handling floats and not integers * Update python/torch_mlir_e2e_test/test_suite/elementwise.py Co-authored-by: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> * fix(xfail_sets.py): add torchdynamo case for tensor divided by scalar --------- Co-authored-by: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com>
1 parent 97503d6 commit e15c359

File tree

5 files changed

+53
-6
lines changed

5 files changed

+53
-6
lines changed

e2e_testing/xfail_sets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@
229229
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
230230
"ElementwiseDivScalarModule_basic",
231231

232+
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
233+
"ElementwiseDivIntScalarModule_basic",
234+
232235
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
233236
"ElementwiseMulScalarModule_int",
234237

@@ -424,6 +427,7 @@
424427
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
425428
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
426429
"ElementwiseDivScalarModule_basic",
430+
"ElementwiseDivIntScalarModule_basic",
427431
"ElementwiseEqDiffWidthScalarModule_basic",
428432
"ElementwiseEqFloatScalarModule_basic",
429433
"ElementwiseEqIntScalarModule_basic",
@@ -893,6 +897,7 @@
893897
"ElementwiseMulScalarModule_float",
894898
"ElementwiseMulTensorIntModule_basic",
895899
"ElementwiseDivScalarModule_basic",
900+
"ElementwiseDivIntScalarModule_basic",
896901
"ElementwiseSubScalarFloatModule_basic",
897902
"ElementwiseAddScalarFloatModule_basic",
898903
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
958958
divScalar.emitError("unimplemented: non-floating point dtype");
959959
return nullptr;
960960
}
961-
Value self = payloadArgs[0];
961+
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
962962
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
963963
return b.create<arith::DivFOp>(loc, self, other);
964964
}

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,15 +497,20 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
497497

498498
// auto result;
499499
Value result;
500-
if (lhsElemTy.isa<mlir::FloatType>()) {
500+
if (outType.getElementType().template isa<mlir::FloatType>()) {
501+
// The input to the reciprocal is an integer sometimes, and we may need to
502+
// promote it to a floating point. Per TOSA specification, the input types
503+
// can only be floating point for tosa::ReciprocalOp.
504+
Value rhsCasted = tosa::promoteType(rewriter, rhsTensor, outType);
501505
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
502-
op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
503-
rhsTensor);
506+
op->getLoc(), rhsCasted.getType(), rhsCasted);
504507

505508
result = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
506509
rcpOp.getResult(), /*shift=*/0)
507510
.getResult();
508511
} else {
512+
// If the output type of the original operation is an integer then we will
513+
// apply a tosa div knowing that rounding will occur and truncate to zero.
509514
result = tosa::createBinaryOpAndCast<tosa::DivOp>(rewriter, op, outType,
510515
lhs, rhsTensor)
511516
.getResult();

python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,6 +1744,28 @@ def forward(self, x):
17441744
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
17451745
module.forward(tu.rand(3, 4))
17461746

1747+
1748+
# ==============================================================================
1749+
1750+
1751+
class ElementwiseDivIntScalarModule(torch.nn.Module):
1752+
1753+
def __init__(self):
1754+
super().__init__()
1755+
1756+
@export
1757+
@annotate_args([
1758+
None,
1759+
([-1, -1], torch.int64, True),
1760+
])
1761+
def forward(self, x):
1762+
return torch.ops.aten.div(x, 128)
1763+
1764+
1765+
@register_test_case(module_factory=lambda: ElementwiseDivIntScalarModule())
1766+
def ElementwiseDivIntScalarModule_basic(module, tu: TestUtils):
1767+
module.forward(tu.randint(3, 4))
1768+
17471769
# ==============================================================================
17481770

17491771

test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],
9191
// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp
9292
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
9393
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
94-
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xi32>) -> tensor<?x?xi32>
95-
// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_2]]) : (tensor<?x?xi32>) -> tensor<?x?xf32>
94+
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<?x?xi32>) -> tensor<?x?xf32>
95+
// CHECK: %[[VAL_3:.*]] = "tosa.reciprocal"(%[[VAL_2]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
9696
// CHECK: %[[VAL_4:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
9797
func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> {
9898
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32>
@@ -113,6 +113,21 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1
113113

114114
// -----
115115

116+
// CHECK-LABEL: torch.aten.div.Scalar$int_input_fp_output
117+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64>
118+
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1.280000e+02> : tensor<f32>}> : () -> tensor<f32>
119+
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<f32>) -> tensor<f32>
120+
// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xi64>) -> tensor<?x?xf32>
121+
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array<i64: 1, 1>}> : (tensor<f32>) -> tensor<1x1xf32>
122+
// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_4]]) <{shift = 0 : i32}> : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
123+
func.func @torch.aten.div.Scalar$int_input_fp_output(%arg0: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],f32> {
124+
%int128 = torch.constant.int 128
125+
%0 = torch.aten.div.Scalar %arg0, %int128 : !torch.vtensor<[?, ?],si64>, !torch.int -> !torch.vtensor<[?, ?],f32>
126+
return %0 : !torch.vtensor<[?, ?],f32>
127+
}
128+
129+
// -----
130+
116131
// CHECK-LABEL: torch.aten.pow.Tensor$mixed_type
117132
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf16>
118133
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>

0 commit comments

Comments
 (0)