Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscript error in JitTraceEnum_ELBO Torch Version 2.2.1, CUDA Version: 12.3 #3338

Open
mtvector opened this issue Mar 10, 2024 · 1 comment

Comments

@mtvector
Copy link

mtvector commented Mar 10, 2024

Hi there,
Noticed a bug in JitTraceEnum_ELBO. My code runs fine with a previous version of pytorch or with JitTrace_ELBO (I can use RelaxedOneHotCategorical instead of OneHotCategorical for what I was enumerating). I don't personally need this bug fixed at this time, and this bug is out of my depth to understand but figured I'd report it in case someone else notices the same problem:

The error seems to come from a torchscript issue in calculating the Enumerate ELBO in pyro.infer.SVI:

    315 def step(self, *args, **kwargs):
    316     # Compute loss and gradients
    317     with poutine.trace(param_only=True) as param_capture:
--> 318         loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    320     loss_val = torch_item(loss)
    321     self.losses.append(loss_val)

File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/infer/traceenum_elbo.py:564, in JitTraceEnum_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    563 def loss_and_grads(self, model, guide, *args, **kwargs):
--> 564     differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
    565     differentiable_loss.backward()  # this line triggers jit compilation
    566     loss = differentiable_loss.item()

File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/infer/traceenum_elbo.py:561, in JitTraceEnum_ELBO.differentiable_loss(self, model, guide, *args, **kwargs)
    557         return elbo * (-1.0 / self.num_particles)
    559     self._differentiable_loss = differentiable_loss
--> 561 return self._differentiable_loss(*args, **kwargs)

File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/ops/jit.py:120, in CompiledFunction.__call__(self, *args, **kwargs)
    118 with poutine.block(hide=self._param_names):
    119     with poutine.trace(param_only=True) as param_capture:
--> 120         ret = self.compiled[key](*params_and_args)
    122 for name in param_capture.trace.nodes.keys():
    123     if name not in self._param_names:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: default_program(23): error: extra text after expected end of number
      aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
                                                                                                                        ^

default_program(23): error: extra text after expected end of number
      aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
                                                                                                                                                   ^

2 errors detected in the compilation of "default_program".

nvrtc compilation failed: 

#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)


template<typename T>
__device__ T maximum(T a, T b) {
  return isnan(a) ? a : (a > b ? a : b);
}

template<typename T>
__device__ T minimum(T a, T b) {
  return isnan(a) ? a : (a < b ? a : b);
}

extern "C" __global__
void fused_clamp_sub_exp(float* tt_3, float* tshift_1, float* aten_exp) {
{
if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<45150ll ? 1 : 0) {
    float tshift_1_1 = __ldg(tshift_1 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
    float v = __ldg(tt_3 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
    aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
  }}
}

My environment is as follows:

absl-py==2.1.0
aiohttp==3.9.1
aiosignal==1.3.1
anndata==0.10.4
annotated-types==0.6.0
anyio==4.2.0
array_api_compat==1.4.1
arrow==1.3.0
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
attrs==23.2.0
backoff==2.2.1
beautifulsoup4==4.12.3
blessed==1.20.0
boto3==1.34.28
botocore==1.34.28
certifi==2023.11.17
charset-normalizer==3.3.2
chex==0.1.7
click==8.1.7
comm @ file:///work/ci_py311/comm_1677709131612/work
contextlib2==21.6.0
contourpy==1.2.0
croniter==1.4.1
cycler==0.12.1
dateutils==0.6.12
debugpy @ file:///croot/debugpy_1690905042057/work
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
deepdiff==6.7.1
dm-tree==0.1.8
docrep==0.3.2
editor==1.6.6
etils==1.6.0
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
fastapi==0.109.0
filelock @ file:///croot/filelock_1700591183607/work
flax==0.8.0
fonttools==4.47.2
frozenlist==1.4.1
fsspec==2023.12.2
gmpy2 @ file:///work/ci_py311/gmpy2_1676839849213/work
h11==0.14.0
h5py==3.10.0
idna==3.6
igraph==0.11.3
importlib-resources==6.1.1
inquirer==3.2.1
ipykernel @ file:///croot/ipykernel_1705933831282/work
ipython @ file:///croot/ipython_1704833016303/work
itsdangerous==2.1.2
jax==0.4.23
jaxlib==0.4.23
jedi @ file:///work/ci_py311_2/jedi_1679336495545/work
Jinja2 @ file:///work/ci_py311/jinja2_1676823587943/work
jmespath==1.0.1
joblib==1.3.2
jupyter_client @ file:///croot/jupyter_client_1699455897726/work
jupyter_core @ file:///croot/jupyter_core_1698937308754/work
kiwisolver==1.4.5
leidenalg==0.10.2
lightning==2.0.9.post0
lightning-cloud==0.5.61
lightning-utilities==0.10.1
llvmlite==0.41.1
markdown-it-py==3.0.0
MarkupSafe @ file:///croot/markupsafe_1704205993651/work
matplotlib==3.8.2
matplotlib-inline @ file:///work/ci_py311/matplotlib-inline_1676823841154/work
mdurl==0.1.2
mkl-fft @ file:///croot/mkl_fft_1695058164594/work
mkl-random @ file:///croot/mkl_random_1695059800811/work
mkl-service==2.4.0
ml-collections==0.1.1
ml-dtypes @ file:///croot/ml_dtypes_1702691022032/work
mpmath @ file:///croot/mpmath_1690848262763/work
msgpack==1.0.7
mudata==0.2.3
multidict==6.0.4
multipledispatch==1.0.0
natsort==8.4.0
nest-asyncio @ file:///work/ci_py311/nest-asyncio_1676823382924/work
networkx==3.2.1
numba==0.58.1
numpy==1.26.1
numpyro==0.13.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
opt-einsum @ file:///home/conda/feedstock_root/build_artifacts/opt_einsum_1696448916724/work
optax==0.1.8
orbax-checkpoint==0.5.1
ordered-set==4.1.0
packaging @ file:///croot/packaging_1693575174725/work
pandas==2.2.0
parso @ file:///opt/conda/conda-bld/parso_1641458642106/work
patsy==0.5.6
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
pillow==10.2.0
platformdirs @ file:///croot/platformdirs_1692205439124/work
prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work
protobuf==4.25.2
psutil @ file:///work/ci_py311_2/psutil_1679337388738/work
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
pydantic==2.1.1
pydantic_core==2.4.0
Pygments @ file:///croot/pygments_1684279966437/work
PyJWT==2.8.0
pymde==0.1.18
pynndescent==0.5.11
pyparsing==3.1.1
pyro-api==0.1.2
pyro-ppl==1.8.6
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
python-multipart==0.0.6
pytorch-lightning==2.1.3
pytz==2023.3.post1
PyYAML @ file:///croot/pyyaml_1698096049011/work
pyzmq @ file:///croot/pyzmq_1705605076900/work
readchar==4.0.5
requests==2.31.0
rich==13.7.0
runs==1.2.2
s3transfer==0.10.0
scanpy==1.9.6
scikit-learn==1.3.2
scipy==1.11.4
scvi-tools==1.0.4
seaborn==0.13.1
session-info==1.0.0
six @ file:///tmp/build/80754af9/six_1644875935023/work
sniffio==1.3.0
soupsieve==2.5
sparse==0.15.1
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
starlette==0.35.1
starsessions==1.3.0
statsmodels==0.14.1
stdlib-list==0.10.0
sympy @ file:///croot/sympy_1701397643339/work
tensorstore==0.1.52
texttable==1.7.0
threadpoolctl==3.2.0
toolz==0.12.1
torch==2.2.1
torchmetrics==1.3.0.post0
torchvision==0.17.1
tornado @ file:///croot/tornado_1696936946304/work
tqdm==4.66.1
traitlets @ file:///work/ci_py311/traitlets_1676823305040/work
triton==2.2.0
types-python-dateutil==2.8.19.20240106
typing_extensions @ file:///croot/typing_extensions_1705599297034/work
tzdata==2023.4
umap-learn==0.5.5
urllib3==2.0.7
uvicorn==0.27.0
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
websocket-client==1.7.0
websockets==12.0
xarray==2024.1.1
xgboost==2.0.1
xmod==1.8.1
yarl==1.9.4
zipp==3.17.0

Thanks for all the development work, pyro rules!

@fritzo
Copy link
Member

fritzo commented Mar 17, 2024

Thanks for the bug report. My guess is that this is an upstream bug in pytorch code generation where they are writing two decimal points in a floating point constant. I'm not sure what we can do but wait for an upstream fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants