We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
self = <flax.linen.module.create_descriptor_wrapper.<locals>._DescriptorWrapper object at 0x7f896ff65410> args = (JaxSCGENVAE( # attributes n_input = 100 n_hidden = 800 n_latent = 100 n_layers = 2 dropout_ra...e_layer_norm = 'none' kl_weight = 5e-05 training = True ), <class 'pertpy.tools._scgen._scgenvae.JaxSCGENVAE'>) kwargs = {} def __get__(self, *args, **kwargs): # here we will catch internal AttributeError and re-raise it as a # more informative and correct error message. try: > return self.wrapped.__get__(*args, **kwargs) ../../miniconda3/envs/pertpy/lib/python3.11/site-packages/flax/linen/module.py:927: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = JaxSCGENVAE( # attributes n_input = 100 n_hidden = 800 n_latent = 100 n_layers = 2 dropout_rat...ribution = 'normal' use_batch_norm = 'both' use_layer_norm = 'none' kl_weight = 5e-05 training = True ) @property def device(self): > return self.seed_rng.device() E AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'device' ../../miniconda3/envs/pertpy/lib/python3.11/site-packages/scvi/module/base/_base_module.py:584: AttributeError The above exception was the direct cause of the following exception: def test_scgen(): adata = scvi.data.synthetic_iid() pt.tl.SCGEN.setup_anndata( adata, batch_key="batch", labels_key="labels", ) scg = pt.tl.SCGEN(adata) > scg.train(max_epochs=1, batch_size=32, early_stopping=True, early_stopping_patience=25) tests/tools/test_scgen.py:16: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ ../../miniconda3/envs/pertpy/lib/python3.11/site-packages/scvi/model/base/_jaxmixin.py:74: in train self.module.to(device) ../../miniconda3/envs/pertpy/lib/python3.11/site-packages/flax/linen/module.py:701: in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) ../../miniconda3/envs/pertpy/lib/python3.11/site-packages/flax/linen/module.py:1233: in _call_wrapped_method y = run_fun(self, *args, **kwargs) ../../miniconda3/envs/pertpy/lib/python3.11/site-packages/scvi/module/base/_base_module.py:647: in to if device is not self.device: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = <flax.linen.module.create_descriptor_wrapper.<locals>._DescriptorWrapper object at 0x7f896ff65410> args = (JaxSCGENVAE( # attributes n_input = 100 n_hidden = 800 n_latent = 100 n_layers = 2 dropout_ra...e_layer_norm = 'none' kl_weight = 5e-05 training = True ), <class 'pertpy.tools._scgen._scgenvae.JaxSCGENVAE'>) kwargs = {} def __get__(self, *args, **kwargs): # here we will catch internal AttributeError and re-raise it as a # more informative and correct error message. try: return self.wrapped.__get__(*args, **kwargs) except AttributeError as e: > raise errors.DescriptorAttributeError() from e E flax.errors.DescriptorAttributeError: Trying to access a property that is accessing a non-existent attribute. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.DescriptorAttributeError)
So the issue lies with JaxBaseModuleClass here:
JaxBaseModuleClass
@property def device(self): return self.seed_rng.device()
scvi-tools: 1.1.2
The text was updated successfully, but these errors were encountered:
Looks like this is being caused by the newest JAX 0.4.27 release. Looking into this rn
Sorry, something went wrong.
Closed via #2787
No branches or pull requests
So the issue lies with
JaxBaseModuleClass
here:Versions:
scvi-tools: 1.1.2
The text was updated successfully, but these errors were encountered: