Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to Save Client Models and results After Each Round in Flower #3169

Open
kalkite opened this issue Mar 25, 2024 · 4 comments
Open

How to Save Client Models and results After Each Round in Flower #3169

kalkite opened this issue Mar 25, 2024 · 4 comments
Assignees
Labels
question Further information is requested

Comments

@kalkite
Copy link

kalkite commented Mar 25, 2024

Hello,

I am required to use trained models, so I need to save the client models and results after each round for applying the XAI method (SHAP). I found an example in TensorFlow. My goal is to save the state of each client's model after each round of federated learning.

Here is my client.

import torch.optim as optim
from collections import OrderedDict
from typing import Dict, Tuple
from flwr.common import NDArrays, Scalar
import torch
import flwr as fl
from src.deepLearn.model_class import DeepNeuralNetwork
from src.fedLearn.centralized import fed_train, fed_test
from helpers.utils_ import get_device
import time
from log_config import base_logger
from flwr.client import Client

logger = base_logger(__name__)

Results = {}


class FlowerClient(fl.client.NumPyClient):
    def __init__(self, client_id, model, train_loader, test_loader):
        super().__init__()
        self.client_id = client_id
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.model = model
        self.device = get_device()

    def set_parameters(self, parameters):
        """Receive parameters and apply them to the local model."""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)
        print("Parameters Successfully Set for Client {0}".format(int(self.client_id) + 1))

    def get_parameters(self, config: Dict[str, Scalar]):
        """Extract model parameters and return them as a list of numpy arrays."""
        print(f"[Client {int(self.client_id) + 1}] get_parameters")
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        print(f"[Client {int(self.client_id) + 1}] fit, config: {config}")
        self.set_parameters(parameters)
        lr = config['learning_rate']
        optimizer_conf = config['optimizer']
        optimizer = getattr(optim, optimizer_conf)(self.model.parameters(), lr=lr)
        epochs = config['epochs']
        train_loss, train_accuracy = fed_train(model=self.model,
                                               epochs=epochs,
                                               optimizer=optimizer,
                                               train_loader=self.train_loader)
        print(f"Client {int(self.client_id) + 1} train_loss: {train_loss}, train_accuracy: {train_accuracy}")
        # save results of each client to the Results dictionary
        Results[self.client_id] = {"train_loss": train_loss, "train_accuracy": train_accuracy}
        time.sleep(5)
        return self.get_parameters({}), len(self.train_loader), {}

    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        print(f"[Client {int(self.client_id) + 1}] evaluate, config: {config}")
        self.set_parameters(parameters)
        loss, accuracy = fed_test(self.model, self.test_loader)
        time.sleep(5)
        return float(loss), len(self.test_loader), {"accuracy": accuracy}


def generate_client_fn(model, train_loaders, test_loaders):
    def client_fn(client_id):
        try:
            return FlowerClient(client_id=client_id, model=model,
                                train_loader=train_loaders[int(client_id)],
                                test_loader=test_loaders[int(client_id)]
                                )
        except Exception as e:
            logger.error(f"Error occurred in client {client_id}: {e}")
            raise

    return client_fn

Saving the server model for each round.

    @hydra.main(config_path="conf", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    print("config: ", cfg)
    train_loaders, test_loaders = prepare_data_loaders(folder_name=cfg.data_config.folder_name,
                                                       clients=cfg.fed_config.clients,
                                                       class_name=cfg.data_config.class_name,
                                                       num_samples=cfg.data_config.num_samples,
                                                       n_features=cfg.data_config.n_features)
    #
    print("cfg", cfg.model)
    model = instantiate(cfg.model)
    print("model", model)

    save_path = HydraConfig.get().runtime.output_dir
    client_fn = generate_client_fn(model=model, train_loaders=train_loaders,
                                   test_loaders=test_loaders)

    class SaveModelStrategy(fl.server.strategy.FedAvg):
        def __init__(self, model, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.model = model

        def aggregate_fit(
                self,
                server_round,
                results,
                failures,
        ):
            """Aggregate model weights using weighted average and store checkpoint"""

            aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
            if aggregated_parameters is not None:
                print(f"Saving round {server_round} aggregated_parameters...")
                aggregated_ndarrays = fl.common.parameters_to_ndarrays(aggregated_parameters)
                params_dict = zip(self.model.state_dict().keys(), aggregated_ndarrays)
                state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
                self.model.load_state_dict(state_dict, strict=True)
                torch.save(self.model.state_dict(),
                           f"fed_models/multiclass/model_round_"
                           f"{server_round}.pth")
                time.sleep(5)
            return aggregated_parameters, aggregated_metrics

    strategy = SaveModelStrategy(
        model=model,
        fraction_fit=0.1,
        min_fit_clients=cfg.fed_config.num_clients_per_round_fit,
        fraction_evaluate=0.1,
        min_evaluate_clients=cfg.fed_config.num_clients_per_round_eval,
        min_available_clients=cfg.fed_config.num_clients,
        on_fit_config_fn=get_on_fit_config(cfg.config_fit),
        evaluate_fn=get_evaluate_server_fn(model=model, test_loader=test_loaders[0]),
    )

    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=cfg.fed_config.num_clients,
        config=fl.server.ServerConfig(
            num_rounds=cfg.fed_config.num_rounds
        ),
        strategy=strategy,  # our strategy of choice
        client_resources={
            "num_cpus": 2,
            "num_gpus": 0.0,
        },
    )
    print("history: ", history)

    results_path = Path(save_path) / f"fed_results_{cfg.data_config.class_name}.pkl"
    results = {"history": history, "anythingelse": "here"}

    with open(str(results_path), "wb") as h:
        pickle.dump(results, h, protocol=pickle.HIGHEST_PROTOCOL)
@kalkite kalkite added the question Further information is requested label Mar 25, 2024
@adam-narozniak
Copy link
Member

Hi, You can apply any custom logic in the e.g. fit method of the client. e.g. you can save the results and model state once you're done with fitting.

@kalkite
Copy link
Author

kalkite commented Mar 26, 2024

Hi @adam-narozniak, thank you for response. I tried to save the client models, but I am only able to save the most last round of communication. Each recent fit file replaces the previous one. However, my goal is to save the client model for each round.

    def fit(self, parameters, config):
        print(f"[Client {int(self.client_id) + 1}] fit, config: {config}")
        self.set_parameters(parameters)
        lr = config['learning_rate']
        optimizer_conf = config['optimizer']
        optimizer = getattr(optim, optimizer_conf)(self.model.parameters(), lr=lr)
        epochs = config['epochs']
        train_loss, train_accuracy = fed_train(model=self.model,
                                               epochs=epochs,
                                               optimizer=optimizer,
                                               train_loader=self.train_loader)
        print(f"Client {int(self.client_id) + 1} train_loss: {train_loss}, train_accuracy: {train_accuracy}")
        # save results of each client to the Results dictionary
        torch.save(self.model.state_dict(), f"client_{int(self.client_id) + 1}_model.pth")
        time.sleep(5)
        return self.get_parameters({}), len(self.train_loader), {}

@adam-narozniak
Copy link
Member

Hi @kalkite,
You're close. Now, you can also store the round_id in the config sent. In a simple FedAvg I'd do it as following:

FedAvg(other_params,
              on_fit_config_fn=lambda x: {"round_id": x})

but since you're sending already the e.g. lr it's just one new thing to add in the function to the config.

    def fit(self, parameters, config):
        print(f"[Client {int(self.client_id) + 1}] fit, config: {config}")
        self.set_parameters(parameters)
        lr = config['learning_rate']
        # LINE BELOW IS NEW
        round_id = config['round_id']
        optimizer_conf = config['optimizer']
        optimizer = getattr(optim, optimizer_conf)(self.model.parameters(), lr=lr)
        epochs = config['epochs']
        train_loss, train_accuracy = fed_train(model=self.model,
                                               epochs=epochs,
                                               optimizer=optimizer,
                                               train_loader=self.train_loader)
        print(f"Client {int(self.client_id) + 1} train_loss: {train_loss}, train_accuracy: {train_accuracy}")
        # save results of each client to the Results dictionary
        # MODIFY THE PATH TO INCLUDE THE ROUND ID
        torch.save(self.model.state_dict(), f"client_{int(self.client_id) + 1}_round_{round_id}_model.pth")
        time.sleep(5)
        return self.get_parameters({}), len(self.train_loader), {}

@adam-narozniak adam-narozniak self-assigned this Apr 2, 2024
@kalkite
Copy link
Author

kalkite commented Apr 12, 2024

Is this round_id of the client taken from the server round? Something like this?

def get_on_fit_config(client_configs):
    def fit_config_fn(server_round: int):
        return {
            "learning_rate": client_configs['learning_rate'],
            "optimizer": client_configs['optimizer'],
            "epochs": client_configs['epochs'],
            "round_id": server_round
        }
    return fit_config_fn

I keep on_fit_config_fn = get_on_fit_config(cfg.config_fit)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants