Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S4 Module incompatible with pytorch 2.0's torch.compile #91

Open
RoiEXLab opened this issue Mar 18, 2023 · 7 comments
Open

S4 Module incompatible with pytorch 2.0's torch.compile #91

RoiEXLab opened this issue Mar 18, 2023 · 7 comments

Comments

@RoiEXLab
Copy link

RoiEXLab commented Mar 18, 2023

Hi,

I'm using the sashimi model on my own dataset with reasonable success for a while now and I wanted to see if I could use the recently released torch.compile function on the sashimi model to speed up training for my experiments.

Unfortunately it doesn't work. The following line seems to fail (for reasons I don't understand): https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/s4/s4.py#L703

On the pytorch site there's some information on how to deal with those issues, so I hope the code can be extended in the future to run faster by a noticeable amount.

Thanks in advance.

@RoiEXLab RoiEXLab changed the title S4 Module incompatible with pytorch 2.0 S4 Module incompatible with pytorch 2.0's torch.compile Mar 18, 2023
@albertfgu
Copy link
Contributor

Thanks a lot for the report! Because of potential issues like these, I'm waiting until Pytorch 2.x is more stable before upgrading the entire repo. If the underlying issue is complex numbers with torch.compile (can you repro the error with a minimal example?) there's not much I can do and it would be great to file an issue directly with Pytorch.

@RoiEXLab
Copy link
Author

Thanks for the reply.

can you repro the error with a minimal example?

It should be rather easy to reproduce, just by running a simple forward pass on the standalone sashimi model after it has been compiled. But I'll try to provide one within the next 24 hours if I get to it.

@albertfgu
Copy link
Contributor

Right, although it would be helpful to see if it fails with a more minimal model than the S4 layer if the line you pointed out is indeed the problem. If you don't get to it, I'll keep this in mind when I get around to trying to upgrade the library versions.

@RoiEXLab
Copy link
Author

Ah I see. I'll see what I can do

@RoiEXLab
Copy link
Author

@albertfgu Here you go: https://gist.github.com/RoiEXLab/5cc1630aca71b603528a574b2a2e3326

It turns out the SSKernel seems to be the issue. When running python3 reproducer.py (see gist, make sure to install the required dependencies, some minor adjustments were made to the files to keep it as simple as possible). I get the following error:

CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%
[2023-03-22 16:52:30,027] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 333, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 225, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 3633, in mul
    return make_pointwise(fn)(a, b)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 329, in inner
    loaders = [x.make_loader() for x in inputs]
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 329, in <listcomp>
    loaders = [x.make_loader() for x in inputs]
AttributeError: 'complex' object has no attribute 'make_loader'
Traceback (most recent call last):
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 333, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 225, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 3633, in mul
    return make_pointwise(fn)(a, b)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 329, in inner
    loaders = [x.make_loader() for x in inputs]
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 329, in <listcomp>
    loaders = [x.make_loader() for x in inputs]
AttributeError: 'complex' object has no attribute 'make_loader'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
    return inner_compile(
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 176, in compile_fx_inner
    graph.run(*example_inputs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 194, in run
    return super().run(*args)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 407, in run_node
    result = super().run_node(n)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 337, in call_function
    raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: AttributeError: 'complex' object has no attribute 'make_loader'
  target: aten.mul.Tensor
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[64, 32], stride=[32, 1]))
  ))
  args[1]: 1j

While executing %mul : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg1_1, 1j), kwargs = {})
Original traceback:
  File "/home/roiex/s4-reproducer/s4.py", line 702, in _w
    w = w_real + 1j * self.w_imag


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "reproducer.py", line 14, in <module>
    y, _ = model(x)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/roiex/s4-reproducer/s4.py", line 1312, in forward
    return self.kernel(state=state, L=L, rate=rate)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/roiex/s4-reproducer/s4.py", line 717, in forward
    if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
  File "/home/roiex/s4-reproducer/s4.py", line 723, in <graph break in forward>
    L = round(self.L.item() / rate)
  File "/home/roiex/s4-reproducer/s4.py", line 736, in <graph break in forward>
    w = self._w() # (n_ssm, N)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: AttributeError: 'complex' object has no attribute 'make_loader'
  target: aten.mul.Tensor
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[64, 32], stride=[32, 1]))
  ))
  args[1]: 1j

While executing %mul : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg1_1, 1j), kwargs = {})
Original traceback:
  File "/home/roiex/s4-reproducer/s4.py", line 702, in _w
    w = w_real + 1j * self.w_imag



You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

So it does indeed seem the issue is with complex numbers. Looking at the pytorch repo there are a lot of issues open regarding complex numbers, but I'm not quite how well they apply to this exact issue. Also I tried using different backends for the compilation (see import torch._dynamo; torch._dynamo.list_backends()), but they also didn't work out of the box (I assume some need additional dependencies installed, but the ones without extra dependencies didn't work either).

@realCrush
Copy link

Same issue here, I import the SSMKernelDPLR in state_spaces.models.s4.s4 as a module of my custom model, and that cause the error when I try torch.compile. Any progress on this problem?

@albertfgu
Copy link
Contributor

Unfortunately this is a missing functionality on PyTorch's end (in turn coming from lack of support in Triton): pytorch/pytorch#98161. The PyTorch team is aware of this and may look to support it eventually, but it's unclear how long that would take.

I don't think that the core state space kernels (SSKernelDiag or SSKernelDPLR) are a bottleneck for larger scale models. The main benefit of compilation would be fusing together the main computation pathway of the FFT-convolution and the surrounding linears. Unfortunately I don't see a way to do this at the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants