Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchScript compatibility? #13

Open
Linux-cpp-lisp opened this issue Apr 9, 2021 · 7 comments
Open

TorchScript compatibility? #13

Linux-cpp-lisp opened this issue Apr 9, 2021 · 7 comments

Comments

@Linux-cpp-lisp
Copy link

Hi all,

This library looks very nice :)

Is TensorType compatible with the TorchScript compiler? As in, are the annotations transparently converted to torch.Tensor as far as torch.jit.script is concerned, allowing annotated modules/functions to be compiled? (I'm not worried about whether the type checking applied in TorchScript, just whether an annotated program that gets shape-checked in Python can be compiled down to TorchScript.)

Thanks!

@patrick-kidger
Copy link
Owner

patrick-kidger commented Apr 9, 2021

So I've been playing with this for a bit and unfortunately can't get it to work.

If you or someone else does manage to get this working, then I'd be happy to accept a PR on it.

For posterity:

  • TensorType does not currently inherit from torch.Tensor. This means that @torch.jit.script def f(x: TensorType) results in TensorType trying to be compiled, which fails.
  • This happens even when adding @torch.jit.ignore in various places. I think ignoring things only really works for free functions or methods of subclasses of torch.nn.Module.
  • Changing TensorType to inherit from torch.Tensor allows @torch.jit.script def f(x: TensorType), but @torch.jit.script def f(x: TensorType["b"]) still breaks, with error message Unknown type constructor TensorType from the TorchScript compiler.
  • Nothing I tried managed to fix that, and indeed a little googling suggests that it might be impossible, as type constructors are apparently parsed as strings: "Unknown type constructor" error in TorchScript pytorch/pytorch#29094 (And indeed I tried sneaky things like inheriting class TensorType(typing.List), without success.) My impression is that the only parameterised types admitted as annotations are the standard built-in ones like List.

@Linux-cpp-lisp
Copy link
Author

Linux-cpp-lisp commented Apr 12, 2021

Hi @patrick-kidger, thanks for he quick answer! This level of arcane tinkering with TorchScript definitely sounds familiar to me... 😁

The issue you link in the third bullet does make it look like there is nothing that can be done here until PyTorch resolves the underlying incompatibility with Python. (If I'm understanding this right you couldn't even do Annotated[torch.Tensor, something_else] since it wouldn't be parsable as a string, even though Python people worked hard to make Annotated backwards compatible.) Hopefully the PyTorch people are going to start using Python inspection for this like they said in the linked issue.

EDIT: it looks like fixes to this may have been merged? unclear: pytorch/pytorch#29623

@patrick-kidger
Copy link
Owner

Haha!

To answer the question, I agree that seems unclear on whether or not that issue is fixed. Either way, because of that or some other issue, our end use case doesn't seem to working at the moment.

@kharitonov-ivan
Copy link

Hi! Is there any updates about that, guys?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 9, 2021

Not that I know about. As far as I know this is still a limitation in torchscript itself.

If this is a priority for you then you might like to try bringing this up with the torchscript team. They might know more about any possibilities for making this work.

@martenlienen
Copy link

I have found a workaround. Let's say you have the following function

def f(x: TensorType["batch", "feature"]):
    return x.sum()

which you want to use in TorchScript. TorchScript does not like generic types in signatures, but we want to keep the dimension annotations somwhere for documentation purposes. We can work around this with a subclass.

import torch
from torchtyping import TensorType

class BatchedFeatureTensor(TensorType["batch", "feature"]):
    pass

@torch.jit.script
def f(x: BatchedFeatureTensor):
    return x.sum()

print(f(torch.tensor([[-1.0, 2.0, 1.2]])))
print(f.code)

# => tensor(2.2000)
# => def f(x: Tensor) -> Tensor:
# =>   return torch.sum(x)

@Datasciensyash
Copy link

Found another way to deal with torchscript. Just paste the code and call patch_torchscript() before exporting.

import re
import typing as tp

import torch

ttp_regexp = re.compile(r"TensorType\[[^\]]*\]")
torchtyping_replacer = "torch.Tensor"


def _replace_torchtyping(source_lines: tp.List[str]) -> tp.List[str]:

    # Join all lines
    cat_lines = "".join(source_lines)

    # Quick exit, if torchtyping is not used
    if ttp_regexp.search(cat_lines) is None:
        return source_lines

    # Replace TensorType
    cat_lines = ttp_regexp.sub(torchtyping_replacer, cat_lines)

    # Split into lines
    source_lines = cat_lines.split("\n")
    source_lines = [f"{i}\n" for i in source_lines]

    return source_lines


def _torchtyping_destruct_wrapper(func: tp.Callable) -> tp.Callable:
    def _wrap_func(obj: tp.Any, error_msg: tp.Optional[str] = None) -> tp.Tuple[tp.List[str], int, tp.Optional[str]]:
        srclines, file_lineno, filename = func(obj, error_msg)
        srclines = _replace_torchtyping(srclines)
        return srclines, file_lineno, filename

    return _wrap_func


def patch_torchscript() -> None:
    """
    Patch torchscript to work with torchtyping.

    Returns: None.

    """
    # Patch _sources if torch >= 1.10.0, else torch.jit.frontend
    if hasattr(torch, "_sources"):
        src = getattr(torch, "_sources")  # noqa: B009
    else:
        src = getattr(torch.jit, "frontend")  # noqa: B009

    src.get_source_lines_and_file = _torchtyping_destruct_wrapper(src.get_source_lines_and_file)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants