Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate name for each member of list arg #571

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

xuzijian629
Copy link
Member

@xuzijian629 xuzijian629 commented Jul 12, 2022

Problem

ppe.onnx.export_testcase fails with list input. This is because the number of input_names is incorrect. The number of input tensors in the exported onnx is total tensor numbers, not len(args).

Reproduction

import torch
import pytorch_pfn_extras.onnx as tou
import torch.onnx.symbolic_helper
import torch.onnx.symbolic_registry

@torch.onnx.symbolic_helper.parse_args("v", "b", "f")
def pad_sequence(g, input, batch_first, padding_value):
    assert batch_first, "batch_first=False is not supported"
    ret = g.op("org.chainer::ChainerSequencePad", input, value_f=padding_value)
    return ret

for opset in range(9, 16):
    torch.onnx.symbolic_registry.register_op("pad_sequence", pad_sequence, "", opset)

class Model(torch.nn.Module):
    def forward(self, xs):
        return torch._C._nn.pad_sequence(xs, True, 0)


model = Model()
args = ([torch.rand(2, 5), torch.rand(3, 5)],)
torch.onnx.export(model, args, "test.onnx")  # Success!!
tou.export_testcase(model, args, "test")  # Fails!!

Error:

Traceback (most recent call last):
  File "hoge.py", line 36, in <module>
    tou.export_testcase(model, args, "test", use_pfto=False)
  File "/mnt/vol21/joe/pfvm/third_party/pytorch-pfn-extras/pytorch_pfn_extras/onnx/export_testcase.py", line 303, in export_testcase
    used_input_index_list.append(input_names.index(used_input.name))
ValueError: 'onnx::SequenceConstruct_1' is not in list

torch version: 1.12.0

This PR

Unrolls the input list args and generate names for all tensors.
For example, if the input args is args=([[a,b],c], d), the generated input names are

[input_0_0_0, input_0_0_1, input_0_1, input_1]

whereas current master generates names as

[input_0, input_1]

@xuzijian629
Copy link
Member Author

/test

@asi1024 asi1024 self-assigned this Jul 13, 2022
@asi1024 asi1024 added the cat:bug Something isn't working label Jul 13, 2022
@asi1024 asi1024 added this to the v0.6.0 milestone Jul 13, 2022
@xuzijian629
Copy link
Member Author

todo: add test

gen_input_names = []
unrolled_args = []

def append_input_name(prefix: str, arg: Any) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about supporting also dict and namedtuple types like _normalize_outputs function in _logic.py?

def _normalize_outputs(outputs: Any) -> Dict[str, Any]:
target: Dict[str, Any]
if isinstance(outputs, tuple) and hasattr(outputs, '_fields'):
# namedtuple
target = outputs._asdict() # type: ignore[attr-defined]
elif isinstance(outputs, dict):
target = outputs
elif isinstance(outputs, (list, tuple)):
target = {str(i): out for i, out in enumerate(outputs)}
else:
target = {"0": outputs}
return target

@xuzijian629 xuzijian629 marked this pull request as draft July 14, 2022 07:37
@xuzijian629
Copy link
Member Author

Update: I summarized the problem of extending current ppe.onnx.export for models with list inputs at #572.

In short, torch.onnx.export automatically unrolls list inputs in their internal trace API (torch.jit._get_trace_graph). However, torch's public API for trace (torch.jit.trace) does not unroll list inputs. So maybe we should create a wrapper class for export that unroll list inputs. How do you think about this?

@xuzijian629
Copy link
Member Author

With 6c73264, onnx export with use_pfto=False works well (because then pee.onnx.export will delegate to torch.onnx.export and list inputs are automatically unrolled, so it makes sense to unroll inputs to generate names).

Note that with use_pfto=True, currently, exported onnx has single list input (that is sequence typed in ONNX) and we MUST NOT unroll input args.

@xuzijian629
Copy link
Member Author

My idea: always unroll args to follow torch.onnx.export style. For PFTO, we wrap the module for list inputs (future work).

@kmaehashi kmaehashi removed this from the v0.6.0 milestone Jul 19, 2022
@xuzijian629
Copy link
Member Author

memo: Also support list, tuple outputs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants