Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Torch conv aliases #2011

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

alealv
Copy link
Contributor

@alealv alealv commented Oct 12, 2023

When trying to convert Torch Conv1d layer I get:

the following model ops are MISSING:
  conv1d
...
RuntimeError: PyTorch convert function for op 'conv1d' not implemented.

And although it seems to be supported, the aliases aren't there hence the error.

A similar situation happens with conv_transpose

It seems to be supported but it fails for the same reason. Though, in this case some code had be added to match the order of inputs used by Torch

I hope this helps to improve the code.

Note: I'm using a torch.script model

This should also close #1753

@@ -889,7 +889,7 @@ def linear(context, node):
context.add(res, torch_name=node.name)


@register_torch_op(torch_alias=["conv2d"])
@register_torch_op(torch_alias=["conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have passing unit tests for conv1d, I don't think we should add an alias for it here. PyTorch must be lowering it to other ops which we do support.

Copy link
Contributor Author

@alealv alealv Oct 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me then understand why do I get ?

the following model ops are IMPLEMENTED:
  add
  atan2
  clamp
  complex
  constant
  constantchunk
  cos
  exp
  gelu
  layer_norm
  linear
  mul
  sin
  tensor
  transpose
the following model ops are MISSING:
  conv1d
  floatimplicit
  istft
Traceback (most recent call last):
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 73, in <module>
    main()
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 54, in main
    model.to_coreml(
  File "/home/aalvarez/.virtualenvs/tts-train-XZ1ykfT_-py3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/vocos/__init__.py", line 502, in to_coreml
    coreml_mdl = ct.convert(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/_converters_entry.py", line 553, in convert
    mlmodel = mil_convert(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 75, in load
    return _perform_torch_convert(converter, debug)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 122, in _perform_torch_convert
    raise e
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 114, in _perform_torch_convert
    prog = converter.convert()
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 484, in convert
    convert_nodes(self.context, self.graph)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
    raise RuntimeError(
RuntimeError: PyTorch convert function for op 'conv1d' not implemented.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you calling torch.jit.trace prior to conversion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'm using torch.script only and I guess there lies the discrepancy.

In my experience, using torch.script is much more reliable because of:

  • data dependent models with conditionals (main point)
  • additions of extra functions
  • no input example boilerplate

@TobyRoseman
Copy link
Collaborator

@alealv - thanks for the pull request. In order to merge this, we need unit tests. Please add unit tests for these new aliases.

@jakesabathia2
Copy link
Collaborator

@alealv does the torch.jit.trace works for your model?
That is what we recommend in generally, instead of using torch.jit.script.

@katelyn-chen
Copy link

I'm also facing the same issue but with conv_transpose2d. I'm using torch.jit.script because while torch.jit.trace runs, it outputs the following: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error: Tensor-likes are not close! Mismatched elements: 239964 / 259200 (92.6%)

@TobyRoseman
Copy link
Collaborator

@katelyn-chen - I don't know. That is a PyTorch issue. I suggest you ask in a PyTorch forum.

@alealv
Copy link
Contributor Author

alealv commented Oct 17, 2023

@alealv does the torch.jit.trace works for your model? That is what we recommend in generally, instead of using torch.jit.script.

I just tried with tracing as you suggested. Here are the differences

With torch.script

Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
Converting PyTorch Frontend ==> MIL Ops:  52%|███████████████████████████████████████▉                                     | 95/183 [00:00<00:00, 6825.15 ops/s]
the following model ops are IMPLEMENTED:
  add
  clamp
  complex
  constant
  constantchunk
  cos
  exp
  gelu
  layer_norm
  linear
  mul
  sin
  transpose
the following model ops are MISSING:
  conv1d
  istft
Traceback (most recent call last):
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 73, in <module>
    main()
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 54, in main
    model.to_coreml(
  File "/home/aalvarez/.virtualenvs/tts-train-XZ1ykfT_-py3.9/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/vocos/__init__.py", line 502, in to_coreml
    coreml_mdl = ct.convert(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/_converters_entry.py", line 553, in convert
    mlmodel = mil_convert(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 75, in load
    return _perform_torch_convert(converter, debug)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 122, in _perform_torch_convert
    raise e
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 114, in _perform_torch_convert
    prog = converter.convert()
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 484, in convert
    convert_nodes(self.context, self.graph)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
    raise RuntimeError(
RuntimeError: PyTorch convert function for op 'conv1d' not implemented.

With torch.tracing

Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
Converting PyTorch Frontend ==> MIL Ops:  96%|████████████████████████████████████████████████████████████████████████▉   | 145/151 [00:00<00:00, 3665.71 ops/s]
the following model ops are IMPLEMENTED:
  _convolution
  add
  complex
  constant
  constantchunk
  cos
  exp
  gelu
  layer_norm
  linear
  listconstruct
  mul
  sin
  transpose
the following model ops are MISSING:
  clip
  istft
Traceback (most recent call last):
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 73, in <module>
    main()
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 54, in main
    model.to_coreml(
  File "/home/aalvarez/.virtualenvs/tts-train-XZ1ykfT_-py3.9/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/vocos/__init__.py", line 531, in to_coreml
    coreml_mdl = ct.convert(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/_converters_entry.py", line 553, in convert
    mlmodel = mil_convert(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 75, in load
    return _perform_torch_convert(converter, debug)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 122, in _perform_torch_convert
    raise e
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 114, in _perform_torch_convert
    prog = converter.convert()
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 484, in convert
    convert_nodes(self.context, self.graph)
  File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
    raise RuntimeError(
RuntimeError: PyTorch convert function for op 'clip' not implemented.

So the Conv1D is mapped to _convolution but not for torch.script that's what my MR adds.
I don't know all the mechanism behind coremltools but they should end up with the same representation. Meaning we should also map to _convolution, shouldn't we?

@TobyRoseman
Copy link
Collaborator

@alealv - if you want to add conv1d support for torch.script, that's fine. We'll just need unit tests for this functionality. FYI - coremltools support for torch.script is only "experimental". So this isn't a priority for us.

Regarding your torch.trace error - it looks like clip is just alias for clamp which we already support.

@alealv alealv force-pushed the add_conv1d_op branch 2 times, most recently from 906f052 to f31b8de Compare November 15, 2023 09:55
@alealv
Copy link
Contributor Author

alealv commented Nov 15, 2023

I just updated Convolution tests to be also tested when using torch script.

Comment on lines 1691 to 1705
if padding == "same" and stride != 1:
return

class FunctionalConv1D(nn.Module):
def __init__(self):
super(FunctionalConv1D, self).__init__()
self.stride=stride
self.padding=padding
self.groups=groups
def forward(self, input_data, weights):
return nn.functional.conv1d(
input_data, weights, stride=stride, padding=padding
input_data, weights, stride=self.stride, padding=self.padding, groups=self.groups
)

model = DynamicConv()
model = FunctionalConv1D().eval()
Copy link
Contributor Author

@alealv alealv Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this code I get the following error:

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py:1710:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
coremltools/converters/mil/frontend/torch/test/testing_utils.py:286: in run_compare_torch
    model_spec, mlmodel, coreml_inputs, coreml_results = convert_and_compare(
coremltools/converters/mil/frontend/torch/test/testing_utils.py:195: in convert_and_compare
    mlmodel = convert_to_mlmodel(model_spec, input_data, backend=backend,
coremltools/converters/mil/frontend/torch/test/testing_utils.py:120: in convert_to_mlmodel
    return ct_convert(model_spec, inputs=inputs, convert_to=backend,
coremltools/converters/mil/testing_utils.py:466: in ct_convert
    mlmodel = converter(
coremltools/converters/_converters_entry.py:583: in convert
    mlmodel = mil_convert(
coremltools/converters/mil/converter.py:188: in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
coremltools/converters/mil/converter.py:212: in _mil_convert
    proto, mil_program = mil_convert_to_proto(
coremltools/converters/mil/converter.py:286: in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
coremltools/converters/mil/converter.py:108: in __call__
    return load(*args, **kwargs)
coremltools/converters/mil/frontend/torch/load.py:80: in load
    return _perform_torch_convert(converter, debug)
coremltools/converters/mil/frontend/torch/load.py:99: in _perform_torch_convert
    prog = converter.convert()
coremltools/converters/mil/frontend/torch/converter.py:519: in convert
    convert_nodes(self.context, self.graph)
coremltools/converters/mil/frontend/torch/ops.py:89: in convert_nodes
    add_op(context, node)
coremltools/converters/mil/frontend/torch/ops.py:729: in listconstruct
    _array_construct(context, node, array_type=list)
coremltools/converters/mil/frontend/torch/ops.py:702: in _array_construct
    inputs = _get_inputs(context, node)
coremltools/converters/mil/frontend/torch/ops.py:224: in _get_inputs
    inputs = get_bindings(node.inputs)
coremltools/converters/mil/frontend/torch/ops.py:210: in get_bindings
    results.append(context[i])
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = %input_data.1 : (1, 3, 7, fp32)
%weights.1 : (3, 3, 3, fp32)
%3 : (1,int32)*
%4 : None

, torch_name = 'stride'

    def __getitem__(self, torch_name: str) -> Var:
        """
        Lookup a name in the context. Note that since nested blocks must be
        able to access anything that was defined before them, we have to
        search all contexts for a name, starting with the most local scope.
        """
        for idx in reversed(range(len(self._current_graph))):
            current_graph = self._current_graph[idx]
            if torch_name in current_graph:
                return self._current_graph[idx][torch_name]
>       raise ValueError(f"Torch var {torch_name} not found in context {self.name}")
E       ValueError: Torch var stride not found in context

coremltools/converters/mil/frontend/torch/converter.py:251: ValueError

Can someone help me to fix this? I don't know coremltools o deeply to solve this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I try to follow how it was done previously:

        class FunctionalConv1D(nn.Module):
            def forward(self, input_data, weights):
                return nn.functional.conv1d(
                    input_data, weights, stride=stride, padding=padding, groups=groups
                )

        model = FunctionalConv1D().eval()
        input_shape = [
            (1, in_channels, width),
            (out_channels, int(in_channels / groups), kernel_size),
        ]

I get:

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py:1705:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
coremltools/converters/mil/frontend/torch/test/testing_utils.py:267: in run_compare_torch
    model_spec = torch.jit.script(model)
envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/jit/_script.py:1324: in script
    return torch.jit._recursive.create_script_module(
envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/jit/_recursive.py:559: in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/jit/_recursive.py:636: in create_script_module_impl
    create_methods_and_properties_from_stubs(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

concrete_type = <torch.ConcreteModuleType object at 0x7f65d8224970>
method_stubs = [ScriptMethodStub(resolution_callback=<function createResolutionCallbackFromEnv.<locals>.<lambda> at 0x7f65d88a2b80>, ...l_method=<bound method T
estFunctionalConv.test_convolution1d.<locals>.FunctionalConv1D.forward of FunctionalConv1D()>)]
property_stubs = []

    def create_methods_and_properties_from_stubs(
        concrete_type, method_stubs, property_stubs
    ):
        method_defs = [m.def_ for m in method_stubs]
        method_rcbs = [m.resolution_callback for m in method_stubs]
        method_defaults = [get_default_args(m.original_method) for m in method_stubs]

        property_defs = [p.def_ for p in property_stubs]
        property_rcbs = [p.resolution_callback for p in property_stubs]

>       concrete_type._create_methods_and_properties(
            property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
        )
E       RuntimeError:
E       python value of type 'int' cannot be used as a value. Perhaps it is a closed over global variable? If so, please consider passing it in as an argument o
r use a local varible instead.:
E         File "/root/coremltools/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py", line 1697
E                   def forward(self, input_data, weights):
E                       return nn.functional.conv1d(
E                           input_data, weights, stride=stride, padding=padding, groups=groups
E                                                       ~~~~~~ <--- HERE
E                       )

envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/jit/_recursive.py:469: RuntimeError

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look like errors related to trying to convert PyTorch models which haven't been traced. Do you only get these error when use_scripting=True? That causes the model not to get traced.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and it's exactly what I want. I want to add tests for scripting.

@alealv
Copy link
Contributor Author

alealv commented Dec 21, 2023

I'm trying to figure out how the nn.Conv1D(...) get's converted with JIT.

The test fails with:

E       ValueError: Torch var bias not found in context

I'm getting a problem only when bias is False. And I don't fully understand what should I do

Here is the output graph:

graph(%self : __torch__.torch.nn.modules.conv.Conv1d,
      %input.1 : Tensor):
  %weight : Tensor = prim::GetAttr[name="weight"](%self)
  %bias : Tensor? = prim::GetAttr[name="bias"](%self)
  %4 : int = prim::Constant[value=1]() # /root/coremltools/envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:306:45
  %5 : int = prim::Constant[value=0]() # /root/coremltools/envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:307:24
  %6 : int = prim::Constant[value=3]() # /root/coremltools/envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:307:38
  %7 : int[] = prim::ListConstruct(%4)
  %8 : int[] = prim::ListConstruct(%5)
  %9 : int[] = prim::ListConstruct(%6)
  %10 : Tensor = aten::conv1d(%input.1, %weight, %bias, %7, %8, %9, %4) # /root/coremltools/envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:306:15
  return (%10)

So my understanding is that we have only one input, the weights. Which makes sense given that we bias is False.
But aten::conv1d requires bias, though it can it says it's optional. Hence, it works because it isn't provided.

How does coremltools handles this?

I see that:

Node: %bias : Tensor? = prim::GetAttr[name="bias"](%self)

Type: <class 'torch.Node'>
Is tensor: False
Is quantize tensor: False
prefix: bias
Module: None

And we have if it's a tensor it does nothing.

    def _lower_graph_block(graph):
        for node in list(graph.nodes()):
        	...
            is_tensor = _check_is_tensor(node, module)
            is_quantized_tensor = _check_is_quantized_tensor(node, module)

            if is_tensor or is_quantized_tensor:
                ...

where

    def _check_is_tensor(node, module):
        if not isinstance(module, torch.Tensor):
            return False
        if str(node.output().type()) not in ("Tensor", "Optional[Tensor]"):
            raise TypeError(f'Type "{node.output().type()}" not supported')
        return True

Can anyone help me to understand this?

@alealv alealv changed the title Add support for Torch ops conv_transpose and add conv aliases Add support for Torch conv aliases Jan 9, 2024
Comment on lines +220 to +223
try:
results.append(context[i])
except ValueError:
results.append(None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the magic trick to continue when bias is None, because is an optional parameter.

@@ -962,7 +965,8 @@ def linear(context, node):
context.add(res, torch_name=node.name)


@register_torch_op(torch_alias=["conv2d", "convolution"])
# NOTE: This function is also an alias of: ["conv_transpose1d", "conv_transpose2d", "conv_transpose3d"] but we lack tests for those
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I fail to add test for conv_transposeXd because it asks for output_size as an input in the computational graph which I don't know how to solve. Though, the function should support the operation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for torch op conv1d
4 participants