Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

After passing the TRAIN function, the model parameters were not updated, resulting in the final evaluation criteria output of the TEST function being the same as the initial stage #3164

Open
CHENxx23 opened this issue Mar 21, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@CHENxx23
Copy link

Describe the bug

In my train function, I saved the optimal model parameters in checkpoint.pth by the following way, which led me to get the model in the test function are the model parameters of the initial stage, and did not receive the model parameters of the parameter update, the training of the final output and the initial stage of the same, here is not my set_parameters() function and the get_parameters() function and this way of saving model parameters, how can I solve it?

Steps/Code to Reproduce

early_stopping(vali_loss, self.model, path)
if early_stopping.early_stop:
print("Early stopping")
break
best_model_path = path + '/' + 'checkpoint.pth'
self.model.load_state_dict(torch.load(best_model_path))

    return self.model

def test(self, model):
test_data, test_loader = self._get_data(flag='test')
train_data, train_loader = self._get_data(flag='train')
test_steps = len(train_loader)
criterion = self._select_criterion()
preds = []
trues = []
inputx = []
self.model.eval()

def set_parameters(self, parameters):
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
# now replace the parameters
self.model.load_state_dict(state_dict, strict=True)
# print("Parameters set successfully:", state_dict)

def get_parameters(self, config):
    print(f"[Client {self.cid}] get_parameters")
    return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

def fit(self, parameters, config):
    print(f"[Client {self.cid}] fit, config: {config}")
    # copy parameters sent by the server into client's local model
    self.set_parameters(parameters)
    # Train model
    exp = Exp_Long_Term_Forecast(self.args)
    exp.train(self.setting)
    # return the model parameters to the server as well as extra info (number of training examples in this case)
    return (
        self.get_parameters({}),
        0,
        {},
    )

def evaluate(self, parameters, config):
    self.set_parameters(parameters)
    exp = Exp_Long_Term_Forecast(self.args)
    # loss, metrics = exp.test(self.setting, test=1)
    loss, metrics = exp.test(self.model)
    return loss, 0, metrics

Expected Results

My hope is that the test function receives the model parameters as trained by the train function, and the final output should be evaluated on different criteria than the first time around

Actual Results

DEBUG flwr 2024-03-20 23:38:35,516 | server.py:187 | evaluate_round 1 received 3 results and 0 failures
WARNING flwr 2024-03-20 23:38:35,516 | fedavg.py:273 | No evaluate_metrics_aggregation_fn provided
INFO flwr 2024-03-20 23:38:35,517 | server.py:153 | FL finished in 4911.190346067073
INFO flwr 2024-03-20 23:38:35,597 | app.py:226 | app_fit: losses_distributed [(1, nan)]
INFO flwr 2024-03-20 23:38:35,597 | app.py:227 | app_fit: metrics_distributed_fit {}
INFO flwr 2024-03-20 23:38:35,597 | app.py:228 | app_fit: metrics_distributed {}
INFO flwr 2024-03-20 23:38:35,597 | app.py:229 | app_fit: losses_centralized [(0, 0.7456194370291954), (1, 0.7456194370291954)]
INFO flwr 2024-03-20 23:38:35,597 | app.py:230 | app_fit: metrics_centralized {'MAE': [(0, 0.5941068), (1, 0.5941068)], 'MSE': [(0, 0.7456194), (1, 0.7456194)], 'RMSE': [(0, 0.86349255), (1, 0.86349255)]}

@CHENxx23 CHENxx23 added the bug Something isn't working label Mar 21, 2024
@jafermarq
Copy link
Contributor

Hi @CHENxx23 , I see your fit() method is returning a 0 as second argument. The second argument should return the number of datapoints in the client (so a weighted average can be performed for aggregation). With a 0 if you are using FedAvg you'll be dividing by zero. That might explain why you see this line in the log:

INFO flwr 2024-03-20 23:38:35,597 | app.py:226 | app_fit: losses_distributed [(1, nan)] # <-- see the NaN

@CHENxx23
Copy link
Author

Do you mean that the second parameter I returned in the evaluate function was 0, which caused the output value to appear as nan and the output result to be the same for each round

@jafermarq
Copy link
Contributor

Do you mean that the second parameter I returned in the evaluate function was 0, which caused the output value to appear as nan and the output result to be the same for each round

Yes, that's highly likely be one of the issues you are having. The return signature of a fit() method is: <parameters as list of numpy arrays>, <number_of_examples>, <metrics>. You can check this in all our examples, including this one for PyTorch:

return self.get_parameters(config={}), len(trainloader.dataset), {}

@CHENxx23
Copy link
Author

I don't think that was the problem, I changed the code and still ended up with the same problem I had before.Here is the part of the code I modified.
def fit(self, parameters, config):
print(f"[Client {self.cid}] fit, config: {config}")
self.set_parameters(parameters)
exp = Exp_Long_Term_Forecast(self.args)
exp.train(self.setting)
train_dataset, _ = exp.get_data(flag='train')
return self.get_parameters({}), len(train_dataset), {},

def get_data(self, flag):
    data_set, data_loader = data_provider(self.args, flag)
    return data_set, data_loader

def data_provider(args, flag, is_shuffle=None):
data_set = Data(
root_path=args.root_path,
data_path=args.data_path,
flag=flag,
size=[args.seq_len, args.label_len, args.pred_len],
features=args.features,
target=args.target,
timeenc=timeenc,
percent=percent,
freq=freq,
seasonal_patterns=args.seasonal_patterns,
)
return data_set, data_loader

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants