Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(<class 'RuntimeError'>, RuntimeError('Unsupported dtype Half'), <traceback object at 0x7ff3c12f6280>) #289

Open
dbl001 opened this issue Feb 20, 2024 · 3 comments

Comments

@dbl001
Copy link

dbl001 commented Feb 20, 2024

I am running on 'MPS' which does not support the datatype Complex64. Initally, I got:

RuntimeError: MPS device does not support bmm for non-float inputs

So, I tried setting

fno_block_precision="half",

This generated:

(<class 'RuntimeError'>, RuntimeError('Unsupported dtype Half'), <traceback object at 0x7ff3c12f6280>)
Screenshot 2024-02-20 at 1 35 48 PM
@JeanKossaifi
Copy link
Member

This seems to an issue with PyTorch and mps support rather than neuraloperator library. FNO is built on top of the FFT so complex support is needed. It seems some support may be available in the nightly build of PyTorch: pytorch/pytorch#116630

@dbl001
Copy link
Author

dbl001 commented Feb 24, 2024

‘mps’ recently added FFT support as well as support for complex numbers, which I’m evaluating.

In [5]: x32 = torch.randn(3, 3, dtype=torch.complex32, device=device)
   ...: y32 = torch.randn(3, 3, dtype=torch.complex32, device=device)
   ...: 
   ...: x64 = torch.randn(3, 3, dtype=torch.complex64, device=device)
   ...: y64 = torch.randn(3, 3, dtype=torch.complex64, device=device)

In [6]: x64
Out[6]: 
tensor([[ 1.0493+1.4454j, -0.7140-0.5213j, -0.1548+0.7472j],
        [-0.5421+0.0464j, -1.4906-0.9571j, -1.4293+0.4247j],
        [ 0.2054-0.0431j, -0.3497+1.4027j,  0.4203-1.3591j]], device='mps:0')

In [7]: y32
Out[7]: 
tensor([[-0.5088+0.1964j,  1.3496-0.5630j,  1.3145-2.2715j],
        [ 1.1006+0.0748j, -1.4512+0.8315j, -1.5723+0.8379j],
        [ 1.7051+1.1533j, -0.4478-0.8452j, -0.1201-1.3057j]], device='mps:0',
       dtype=torch.complex32)

In [8]: y64
Out[8]: 
tensor([[ 1.6695+0.6466j, -0.8143-0.5023j, -0.2274-0.8802j],
        [-0.0850+0.8814j,  0.4071-1.8256j,  0.0106-0.0880j],
        [ 0.9343+0.2286j, -1.3573+2.0575j, -0.2211-0.2846j]], device='mps:0')

In [9]: y64.type
Out[9]: <function Tensor.type>

In [10]: x64*y64
Out[10]: 
tensor([[ 0.8172+3.0915j,  0.3196+0.7831j,  0.6929-0.0337j],
        [ 0.0052-0.4818j, -2.3543+2.3316j,  0.0222+0.1303j],
        [ 0.2017+0.0067j, -2.4114-2.6234j, -0.4797+0.1809j]], device='mps:0')

In [11]: print(x64.dtype)
    ...: # torch.complex64
    ...: 
    ...: dtype = x64.dtype
    ...: print(type(dtype))
    ...: # <class 'torch.dtype'>
    ...: 
    ...: print(dtype == torch.complex64)
    ...: # True
torch.complex64
<class 'torch.dtype'>
True

In [2]: import torch
   ...: x = torch.randn(1, 16000, device="mps")
   ...: y = torch.fft.rfft(x)
   ...: y_abs = y.abs()
Out[2]: 
tensor([[  1.7614,   0.0000, 115.4741,  ...,  63.1403,   0.0000, 148.1510]],
       device='mps:0')

When I try to run train_darcy with device='mps', I get:

Traceback (most recent call last):
  File "/Users/davidlaxer/neuraloperator/scripts/train_darcy.py", line 187, in <module>
    trainer.train(
  File "/Users/davidlaxer/neuraloperator/neuralop/training/trainer.py", line 158, in train
    out  = self.model(**sample)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/davidlaxer/neuraloperator/neuralop/models/fno.py", line 253, in forward
    x = self.fno_blocks(x, layer_idx, output_shape=output_shape[layer_idx])
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/davidlaxer/neuraloperator/neuralop/layers/fno_block.py", line 195, in forward
    return self.forward_with_postactivation(x, index, output_shape)
  File "/Users/davidlaxer/neuraloperator/neuralop/layers/fno_block.py", line 208, in forward_with_postactivation
    x_fno = self.convs(x, index, output_shape=output_shape)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/davidlaxer/neuraloperator/neuralop/layers/spectral_convolution.py", line 459, in forward
    out_fft[slices_x] = self._contract(x[slices_x], weight, separable=False)
  File "/Users/davidlaxer/neuraloperator/neuralop/layers/spectral_convolution.py", line 46, in _contract_dense
    return tl.einsum(eq, x, weight)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/tensorly/backend/__init__.py", line 206, in wrapped_backend_method
    return getattr(
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/tensorly/plugins.py", line 77, in cached_einsum
    return expression(*args)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/opt_einsum/contract.py", line 763, in __call__
    return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/opt_einsum/contract.py", line 693, in _contract
    return _core_contract(list(arrays),
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/opt_einsum/contract.py", line 591, in _core_contract
    new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/opt_einsum/sharing.py", line 151, in cached_einsum
    return einsum(*args, **kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/opt_einsum/contract.py", line 353, in _einsum
    return fn(einsum_str, *operands, **kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/opt_einsum/backends/torch.py", line 45, in einsum
    return torch.einsum(equation, operands)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/functional.py", line 380, in einsum
    return einsum(equation, *_operands)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/functional.py", line 385, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: MPS device does not support bmm for non-float inputs

Next:

In [13]: import torch
    ...: 
    ...: device = torch.device("mps")
    ...: 
    ...: x = torch.randn(3, 4, 5, dtype=torch.complex64, device=device)
    ...: y = torch.randn(3, 5, 2, dtype=torch.complex64, device=device)
    ...: 
    ...: x = x.to(torch.float32)
    ...: y = y.to(torch.float32)
    ...: 
    ...: z = torch.bmm(x, y)
    ...: 
    ...: print(f"X dtype: {x.dtype}")
    ...: print(f"Y dtype: {y.dtype}")
    ...: print(f"Result dtype: {z.dtype}")
    ...: print(f"Result device: {z.device}")
    ...: print(f"Result shape: {z.shape}")
X dtype: torch.float32
Y dtype: torch.float32
Result dtype: torch.float32
Result device: mps:0
Result shape: torch.Size([3, 4, 2])

In [14]: import torch
    ...: 
    ...: device = torch.device("mps")
    ...: 
    ...: x = torch.randn(3, 4, 5, dtype=torch.complex64, device=device)
    ...: y = torch.randn(3, 5, 2, dtype=torch.complex64, device=device)
    ...: 
    ...: z = torch.bmm(x, y)
    ...: 
    ...: print(f"X dtype: {x.dtype}")
    ...: print(f"Y dtype: {y.dtype}")
    ...: print(f"Result dtype: {z.dtype}")
    ...: print(f"Result device: {z.device}")
    ...: print(f"Result shape: {z.shape}")
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 8
      5 x = torch.randn(3, 4, 5, dtype=torch.complex64, device=device)
      6 y = torch.randn(3, 5, 2, dtype=torch.complex64, device=device)
----> 8 z = torch.bmm(x, y)
     10 print(f"X dtype: {x.dtype}")
     11 print(f"Y dtype: {y.dtype}") 

RuntimeError: MPS device does not support bmm for non-float inputs


So, I tried to set ’fno_block_precision = “half” in the class SpectralConv.
Setting fno_block_precision = half is what caused the exception I reported.

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1534, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/davidlaxer/neuraloperator/scripts/train_darcy.py", line 187, in <module>
    trainer.train(
  File "/Users/davidlaxer/neuraloperator/neuralop/training/trainer.py", line 158, in train
    out  = self.model(**sample)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/davidlaxer/neuraloperator/neuralop/models/fno.py", line 253, in forward
    x = self.fno_blocks(x, layer_idx, output_shape=output_shape[layer_idx])
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/davidlaxer/neuraloperator/neuralop/layers/fno_block.py", line 195, in forward
    return self.forward_with_postactivation(x, index, output_shape)
  File "/Users/davidlaxer/neuraloperator/neuralop/layers/fno_block.py", line 208, in forward_with_postactivation
    x_fno = self.convs(x, index, output_shape=output_shape)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/davidlaxer/neuraloperator/neuralop/layers/spectral_convolution.py", line 434, in forward
    x = torch.fft.rfftn(x, norm=self.fft_norm, dim=fft_dims)
RuntimeError: Unsupported dtype Half
python-BaseException

Screenshot 2024-02-24 at 8 13 17 AM

@dbl001
Copy link
Author

dbl001 commented May 8, 2024

pytorch still doesn't support bmm() for device 'mps' which is used in tensorly's einsum(). I created a temporary workaround using numpy.einsum():

def _contract_dense(x, weight, separable=False):
    order = tl.ndim(x)
    # batch-size, in_channels, x, y...
    x_syms = list(einsum_symbols[:order])

    # in_channels, out_channels, x, y...
    weight_syms = list(x_syms[1:])  # no batch-size

    # batch-size, out_channels, x, y...
    if separable:
        out_syms = [x_syms[0]] + list(weight_syms)
    else:
        weight_syms.insert(1, einsum_symbols[order])  # outputs
        out_syms = list(weight_syms)
        out_syms[0] = x_syms[0]

    eq = f'{"".join(x_syms)},{"".join(weight_syms)}->{"".join(out_syms)}'

    if not torch.is_tensor(weight):
        weight = weight.to_tensor()

    if x.dtype == torch.complex32:
        # if x is half precision, run a specialized einsum
        return einsum_complexhalf(eq, x, weight)
    else:
        #return tl.einsum(eq, x, weight)
        result_numpy =  np.einsum(eq,x.detach().cpu().numpy(), weight.detach().cpu().numpy())
        result = torch.from_numpy(result_numpy).to(x.device)
        return result

Output from Train_Darcy.py:

[296] time=9.40, avg_loss=0.9336, train_err=4.4459, 16_h1=0.2855, 16_l2=0.1698, 32_h1=0.4454, 32_l2=0.1980
[297] time=9.04, avg_loss=0.9336, train_err=4.4458, 16_h1=0.2902, 16_l2=0.1777, 32_h1=0.4353, 32_l2=0.1961
[298] time=8.80, avg_loss=0.9349, train_err=4.4521, 16_h1=0.2859, 16_l2=0.1708, 32_h1=0.4440, 32_l2=0.1949
[299] time=8.95, avg_loss=0.9347, train_err=4.4510, 16_h1=0.2873, 16_l2=0.1736, 32_h1=0.4420, 32_l2=0.1986

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants