Skip to content

Commit

Permalink
Fix for MPS regression in #122016 and #123178 (#123385)
Browse files Browse the repository at this point in the history
Fixes #122016 and #123178. This regression is related to an OS side change that requires a slight adjustment from us on PyTorch side to restore the previous behavior. Additionally we cleared out pre-MacOS13 related workarounds.

Before the fix on MacOS 14.4:

```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 3., 3.], device='mps:0')
```

After the fix:
```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 1., 3.], device='mps:0')
```

This also fixes complex number initialization and as such makes `nn.functional.rms_norm` pass on MacOS-14+

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: #123234
Approved by: https://github.com/malfet, https://github.com/kulinseth

(cherry picked from commit 05289a2)

Co-authored-by: Joona Havukainen <jhavukainen@apple.com>
  • Loading branch information
pytorchbot and jhavukainen committed Apr 5, 2024
1 parent 23961ce commit fb38ab7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
28 changes: 8 additions & 20 deletions aten/src/ATen/native/mps/operations/ConstantOps.mm
Expand Up @@ -28,43 +28,31 @@

struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};

@autoreleasepool {
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());

auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
// workaround by filing it as int8 tensor and than casting to bool
// See https://github.com/pytorch/pytorch/issues/82427
// constantWithScalar does not work for UInt8 Types on MacOS-12.[34]/Ventura preview
// workaround by filing it as uint32 tensor and than casting to uint8
// See https://github.com/pytorch/pytorch/issues/83692
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
shape:getMPSShape(self)
dataType:dataType];
MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()));
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
if (isBool) {
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"];
}
if (isUInt8) {
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"];
}

newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
});

auto mpsScalar = getMPSScalar(value, self.scalar_type());
auto mpsScalarData = getMPSGraphTensorFromScalar(getCurrentMPSStream(), mpsScalar);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_ : mpsScalarData};

Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};

runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), /*feeds*/ nil, results);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);

if (needsCopyToOutput) {
self.copy_(output);
Expand Down
1 change: 1 addition & 0 deletions test/test_mps.py
Expand Up @@ -412,6 +412,7 @@ def mps_ops_modifier(ops):
'mean',
'ne',
'neg',
'nn.functional.rms_norm',
'nn.functional.padconstant',
'nn.functional.padreflect',
'nn.functional.padreplicate',
Expand Down

0 comments on commit fb38ab7

Please sign in to comment.