Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28] #33

Open
abcxq opened this issue Dec 26, 2021 · 2 comments

Comments

@abcxq
Copy link

abcxq commented Dec 26, 2021

When I run your CGAN program ,the following error message has occured:
D:\ProgramData\Anaconda3\python.exe I:/gupaocode/pytorch-generative-model-collections-master/main.py --dataset mnist --gan_type CGAN --epoch 50 --batch_size 64
Traceback (most recent call last):
File "I:/gupaocode/pytorch-generative-model-collections-master/main.py", line 111, in
main()
File "I:/gupaocode/pytorch-generative-model-collections-master/main.py", line 82, in main
gan = CGAN(args)
File "I:\gupaocode\pytorch-generative-model-collections-master\CGAN.py", line 94, in init
data = self.data_loader.iter().next()[0]
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 517, in next
data = self._next_data()
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 557, in _next_data
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data_utils\fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data_utils\fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "D:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py", line 106, in getitem
img = self.transform(img)
File "D:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py", line 60, in call
img = t(img)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 889, in call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py", line 221, in forward
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "D:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\functional.py", line 336, in normalize
tensor.sub
(mean).div
(std)
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

@WhiteGL
Copy link

WhiteGL commented Feb 4, 2022

if you are using mnist as your dataset, which only have 1 channel of input, it will occur.
Try to change the transform in line 5 of dataloader.py from transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) to transforms.Normalize(mean=(0.5), std=(0.5)).

@ClayShoaf
Copy link

@WhiteGL Now when I run python main.py --gan_type 'GAN' --dataset 'mnist' I get:

Traceback (most recent call last):
  File "/home/user/testpad/python/pytorch-generative-model-collections/main.py", line 111, in <module>
    main()
  File "/home/user/testpad/python/pytorch-generative-model-collections/main.py", line 80, in main
    gan = GAN(args)
  File "/home/user/testpad/python/pytorch-generative-model-collections/GAN.py", line 90, in __init__
    data = self.data_loader.__iter__().__next__()[0]
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torchvision/datasets/mnist.py", line 145, in __getitem__
    img = self.transform(img)
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torchvision/transforms/transforms.py", line 277, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/home/user/anaconda3/envs/threeten/lib/python3.10/site-packages/torchvision/transforms/functional.py", line 361, in normalize
    raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
TypeError: img should be Tensor Image. Got <class 'PIL.Image.Image'>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants