Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flax: Trying to access a property that is accessing a non-existent attribute. - JaxBaseModuleClass #2782

Closed
Zethson opened this issue May 8, 2024 · 2 comments
Labels

Comments

@Zethson
Copy link
Member

Zethson commented May 8, 2024

self = <flax.linen.module.create_descriptor_wrapper.<locals>._DescriptorWrapper object at 0x7f896ff65410>
args = (JaxSCGENVAE(
    # attributes
    n_input = 100
    n_hidden = 800
    n_latent = 100
    n_layers = 2
    dropout_ra...e_layer_norm = 'none'
    kl_weight = 5e-05
    training = True
), <class 'pertpy.tools._scgen._scgenvae.JaxSCGENVAE'>)
kwargs = {}

    def __get__(self, *args, **kwargs):
      # here we will catch internal AttributeError and re-raise it as a
      # more informative and correct error message.
      try:
>       return self.wrapped.__get__(*args, **kwargs)

../../miniconda3/envs/pertpy/lib/python3.11/site-packages/flax/linen/module.py:927: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = JaxSCGENVAE(
    # attributes
    n_input = 100
    n_hidden = 800
    n_latent = 100
    n_layers = 2
    dropout_rat...ribution = 'normal'
    use_batch_norm = 'both'
    use_layer_norm = 'none'
    kl_weight = 5e-05
    training = True
)

    @property
    def device(self):
>       return self.seed_rng.device()
E       AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'device'

../../miniconda3/envs/pertpy/lib/python3.11/site-packages/scvi/module/base/_base_module.py:584: AttributeError

The above exception was the direct cause of the following exception:

    def test_scgen():
        adata = scvi.data.synthetic_iid()
        pt.tl.SCGEN.setup_anndata(
            adata,
            batch_key="batch",
            labels_key="labels",
        )
    
        scg = pt.tl.SCGEN(adata)
>       scg.train(max_epochs=1, batch_size=32, early_stopping=True, early_stopping_patience=25)

tests/tools/test_scgen.py:16: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../miniconda3/envs/pertpy/lib/python3.11/site-packages/scvi/model/base/_jaxmixin.py:74: in train
    self.module.to(device)
../../miniconda3/envs/pertpy/lib/python3.11/site-packages/flax/linen/module.py:701: in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
../../miniconda3/envs/pertpy/lib/python3.11/site-packages/flax/linen/module.py:1233: in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
../../miniconda3/envs/pertpy/lib/python3.11/site-packages/scvi/module/base/_base_module.py:647: in to
    if device is not self.device:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <flax.linen.module.create_descriptor_wrapper.<locals>._DescriptorWrapper object at 0x7f896ff65410>
args = (JaxSCGENVAE(
    # attributes
    n_input = 100
    n_hidden = 800
    n_latent = 100
    n_layers = 2
    dropout_ra...e_layer_norm = 'none'
    kl_weight = 5e-05
    training = True
), <class 'pertpy.tools._scgen._scgenvae.JaxSCGENVAE'>)
kwargs = {}

    def __get__(self, *args, **kwargs):
      # here we will catch internal AttributeError and re-raise it as a
      # more informative and correct error message.
      try:
        return self.wrapped.__get__(*args, **kwargs)
      except AttributeError as e:
>       raise errors.DescriptorAttributeError() from e
E       flax.errors.DescriptorAttributeError: Trying to access a property that is accessing a non-existent attribute. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.DescriptorAttributeError)

So the issue lies with JaxBaseModuleClass here:

    @property
    def device(self):
        return self.seed_rng.device()

Versions:

scvi-tools: 1.1.2

@martinkim0
Copy link
Contributor

Looks like this is being caused by the newest JAX 0.4.27 release. Looking into this rn

@martinkim0
Copy link
Contributor

Closed via #2787

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants