Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fed prox example #564

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/mnist-pytorch-fedprox/bin/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
set -e

# Init seed
client/entrypoint init_seed

# Make compute package
tar -czvf package.tgz client
21 changes: 21 additions & 0 deletions examples/mnist-pytorch-fedprox/bin/get_data
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!./.mnist-pytorch/bin/python
import os

import fire
import torchvision


def get_data(out_dir='data'):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Download data
torchvision.datasets.MNIST(
root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True, download=True)
torchvision.datasets.MNIST(
root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False, download=True)


if __name__ == '__main__':
fire.Fire(get_data)
10 changes: 10 additions & 0 deletions examples/mnist-pytorch-fedprox/bin/init_venv.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -e

# Init venv
python3.10 -m venv .mnist-pytorch

# Pip deps
.mnist-pytorch/bin/pip install --upgrade pip
.mnist-pytorch/bin/pip install -e ../../fedn
.mnist-pytorch/bin/pip install -r requirements.txt
51 changes: 51 additions & 0 deletions examples/mnist-pytorch-fedprox/bin/split_data
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!./.mnist-pytorch/bin/python
import os
from math import floor

import fire
import torch
import torchvision


def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n/parts)
result = []
for i in range(parts):
result.append(dataset[i*local_n: (i+1)*local_n])
return result


def split(out_dir='data', n_splits=2):
# Make dir
if not os.path.exists(f'{out_dir}/clients'):
os.mkdir(f'{out_dir}/clients')

# Load and convert to dict
train_data = torchvision.datasets.MNIST(
root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True)
test_data = torchvision.datasets.MNIST(
root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False)
data = {
'x_train': splitset(train_data.data, n_splits),
'y_train': splitset(train_data.targets, n_splits),
'x_test': splitset(test_data.data, n_splits),
'y_test': splitset(test_data.targets, n_splits),
}

# Make splits
for i in range(n_splits):
subdir = f'{out_dir}/clients/{str(i+1)}'
if not os.path.exists(subdir):
os.mkdir(subdir)
torch.save({
'x_train': data['x_train'][i],
'y_train': data['y_train'][i],
'x_test': data['x_test'][i],
'y_test': data['y_test'][i],
},
f'{subdir}/mnist.pt')


if __name__ == '__main__':
fire.Fire(split)
245 changes: 245 additions & 0 deletions examples/mnist-pytorch-fedprox/client/entrypoint
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#!./.mnist-pytorch/bin/python
import collections
import math
import os

import docker
import fire
import torch

from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics
import copy
HELPER_MODULE = 'numpyhelper'
helper = get_helper(HELPER_MODULE)

NUM_CLASSES = 10


def _get_data_path():
""" For test automation using docker-compose. """
# Figure out FEDn client number from container name
client = docker.from_env()
container = client.containers.get(os.environ['HOSTNAME'])
number = container.name[-1]

# Return data path
return f"/var/data/clients/{number}/mnist.pt"


def compile_model():
""" Compile the pytorch model.

:return: The compiled model.
:rtype: torch.nn.Module
"""
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(784, 64)
self.fc2 = torch.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)

def forward(self, x):
x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784)))
x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
return x

return Net()


def load_data(data_path, is_train=True):
""" Load data from disk.

:param data_path: Path to data file.
:type data_path: str
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
"""
if data_path is None:
data = torch.load(_get_data_path())
else:
data = torch.load(data_path)

if is_train:
X = data['x_train']
y = data['y_train']
else:
X = data['x_test']
y = data['y_test']

# Normalize
X = X / 255

return X, y


def save_parameters(model, out_path):
""" Save model paramters to file.

:param model: The model to serialize.
:type model: torch.nn.Module
:param out_path: The path to save to.
:type out_path: str
"""
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
helper.save(parameters_np, out_path)


def load_parameters(model_path):
""" Load model parameters from file and populate model.

param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: torch.nn.Module
"""
model = compile_model()
parameters_np = helper.load(model_path)

params_dict = zip(model.state_dict().keys(), parameters_np)
state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
model.load_state_dict(state_dict, strict=True)
return model


def init_seed(out_path='seed.npz'):
""" Initialize seed model and save it to file.

:param out_path: The path to save the seed model to.
:type out_path: str
"""
# Init and save
model = compile_model()
save_parameters(model, out_path)


def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01, mu=3):
""" Complete a model update.

Load model paramters from in_model_path (managed by the FEDn client),
perform a model update, and write updated paramters
to out_model_path (picked up by the FEDn client).

:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_model_path: The path to save the output model to.
:type out_model_path: str
:param data_path: The path to the data file.
:type data_path: str
:param batch_size: The batch size to use.
:type batch_size: int
:param epochs: The number of epochs to train.
:type epochs: int
:param lr: The learning rate to use.
:type lr: float
"""
print("data_path: ", data_path)
print(os.getcwd())
print("list data path: ", os.listdir('/var/data'))

print("list data/clients path: ", os.listdir('/var/data/clients'))
# Load data
x_train, y_train = load_data(data_path)

# Load parmeters and initialize model
model = load_parameters(in_model_path)
global_model = copy.deepcopy(model)

# Train
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
n_batches = int(math.ceil(len(x_train) / batch_size))
criterion = torch.nn.NLLLoss()
for e in range(epochs): # epoch loop
for b in range(n_batches): # batch loop
# Retrieve current batch
batch_x = x_train[b * batch_size:(b + 1) * batch_size]
batch_y = y_train[b * batch_size:(b + 1) * batch_size]
# Train on batch
optimizer.zero_grad()
outputs = model(batch_x)

proximal_term = 0.0
for w, w_t in zip(model.parameters(), global_model.parameters()):
proximal_term += (w - w_t).norm(2)
#print("proximal_term: ", proximal_term)
ahellander marked this conversation as resolved.
Show resolved Hide resolved

# loss = criterion(outputs, batch_y) # <-- old

# loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term <-- fed prox term
loss = criterion(outputs, batch_y) + (mu / 2) * proximal_term # <-- new

loss.backward()
optimizer.step()
# Log
if b % 100 == 0:
print(
f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}")

# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
'num_examples': len(x_train),
'batch_size': batch_size,
'epochs': epochs,
'lr': lr
}

# Save JSON metadata file (mandatory)
save_metadata(metadata, out_model_path)

# Save model update (mandatory)
save_parameters(model, out_model_path)


def validate(in_model_path, out_json_path, data_path=None):
""" Validate model.

:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_json_path: The path to save the output JSON to.
:type out_json_path: str
:param data_path: The path to the data file.
:type data_path: str
"""
# Load data
x_train, y_train = load_data(data_path)
x_test, y_test = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)
model.eval()

# Evaluate
criterion = torch.nn.NLLLoss()
with torch.no_grad():
train_out = model(x_train)
training_loss = criterion(train_out, y_train)
training_accuracy = torch.sum(torch.argmax(
train_out, dim=1) == y_train) / len(train_out)
test_out = model(x_test)
test_loss = criterion(test_out, y_test)
test_accuracy = torch.sum(torch.argmax(
test_out, dim=1) == y_test) / len(test_out)

# JSON schema
report = {
"training_loss": training_loss.item(),
"training_accuracy": training_accuracy.item(),
"test_loss": test_loss.item(),
"test_accuracy": test_accuracy.item(),
}
print("validation data: ", report)
# Save JSON
save_metrics(report, out_json_path)


if __name__ == '__main__':
fire.Fire({
'init_seed': init_seed,
'train': train,
'validate': validate,
})
5 changes: 5 additions & 0 deletions examples/mnist-pytorch-fedprox/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
entry_points:
train:
command: /venv/bin/python entrypoint train $ENTRYPOINT_OPTS
validate:
command: /venv/bin/python entrypoint validate $ENTRYPOINT_OPTS