Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

problem with loading 32px checkpoint (CIFAR10) #329

Open
elad-be opened this issue Aug 8, 2022 · 1 comment
Open

problem with loading 32px checkpoint (CIFAR10) #329

elad-be opened this issue Aug 8, 2022 · 1 comment

Comments

@elad-be
Copy link

elad-be commented Aug 8, 2022

Hi, I'm trying to load the CIFAR10 checkpoint (from NVIDIA official implementation page, after I converted the weights).
I do this with this code
`model = Generator(32, 512, 2)

ckpt = torch.load("path to CIFAR.pt")

self.model.load_state_dict(ckpt['g_ema'])`

and I'm getting this error
RuntimeError: Error(s) in loading state_dict for Generator:
size mismatch for style.1.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]).

how can I solve this problem?

Thanks.

@sergevkim
Copy link

I suppose you tried to convert the weights from stylegan2-ada repository. However, Nvidia's stylegan2-ada implementation is very different in comparison to the rosinality's stylegan2 implementation. stylegan2-ada is a solution for conditional generation problem, and this repo related to unconditional generation. So it's impossible to convert weights automatically

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants