Skip to content

Latest commit

 

History

History
482 lines (329 loc) · 20.5 KB

flamby-integration-into-substra.md

File metadata and controls

482 lines (329 loc) · 20.5 KB

Drawing

Substra is an open source federated learning framework. It provides a flexible Python interface and a web app to run federated learning training at scale.

Substra main usage is a production one. It has already been deployed and used by hospitals and biotech companies (see the MELLODDY project for instance).

Yet Substra can also be used on a single machine on a virtually splitted dataset for two use cases:

  • debugging code before launching experiments on a real network
  • performing FL simulations

Substra was created by Owkin and is now hosted by the Linux Foundation for AI and Data.

SubstraFL example on FLamby TCGA BRCA dataset

This example illustrate the basic usage of SubstraFL, and propose a model training by Federated Learning using the Federated Averaging strategy on the TCGA BRCA dataset.

The objective of this example is to launch a federated learning experiment on the six centers provided on Flamby (called organization on Substra), using the FedAvg strategy on the baseline model.

This example does not use the deployed platform of Substra and will run in local mode.

Requirements

To run this example locally, please make sure to download and unzip in the same directory as this example the assets needed to run it:

Please ensure to have all the libraries installed, a requirements.txt file is included in the example.

You can run the command: pip install -r requirements.txt to install them.

Objective

This example will run a federated training on all 6 centers of the TCGA BRCA FLamby dataset.

This example shows how to interface Substra with FLamby.

This example runs in local mode, simulating a federated learning experiment.

Setup

We work with seven different organizations, defined by their IDs. Six organizations provide a FLamby dataset configuration. The last one provide the algorithm and will register the machine learning tasks.

Once these variables defined, we can create our different Substra Clients (one for each organization/center).

from substra import Client
from flamby.datasets import fed_tcga_brca

MODE = "subprocess"

# Create the substra clients
data_provider_clients = [Client(backend_type=MODE) for _ in range(fed_tcga_brca.NUM_CLIENTS)]
data_provider_clients = {client.organization_info().organization_id: client for client in data_provider_clients}

algo_provider_client = Client(backend_type=MODE)

# Store their IDs
DATA_PROVIDER_ORGS_ID = list(data_provider_clients.keys())

# The org id on which your computation tasks are registered
ALGO_ORG_ID = algo_provider_client.organization_info().organization_id

Data and metrics

Dataset registration

A Dataset is composed of an opener, which is a Python script with the instruction of how to load the data from the files in memory, and a description markdown file.

The Dataset object itself does not contain the data. The proper asset to access them is the datasample asset.

A datasample contains a local path to the data, and the key identifying the Dataset in order to have access to the proper opener.py file.

To interface with FLamby, in the simple case of a single machine and a single runtime where clients are simple Python objects, the utilization of the opener and associated datasample method is implemented as a simple lookup table to indicate which center lives in which client. Substra is a library built to be deployed on a real federated network, where the path to the data is to the opener in order to load and read the data in a personalized way on the premises of the organization that owns the data.

As we directly load a torch dataset from Flamby, the folders parameters is unused and the path usually leading to the data will point out to an empty folder in this example.

As data can not be seen once it is registered on the platform, we set a Permissions object for each Assets, defining their access rights to the different data.

import pathlib

from substra.sdk.schemas import DatasetSpec
from substra.sdk.schemas import DataSampleSpec
from substra.sdk.schemas import Permissions

assets_directory = pathlib.Path.cwd() / "assets"
empty_path = assets_directory / "empty_datasamples"

permissions_dataset = Permissions(public=False, authorized_ids=[ALGO_ORG_ID])

train_dataset_keys = {}
test_dataset_keys = {}

train_datasample_keys = {}
test_datasample_keys = {}


for ind, org_id in enumerate(DATA_PROVIDER_ORGS_ID):
    client = data_provider_clients[org_id]


    # DatasetSpec is the specification of a dataset. It makes sure every field
    # is well defined, and that our dataset is ready to be registered.
    # The real dataset object is created in the add_dataset method.

    dataset = DatasetSpec(
        name="FLamby",
        type="torchDataset",
        data_opener=assets_directory / "dataset" / f"opener_train_{org_id}.py",
        description=assets_directory / "dataset" / "description.md",
        permissions=permissions_dataset,
        logs_permission=permissions_dataset,
    )

    # Add the dataset to the client to provide access to the opener in each organization.
    train_dataset_key = client.add_dataset(dataset)
    assert train_dataset_key, "Missing data manager key"

    train_dataset_keys[org_id] = train_dataset_key

    # Add the training data on each organization.
    data_sample = DataSampleSpec(
        data_manager_keys=[train_dataset_key],
        test_only=False,
        path=empty_path,
    )
    train_datasample_key = client.add_data_sample(
        data_sample,
        local=True,
    )

    train_datasample_keys[org_id] = train_datasample_key

    # Add the testing data.

    test_dataset_key = client.add_dataset(
        DatasetSpec(
            name="FLamby",
            type="torchDataset",
            data_opener=assets_directory / "dataset" / f"opener_test_{org_id}.py",
            description=assets_directory / "dataset" / "description.md",
            permissions=permissions_dataset,
            logs_permission=permissions_dataset,
        )
    )
    assert test_dataset_key, "Missing data manager key"
    test_dataset_keys[org_id] = test_dataset_key

    data_sample = DataSampleSpec(
        data_manager_keys=[test_dataset_key],
        test_only=True,
        path=empty_path,
    )
    test_datasample_key = client.add_data_sample(
        data_sample,
        local=True,
    )

    test_datasample_keys[org_id] = test_datasample_key

Metrics registration

A metric is a function used to compute the score of predictions on one or several datasamples.

To add a metric, you need to define a function that computes and return a performance from the datasamples (as returned by the opener) and the predictions_path (to be loaded within the function).

When using a Torch SubstraFL algorithm, the predictions are saved in the predict function under the numpy format so that you can simply load them using numpy.load.

After defining the metrics dependencies and permissions, we use the add_metric function to register the metric. This metric will be used on the test datasamples to evaluate the model performances.

import numpy as np

from torch.utils import data

from substrafl.dependency import Dependency
from substrafl.remote.register import add_metric


def tgca_brca_metric(datasamples, predictions_path):

    config = datasamples

    dataset = fed_tcga_brca.FedTcgaBrca(**config)
    dataloader = data.DataLoader(dataset, batch_size=len(dataset))

    y_true =  next(iter(dataloader))[1]
    y_pred = np.load(predictions_path)

    return float(fed_tcga_brca.metric(y_true, y_pred))

# The Dependency object is instantiated in order to install the right libraries in
# the Python environment of each organization.
# The local dependencies are local packages to be installed using the command `pip install -e .`.
# Flamby is a local dependency. We put as argument the path to the `setup.py` file.
metric_deps = Dependency(pypi_dependencies=["torch==1.11.0","numpy==1.23.1"],
                         local_dependencies=[pathlib.Path.cwd().parent.parent], # Flamby dependency
                        )
permissions_metric = Permissions(public = False, authorized_ids = DATA_PROVIDER_ORGS_ID + [ALGO_ORG_ID])

metric_key = add_metric(
    client=algo_provider_client,
    metric_function=tgca_brca_metric,
    permissions=permissions_metric,
    dependencies=metric_deps,
)

Specify the machine learning components

This section uses the PyTorch based SubstraFL API to simplify the machine learning components definition.

However, SubstraFL is compatible with any machine learning framework.

Specifying on how much data to train

To specify on how much data to train at each round, we use the Index Generator object.

We specify the batch size and the number of batches to consider for each round (called num_updates).

from substrafl.index_generator import NpIndexGenerator

NUM_UPDATES = 16
SEED = 42

index_generator = NpIndexGenerator(
    batch_size=fed_tcga_brca.BATCH_SIZE,
    num_updates=NUM_UPDATES,
)

Torch Dataset definition

To instantiate a Substrafl Torch Algorithm, you need to define a torch Dataset with a specific __init__ signature, that must contain (self, datasamples, is_inference).

This torch Dataset is normally useful to preprocess your data on the __getitem__ function.

class TorchDataset(fed_tcga_brca.FedTcgaBrca):

    def __init__(self, datasamples, is_inference):
        config = datasamples
        super().__init__(**config)

SubstraFL algo definition

A SubstraFL Algo gathers all the elements that we defined that run locally in each organization. This is the only SubstraFL object that is framework specific (here PyTorch specific).

The torch dataset is passed as a class to the Torch Algorithms. Indeed, this torch Dataset will be instantiated within the algorithm, using the opener functions as datasamples parameters.

Concerning the TorchFedAvgAlgo interaction with FLamby, we need to overwrite the _local_predict function, used to compute the predictions of the model. In the default _local_predictprovided by Substrafl, we assume that the __getitem__ method of the dataset has a keyword argument is_inference used to only return X (and not the tuple X, y).

But as the __getitem__ function is provided by FLamby, the is_inference argument is ignored. We need to overwrite the _local_predict to change the behavior of the function, and can use this opportunity to optimize the computation time of the function using knowledge of the ouput dimension of the TGCA-BRCA dataset.

import torch

from substrafl.algorithms.pytorch import TorchFedAvgAlgo

model = fed_tcga_brca.Baseline()

class MyAlgo(TorchFedAvgAlgo):
    def __init__(self):
        super().__init__(
            model=model,
            criterion=fed_tcga_brca.BaselineLoss(),
            optimizer=fed_tcga_brca.Optimizer(model.parameters(), lr=fed_tcga_brca.LR),
            index_generator=index_generator,
            dataset=TorchDataset,
            seed=SEED,
        )

    def _local_predict(self, predict_dataset: torch.utils.data.Dataset, predictions_path):

        batch_size = self._index_generator.batch_size
        predict_loader = torch.utils.data.DataLoader(predict_dataset, batch_size=batch_size)

        self._model.eval()

        # The output dimension of the model is of size (1,)
        predictions = torch.zeros((len(predict_dataset), 1))

        with torch.inference_mode():
            for i, (x, _) in enumerate(predict_loader):
                x = x.to(self._device)
                predictions[i * batch_size: (i+1) * batch_size] = self._model(x)

        predictions = predictions.cpu().detach()
        self._save_predictions(predictions, predictions_path)

Federated Learning strategies

A FL strategy specifies how to train a model on distributed data. The most well known strategy is the Federated Averaging strategy: train locally a model on every organization, then aggregate the weight updates from every organization, and then apply locally at each organization the averaged updates.

For this example, we choose to use the Federated averaging Strategy, based on the FedAvg paper by McMahan et al., 2017.

from substrafl.strategies import FedAvg

strategy = FedAvg()

Where to train where to aggregate

We specify on which data we want to train our model, using the TrainDataNodes objets. Here we train on the two datasets that we have registered earlier.

The AggregationNode specifies the organization on which the aggregation operation will be computed.

from substrafl.nodes import TrainDataNode
from substrafl.nodes import AggregationNode


aggregation_node = AggregationNode(ALGO_ORG_ID)

train_data_nodes = list()

for org_id in DATA_PROVIDER_ORGS_ID:

    # Create the Train Data Node (or training task) and save it in a list
    train_data_node = TrainDataNode(
        organization_id=org_id,
        data_manager_key=train_dataset_keys[org_id],
        data_sample_keys=[train_datasample_keys[org_id]],
    )
    train_data_nodes.append(train_data_node)

Where and when to test

With the same logic as the train nodes, we create TestDataNodes to specify on which data we want to test our model.

The Evaluation Strategy defines where and at which frequency we evaluate the model, using the given metric(s) that you registered in a previous section.

from substrafl.nodes import TestDataNode
from substrafl.evaluation_strategy import EvaluationStrategy


test_data_nodes = list()

for org_id in DATA_PROVIDER_ORGS_ID:

    # Create the Test Data Node (or testing task) and save it in a list
    test_data_node = TestDataNode(
        organization_id=org_id,
        data_manager_key=test_dataset_keys[org_id],
        test_data_sample_keys=[test_datasample_keys[org_id]],
        metric_keys=[metric_key],
    )
    test_data_nodes.append(test_data_node)

# Test at the end of every round
my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, rounds=1)

Running the experiment

We now have all the necessary objects to launch our experiment. Below a summary of all the objects we created so far:

  • A Client to orchestrate all the assets of our project, using their keys to identify them
  • An Torch Algorithms, to define the training parameters (optimizer, train function, predict function, etc...)
  • A Strategies, to specify the federated learning aggregation operation
  • TrainDataNode, to indicate where we can process training task, on which data and using which opener
  • An Evaluation Strategy, to define where and at which frequency we evaluate the model
  • An AggregationNode, to specify the node on which the aggregation operation will be computed
  • The number of round, a round being defined by a local training step followed by an aggregation operation
  • An experiment folder to save a summary of the operation made
  • The Dependency to define the libraries the experiment needs to run.
from substrafl.experiment import execute_experiment

# Number of time to apply the compute plan.
NUM_ROUNDS = 3

# The Dependency object is instantiated in order to install the right libraries in
# the Python environment of each organization.
# The local dependencies are local packages to be installed using the command `pip install -e .`.
# Flamby is a local dependency. We put as argument the path to the `setup.py` file.
algo_deps = Dependency(pypi_dependencies=["torch==1.11.0"], local_dependencies=[pathlib.Path.cwd().parent.parent])

compute_plan = execute_experiment(
    client=algo_provider_client,
    algo=MyAlgo(),
    strategy=strategy,
    train_data_nodes=train_data_nodes,
    evaluation_strategy=my_eval_strategy,
    aggregation_node=aggregation_node,
    num_rounds=NUM_ROUNDS,
    experiment_folder=str(pathlib.Path.cwd() / "experiment_summaries"),
    dependencies=algo_deps,
)
2022-11-23 16:27:22,920 - INFO - Building the compute plan.
2022-11-23 16:27:22,937 - INFO - Registering the algorithm to Substra.
2022-11-23 16:27:22,965 - INFO - Registering the compute plan to Substra.
2022-11-23 16:27:22,966 - INFO - Experiment summary saved.

Compute plan progress:   0%|          | 0/75 [00:00<?, ?it/s]

Plot results

import pandas as pd
import matplotlib.pyplot as plt
plt.title("Performance evolution on each center of the baseline on Fed-TCGA-BRCA with Federated Averaging training")
plt.xlabel("Rounds")
plt.ylabel("Metric")

performance_df = pd.DataFrame(client.get_performances(compute_plan.key).dict())

for i, id in enumerate(DATA_PROVIDER_ORGS_ID):
    df = performance_df.query(f"worker == '{id}'")
    plt.plot(df["round_idx"], df["performance"], label=f"Client {i} ({id})")

plt.legend(loc=(1.1, 0.3), title="Test set")
plt.show()

png

Download a model

After the experiment, you might be interested in getting your trained model. To do so, you will need the source code in order to reload in memory your code architecture.

You have the option to choose the client and the round you are interested in.

If round_idx is set to None, the last round will be selected by default.

from substrafl.model_loading import download_algo_files
from substrafl.model_loading import load_algo

client_to_dowload_from = DATA_PROVIDER_ORGS_ID[0]
round_idx = None

folder = str(pathlib.Path.cwd() / "experiment_summaries" / compute_plan.key / ALGO_ORG_ID / (round_idx or "last"))

download_algo_files(
    client=data_provider_clients[client_to_dowload_from],
    compute_plan_key=compute_plan.key,
    round_idx=round_idx,
    dest_folder=folder,
)

model = load_algo(input_folder=folder)._model

print(model)
print([p for p in model.parameters()])
Baseline(
  (fc): Linear(in_features=39, out_features=1, bias=True)
)
[Parameter containing:
tensor([[ 0.0398,  0.6503, -0.2928, -0.1939,  0.0650, -0.1311, -0.0354,  0.0586,
         -0.4715, -0.1761, -0.3675, -0.2748,  0.0926, -0.5276, -0.1857, -0.3742,
          0.0829,  0.2343, -0.8615,  0.0280, -0.0237, -0.1865, -0.5507,  0.4314,
         -0.3690, -0.2061,  0.0499, -0.2285,  0.1102, -0.0276,  0.2751,  0.5251,
         -0.5587, -0.6355, -0.6012, -0.1639, -0.3266,  0.0889,  0.2282]],
       requires_grad=True), Parameter containing:
tensor([-0.3652], requires_grad=True)]