Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using dataloader in mnist-pytorch example #584

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
129 changes: 86 additions & 43 deletions examples/mnist-pytorch/client/data.py
Expand Up @@ -3,11 +3,60 @@

import torch
import torchvision

from torch.utils.data import Dataset
from torchvision import datasets, transforms
dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)


# Function to save data in chunks
def save_chunks(dataset, chunk_size, folder):
if not os.path.exists(folder):
os.makedirs(folder)
num_chunks = len(dataset) // chunk_size + (1 if len(dataset) % chunk_size != 0 else 0)

for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, len(dataset))
images = []
labels = []
# Extract images and labels
for j in range(start, end):
image, label = dataset[j]
images.append(image)
labels.append(label)

# Stack images into a single tensor and convert labels to a tensor
images = torch.stack(images)
labels = torch.tensor(labels, dtype=torch.long)

# Save the tuple of images and labels
torch.save((images, labels), os.path.join(folder, f'chunk_{i}.pt'))

class SingleChunkDataset(Dataset):
def __init__(self, chunk_file):
"""
Initialize the dataset with the path to a specific chunk file.
"""
# Load the data from the specified chunk file
self.data = torch.load(chunk_file)
self.images = self.data[0] # Images tensor
self.labels = self.data[1] # Labels tensor


def __len__(self):
"""
Return the total number of samples in the chunk.
"""
return len(self.labels)

def __getitem__(self, idx):
"""
Retrieve the image and label at the specified index.
"""
image = self.images[idx]
label = self.labels[idx]
return image, label
def get_data(out_dir='data'):
# Make dir if necessary
if not os.path.exists(out_dir):
Expand All @@ -29,34 +78,29 @@ def load_data(data_path, is_train=True):
:type data_path: str
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
:return: dataset.
:rtype: dataset
"""
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path+'/data/clients/1/mnist.pt')
data_path = os.environ.get("FEDN_DATA_PATH")

data_chunk = os.environ.get("FEDN_DATA_CHUNK")

data = torch.load(data_path)

if is_train:
X = data['x_train']
y = data['y_train']
# Create an instance of the dataset for a specific chunk
data_path += '/data/train_chunks/chunk_' + data_chunk + '.pt'

chunk_dataset = SingleChunkDataset(data_path)
else:
X = data['x_test']
y = data['y_test']
data_path += '/data/test_chunks/chunk_' + data_chunk + '.pt'
chunk_dataset = SingleChunkDataset(data_path)

# Normalize
X = X / 255

return X, y

return chunk_dataset


def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n/parts)
result = []
for i in range(parts):
result.append(dataset[i*local_n: (i+1)*local_n])
return result


def split(out_dir='data'):
Expand All @@ -67,30 +111,29 @@ def split(out_dir='data'):
if not os.path.exists(f'{out_dir}/clients'):
os.mkdir(f'{out_dir}/clients')

# Load and convert to dict
train_data = torchvision.datasets.MNIST(
root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True)
test_data = torchvision.datasets.MNIST(
root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False)
data = {
'x_train': splitset(train_data.data, n_splits),
'y_train': splitset(train_data.targets, n_splits),
'x_test': splitset(test_data.data, n_splits),
'y_test': splitset(test_data.targets, n_splits),
}

# Make splits
for i in range(n_splits):
subdir = f'{out_dir}/clients/{str(i+1)}'
if not os.path.exists(subdir):
os.mkdir(subdir)
torch.save({
'x_train': data['x_train'][i],
'y_train': data['y_train'][i],
'x_test': data['x_test'][i],
'y_test': data['y_test'][i],
},
f'{subdir}/mnist.pt')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_trainset = torchvision.datasets.MNIST(
root=f'{out_dir}/train', transform=transform, train=True)
mnist_testset = torchvision.datasets.MNIST(
root=f'{out_dir}/test', transform=transform, train=False)

# Define chunk size and folders
splits = 2
chunk_size = int(floor(len(mnist_trainset) / n_splits)) # Define the size of each chunk
train_folder = 'data/train_chunks'
test_folder = 'data/test_chunks'

# Save chunks
save_chunks(mnist_trainset, chunk_size, train_folder)
chunk_size = int(floor(len(mnist_testset) / n_splits)) # Define the size of each chunk

save_chunks(mnist_testset, chunk_size, test_folder)





if __name__ == '__main__':
Expand Down
15 changes: 7 additions & 8 deletions examples/mnist-pytorch/client/train.py
Expand Up @@ -7,6 +7,7 @@
from model import load_parameters, save_parameters

from fedn.utils.helpers.helpers import save_metadata
from torch.utils.data import DataLoader

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))
Expand All @@ -33,20 +34,18 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
:type lr: float
"""
# Load data
x_train, y_train = load_data(data_path)

chunk_dataset = load_data(data_path)
chunk_loader = DataLoader(chunk_dataset, batch_size=batch_size, shuffle=True)
# Load parmeters and initialize model
model = load_parameters(in_model_path)

# Train
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
n_batches = int(math.ceil(len(x_train) / batch_size))
n_batches = int(math.ceil(len(chunk_dataset) / batch_size))
criterion = torch.nn.NLLLoss()
for e in range(epochs): # epoch loop
for b in range(n_batches): # batch loop
# Retrieve current batch
batch_x = x_train[b * batch_size:(b + 1) * batch_size]
batch_y = y_train[b * batch_size:(b + 1) * batch_size]
print("start")
for b, (batch_x, batch_y) in enumerate(chunk_loader):
# Train on batch
optimizer.zero_grad()
outputs = model(batch_x)
Expand All @@ -61,7 +60,7 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
'num_examples': len(x_train),
'num_examples': len(chunk_dataset),
'batch_size': batch_size,
'epochs': epochs,
'lr': lr
Expand Down
31 changes: 18 additions & 13 deletions examples/mnist-pytorch/client/validate.py
Expand Up @@ -4,14 +4,14 @@
import torch
from data import load_data
from model import load_parameters

from torch.utils.data import DataLoader
from fedn.utils.helpers.helpers import save_metrics

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def validate(in_model_path, out_json_path, data_path=None):
def validate(in_model_path, out_json_path, data_path=None, batch_size=32):
""" Validate model.

:param in_model_path: The path to the input model.
Expand All @@ -21,9 +21,12 @@ def validate(in_model_path, out_json_path, data_path=None):
:param data_path: The path to the data file.
:type data_path: str
"""
# Load data
x_train, y_train = load_data(data_path)
x_test, y_test = load_data(data_path, is_train=False)

chunk_dataset = load_data(data_path)
chunk_loader = DataLoader(chunk_dataset, batch_size=batch_size, shuffle=True)

chunk_dataset_test = load_data(data_path,is_train=False)
chunk_loader_test = DataLoader(chunk_dataset_test, batch_size=batch_size, shuffle=True)

# Load model
model = load_parameters(in_model_path)
Expand All @@ -32,14 +35,16 @@ def validate(in_model_path, out_json_path, data_path=None):
# Evaluate
criterion = torch.nn.NLLLoss()
with torch.no_grad():
train_out = model(x_train)
training_loss = criterion(train_out, y_train)
training_accuracy = torch.sum(torch.argmax(
train_out, dim=1) == y_train) / len(train_out)
test_out = model(x_test)
test_loss = criterion(test_out, y_test)
test_accuracy = torch.sum(torch.argmax(
test_out, dim=1) == y_test) / len(test_out)

for b, (x_train, y_train) in enumerate(chunk_loader):
train_out = model(x_train)
training_loss = criterion(train_out, y_train)
training_accuracy = torch.sum(torch.argmax(train_out, dim=1) == y_train) / len(train_out)

for b, (x_test, y_test) in enumerate(chunk_loader):
test_out = model(x_test)
test_loss = criterion(test_out, y_test)
test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out)

# JSON schema
report = {
Expand Down
8 changes: 6 additions & 2 deletions examples/mnist-pytorch/docker-compose.override.yaml
Expand Up @@ -17,7 +17,9 @@ services:
service: client
environment:
<<: *defaults
FEDN_DATA_PATH: /app/package/data/clients/1/mnist.pt
FEDN_DATA_PATH: /app/package
<<: *defaults
FEDN_DATA_CHUNK: 1
deploy:
replicas: 1
volumes:
Expand All @@ -29,7 +31,9 @@ services:
service: client
environment:
<<: *defaults
FEDN_DATA_PATH: /app/package/data/clients/2/mnist.pt
FEDN_DATA_PATH: /app/package
<<: *defaults
FEDN_DATA_CHUNK: 2
deploy:
replicas: 1
volumes:
Expand Down