New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyroModelGuideWarmup fails on GPU - probably need to be manually run before trainer.fit()
#2616
Comments
This can probably be addressed using the following modification of the class TrainRunner:
def __call__(self):
# other code .....
from copy import copy
dl = copy(self.data_splitter)
dl.setup()
dl = dl.train_dataloader()
PyroModelGuideWarmup(dl).setup(
self.trainer, self.training_plan, stage="fit"
)
self.trainer.fit(
self.training_plan, self.data_splitter, ckpt_path=self.ckpt_path
)
# other code ..... At this stage Using |
Would be great to hear what you think @martinkim0 and I can add the proposed changes |
Hey sorry, took a look at this and forgot to respond. I think it makes sense to add the fixes to |
Sounds good! Later this week, I will make a PR about this issue - as well as another issue with the second GuideWarmup callback (pyro doesn't track deterministic variables initialised after setup). I think we need to get rid of both pyro GuideWarmup callbacks and just run guide once in |
PyroModelGuideWarmup fails on GPU probably because
Callback.setup()
is called in the accelerator environment in the latest PyTorch Lightning.This test fails on GPU:
Versions:
The text was updated successfully, but these errors were encountered: