Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facing issue with Flower Simulation with ResNet18 and MNIST dataset #3237

Open
EzyHow opened this issue Apr 8, 2024 · 3 comments
Open

Facing issue with Flower Simulation with ResNet18 and MNIST dataset #3237

EzyHow opened this issue Apr 8, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@EzyHow
Copy link

EzyHow commented Apr 8, 2024

Describe the bug

I was trying a example project of Flower Simulation (Flower Simulation Step by Step Pytorch - Part II). Everything went very well until I tried to change the model to resnet18 as given below:

class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net,` self).__init__()
        self.model = models.resnet18()
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)
        summary(self.model, input_size=(1, 28, 28)) # <<== THIS LINE

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        return x

If I add summary(self.model, input_size=(1, 28, 28)) at the end of __init__() method, everything works. But when I remove it, I get error: input_param = input_param[0] IndexError: index 0 is out of bounds for dimension 0 with size 0 in evaluate_fn of server.py:

params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True) # <= At this line I'm getting error

Steps/Code to Reproduce

Clone the repository from Flower Simulation Step by Step Pytorch Part-II and follow instructions to setup the environment.

Then change the model to resnet18 in model.py file as given below:

import torch
import torch.nn as nn
import torchvision.models as models
from flwr.common.parameter import ndarrays_to_parameters
from torchsummary import summary

class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net, self).__init__()

        self.model = models.resnet18()
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)
        summary(self.model, input_size=(1, 28, 28))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        return x

Following is the list of packages installed in the conda environment:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
absl-py                   2.1.0                    pypi_0    pypi
aiohttp                   3.9.3                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
antlr4-python3-runtime    4.9.3                    pypi_0    pypi
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
astunparse                1.6.3                    pypi_0    pypi
async-timeout             4.0.3                    pypi_0    pypi
attrs                     23.2.0                   pypi_0    pypi
blas                      1.0                         mkl  
brotli-python             1.0.9            py39h6a678d5_7  
bzip2                     1.0.8                h5eee18b_5  
ca-certificates           2024.3.11            h06a4308_0  
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
cffi                      1.16.0                   pypi_0    pypi
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.1.7                    pypi_0    pypi
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
contourpy                 1.2.0                    pypi_0    pypi
cryptography              41.0.7                   pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
datasets                  2.18.0                   pypi_0    pypi
debugpy                   1.6.7            py39h6a678d5_0  
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
dill                      0.3.8                    pypi_0    pypi
exceptiongroup            1.2.0              pyhd8ed1ab_2    conda-forge
executing                 2.0.1              pyhd8ed1ab_0    conda-forge
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.13.3                   pypi_0    pypi
flatbuffers               24.3.25                  pypi_0    pypi
flwr                      1.7.0                    pypi_0    pypi
flwr-datasets             0.1.0                    pypi_0    pypi
fonttools                 4.50.0                   pypi_0    pypi
freetype                  2.12.1               h4a9f257_0  
frozenlist                1.4.1                    pypi_0    pypi
fsspec                    2024.2.0                 pypi_0    pypi
gast                      0.5.4                    pypi_0    pypi
gmp                       6.2.1                h295c915_3  
gnutls                    3.6.15               he1e5248_0  
google-pasta              0.2.0                    pypi_0    pypi
grpcio                    1.62.1                   pypi_0    pypi
h5py                      3.10.0                   pypi_0    pypi
huggingface-hub           0.22.1                   pypi_0    pypi
hydra-core                1.3.2                    pypi_0    pypi
idna                      3.4              py39h06a4308_0  
importlib-metadata        7.1.0              pyha770c72_0    conda-forge
importlib-resources       6.4.0                    pypi_0    pypi
importlib_metadata        7.1.0                hd8ed1ab_0    conda-forge
intel-openmp              2023.1.0         hdb19cb5_46306  
ipykernel                 6.29.3             pyhd33586a_0    conda-forge
ipython                   8.18.1             pyh707e725_3    conda-forge
iterators                 0.0.2                    pypi_0    pypi
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jpeg                      9e                   h5eee18b_1  
jsonschema                4.21.1                   pypi_0    pypi
jsonschema-specifications 2023.12.1                pypi_0    pypi
jupyter_client            8.6.1              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.2            py39hf3d152e_0    conda-forge
keras                     3.1.1                    pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
lame                      3.100                h7b6447c_0  
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.38                 h1181459_1  
lerc                      3.0                  h295c915_0  
libclang                  18.1.1                   pypi_0    pypi
libdeflate                1.17                 h5eee18b_1  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 13.2.0               h807b86a_5    conda-forge
libgomp                   13.2.0               h807b86a_5    conda-forge
libiconv                  1.16                 h7f8727e_2  
libidn2                   2.3.4                h5eee18b_0  
libpng                    1.6.39               h5eee18b_0  
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              11.2.0               h1234567_1  
libtasn1                  4.19.0               h5eee18b_0  
libtiff                   4.5.1                h6a678d5_0  
libunistring              0.9.10               h27cfd23_0  
libwebp-base              1.3.2                h5eee18b_0  
lz4-c                     1.9.4                h6a678d5_0  
markdown                  3.6                      pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
matplotlib                3.8.3                    pypi_0    pypi
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
mdurl                     0.1.2                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46344  
mkl-service               2.4.0            py39h5eee18b_1  
mkl_fft                   1.3.8            py39h5eee18b_0  
mkl_random                1.2.4            py39hdb19cb5_0  
ml-dtypes                 0.3.2                    pypi_0    pypi
msgpack                   1.0.8                    pypi_0    pypi
multidict                 6.0.5                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
namex                     0.0.7                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
nettle                    3.7.3                hbbd107a_1  
numpy                     1.26.4           py39h5f9d8c6_0  
numpy-base                1.26.4           py39hb5e798b_0  
omegaconf                 2.3.0                    pypi_0    pypi
openh264                  2.1.1                h4ff587b_0  
openjpeg                  2.4.0                h3ad879b_0  
openssl                   3.2.1                hd590300_1    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
optree                    0.11.0                   pypi_0    pypi
packaging                 24.0               pyhd8ed1ab_0    conda-forge
pandas                    2.2.1                    pypi_0    pypi
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    10.2.0           py39h5eee18b_0  
pip                       23.3.1           py39h06a4308_0  
platformdirs              4.2.0              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.42             pyha770c72_0    conda-forge
protobuf                  4.25.3                   pypi_0    pypi
psutil                    5.9.8            py39hd1e30aa_0    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pyarrow                   15.0.2                   pypi_0    pypi
pyarrow-hotfix            0.6                      pypi_0    pypi
pycparser                 2.21                     pypi_0    pypi
pycryptodome              3.20.0                   pypi_0    pypi
pydantic                  1.10.14                  pypi_0    pypi
pygments                  2.17.2             pyhd8ed1ab_0    conda-forge
pyparsing                 3.1.2                    pypi_0    pypi
pysocks                   1.7.1            py39h06a4308_0  
python                    3.9.19               h955ad1f_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
python_abi                3.9                      2_cp39    conda-forge
pytorch                   1.13.1              py3.9_cpu_0    pytorch
pytorch-mutex             1.0                         cpu    pytorch
pytz                      2024.1                   pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
pyzmq                     25.1.2           py39h6a678d5_0  
ray                       2.6.3                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
referencing               0.34.0                   pypi_0    pypi
requests                  2.31.0           py39h06a4308_1  
rich                      13.7.1                   pypi_0    pypi
rpds-py                   0.18.0                   pypi_0    pypi
scipy                     1.12.0                   pypi_0    pypi
setuptools                68.2.2           py39h06a4308_0  
six                       1.16.0             pyh6c4a22f_0    conda-forge
sqlite                    3.41.2               h5eee18b_0  
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
tbb                       2021.8.0             hdb19cb5_0  
tensorboard               2.16.2                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
tensorflow-io-gcs-filesystem 0.36.0                   pypi_0    pypi
termcolor                 2.4.0                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
torchaudio                0.13.1                 py39_cpu    pytorch
torchsummary              1.5.1                    pypi_0    pypi
torchvision               0.14.1                 py39_cpu    pytorch
tornado                   6.4              py39hd1e30aa_0    conda-forge
tqdm                      4.66.2                   pypi_0    pypi
traitlets                 5.14.2             pyhd8ed1ab_0    conda-forge
typing_extensions         4.9.0            py39h06a4308_1  
tzdata                    2024.1                   pypi_0    pypi
urllib3                   2.1.0            py39h06a4308_1  
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
werkzeug                  3.0.2                    pypi_0    pypi
wheel                     0.41.2           py39h06a4308_0  
wrapt                     1.16.0                   pypi_0    pypi
xxhash                    3.4.1                    pypi_0    pypi
xz                        5.4.6                h5eee18b_0  
yarl                      1.9.4                    pypi_0    pypi
zeromq                    4.3.5                h6a678d5_0  
zipp                      3.18.1                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_0  
zstd                      1.5.5                hc292b87_0 

requirement.txt file

datasets==2.18.0
flwr==1.7.0
hydra-core==1.3.2
omegaconf==2.3.0
torch==1.13.1
torchvision==0.14.1
flwr[simulation]>=1.0, <2.0
matplotlib==3.8.3
scipy==1.12.0
torchsummary==1.5.1

Expected Results

Following is the output when it runs successfully (by adding line summary(self.model, input_size=(1, 28, 28))) :

{'history': History (loss, distributed): round 1: 6.738090056180954 round 2: 3.8934330970048903 History (loss, centralized): round 0: 366.1482033729553 round 1: 97.4027541577816 round 2: 52.76616382226348 History (metrics, centralized): {'accuracy': [(0, 0.1086), (1, 0.8021), (2, 0.8959)]}

Actual Results

When I remove line summary(self.model, input_size=(1, 28, 28)), I get following error:

[2024-04-08 09:43:34,760][flwr][INFO] - Initializing global parameters
[2024-04-08 09:43:34,761][flwr][INFO] - Requesting initial parameters from one random client
[2024-04-08 09:43:37,337][flwr][INFO] - Received initial parameters from one random client
[2024-04-08 09:43:37,338][flwr][INFO] - Evaluating initial parameters
[2024-04-08 09:43:37,644][flwr][ERROR] - index 0 is out of bounds for dimension 0 with size 0
[2024-04-08 09:43:37,646][flwr][ERROR] - Traceback (most recent call last):
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/simulation/app.py", line 308, in start_simulation
    hist = run_fl(
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/app.py", line 225, in run_fl
    hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/server.py", line 92, in fit
    res = self.strategy.evaluate(0, parameters=self.parameters)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/strategy/fedavg.py", line 165, in evaluate
    eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
  File "/root/development/machine-learning-project/server.py", line 42, in evaluate_fn
    model.load_state_dict(state_dict, strict=True)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1657, in load_state_dict
    load(self, state_dict)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1645, in load
    load(child, child_state_dict, child_prefix)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1645, in load
    load(child, child_state_dict, child_prefix)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1639, in load
    module._load_from_state_dict(
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 110, in _load_from_state_dict
    super(_NormBase, self)._load_from_state_dict(
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _load_from_state_dict
    input_param = input_param[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0

[2024-04-08 09:43:37,648][flwr][ERROR] - Your simulation crashed :(. This could be because of several reasons. The most common are: 
	 > Sometimes, issues in the simulation code itself can cause crashes. It's always a good idea to double-check your code for any potential bugs or inconsistencies that might be contributing to the problem. For example: 
		 - You might be using a class attribute in your clients that hasn't been defined.
		 - There could be an incorrect method call to a 3rd party library (e.g., PyTorch).
		 - The return types of methods in your clients/strategies might be incorrect.
	 > Your system couldn't fit a single VirtualClient: try lowering `client_resources`.
	 > All the actors in your pool crashed. This could be because: 
		 - You clients hit an out-of-memory (OOM) error and actors couldn't recover from it. Try launching your simulation with more generous `client_resources` setting (i.e. it seems {'num_cpus': 1, 'num_gpus': 0.0} is not enough for your run). Use fewer concurrent actors. 
		 - You were running a multi-node simulation and all worker nodes disconnected. The head node might still be alive but cannot accommodate any actor with resources: {'num_cpus': 1, 'num_gpus': 0.0}.
Take a look at the Flower simulation examples for guidance <https://flower.dev/docs/framework/how-to-run-simulations.html>.
@EzyHow EzyHow added the bug Something isn't working label Apr 8, 2024
@jafermarq
Copy link
Contributor

Hi @EzyHow, have you added that summary(self.model, input_size=(1, 28, 28)) somewhere else? maybe also in the evaluation in server.py? I wonder if torchsummary is adding something to the state_dict...

@EzyHow
Copy link
Author

EzyHow commented Apr 8, 2024

Flower Simulation Step by Step Pytorch Part-II

Kindly check this repository for detailed code: Testing Flower Simulation

In this repository, please go through the main.log files for three different scenarios given in output directory.

@rhythm1827
Copy link

Hello,

I encountered the same issue and found a solution. I noticed the ndarrays_to_model function in src/model_utils.py. The relevant code is:

def ndarrays_to_model(model: torch.nn.ModuleList, params: List[np.ndarray]):
    """Set model weights from a list of NumPy ndarrays."""
    params_dict = zip(model.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)

Therefore, I changed

state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})

to

state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})

in set_parameters function on client.py and evaluate_fn in server.py. Please also import numpy:

import numpy as np

I hope it will work for you as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants