Skip to content

Simulation Codes for “Relay-Assisted Cooperative Federated Learning”

License

Notifications You must be signed in to change notification settings

zhlinup/Relay-FL

Repository files navigation

Relay-FL

This is the simulation code package for the following paper:

Zehong Lin, Hang Liu, and Ying-Jun Angela Zhang, “Relay-Assisted Cooperative Federated Learning,” IEEE Transactions on Wireless Communications, DOI: 10.1109/TWC.2022.3155596. [ArXiv Version]

The package, written on Python 3 and Matlab, reproduces the numerical results of the proposed algorithm in the above paper.

Abstract of Article:

Federated learning (FL) has recently emerged as a promising technology to enable artificial intelligence (AI) at the network edge, where distributed mobile devices collaboratively train a shared AI model under the coordination of an edge server. To significantly improve the communication efficiency of FL, over-the-air computation allows a large number of mobile devices to concurrently upload their local models by exploiting the superposition property of wireless multi-access channels. Due to wireless channel fading, the model aggregation error at the edge server is dominated by the weakest channel among all devices, causing severe straggler issues. In this paper, we propose a relay-assisted cooperative FL scheme to effectively address the straggler issue. In particular, we deploy multiple half-duplex relays to cooperatively assist the devices in uploading the local model updates to the edge server. The nature of the over-the-air computation poses system objectives and constraints that are distinct from those in traditional relay communication systems. Moreover, the strong coupling between the design variables renders the optimization of such a system challenging. To tackle the issue, we propose an alternating-optimization-based algorithm to optimize the transceiver and relay operation with low complexity. Then, we analyze the model aggregation error in a single-relay case and show that our relay-assisted scheme achieves a smaller error than the one without relays provided that the relay transmit power and the relay channel gains are sufficiently large. The analysis provides critical insights on relay deployment in the implementation of cooperative FL. Extensive numerical results show that our design achieves faster convergence compared with state-of-the-art schemes.

Referencing

If you in any way use this code for research that results in publications, please cite our original article listed above.

Dependencies

This package is written on Matlab and Python 3. It requires the following libraries:

  • Matlab and CVX
  • Python >= 3.5
  • torch
  • torchvision
  • scipy
  • CUDA (if GPU is used)

Documentations (Please also see each file for more details):

  • data/: Store the Fashion-MNIST dataset. When running at the first time, it automatically downloads the dataset from the Interenet.
  • store/: Store output files (*.npz)
  • matlab/: Documents for data and codes to be used in Matlab
    • DATA/: Store files (*.mat) for channel models and optimization results in Matlab
    • training_result/: Store files for training results (*.mat) to be plotted for presentation
    • main_cmp.m: Initialize the simulation system, optimizing the variables
    • Setup_Init.m: Specify and initialize the system parameters
    • AM.m: Alternating minization algorithm proposed in the paper
    • Single.m: Conventional over-the-air model aggregation scheme
    • Xu.m: Existing relay-assisted scheme in Ref. [23]
    • single_relay_channel.m: Construct the channel model for the single-relay case
    • single_relay_channel_loc.m: Construct the channel model for the single-relay case with varying relay location
    • cell_channel_model.m: Construct the channel model for the multi-relay case in a single-cell
    • plot_figure.m: plot the figure with varying transmission blocks from the training results stored in training_result/
    • plot_Pr.m: plot the figure with varying P_r from the training results stored in training_result/
  • main.py: Initialize the simulation system, training the learning model, and storing the result to store/ as a npz file
    • initial(): Initialize the parser function to read the user-input parameters
  • learning_flow.py: Read the optimization result, initial the learning model, and perform training and testing
    • Learning_iter(): Given learning model, compute the graidents, update the training models, and perform testing on top of train_script.py
    • FedAvg_grad(): Given the aggregated model changes and the current model, update the global model by eq.(5)
  • Nets.py:
    • CNNMnist(): Specify the convolutional neural network structure used for learning
    • MLP(): Specify the multiple layer perceptron structure used for learning
  • AirComp.py:
    • AM(): Given the local model changes, perform relay-assisted over-the-air model aggregation; see Section II-C
    • Single(): Given the local model changes, perform conventional over-the-air model aggregation; see Section II-B
    • Xu(): Given the local model changes, perform relay-assisted over-the-air model aggregation scheme proposed in Ref. [23]
  • train_script.py:
    • Load_fmnist_iid(): Download (if needed) and load the Fashion-MNIST data, and distribute them to the local devices
    • Load_fmnist_noniid(): Download (if needed) and load the Fashion-MNIST data, and distribute them to the local devices by following a non-iid distribution
    • local_update(): Given a learning model and the distributed training data, compute the local gradients/model changes
    • test_model(): Given a learning model, test the accuracy/loss based on certain test images
  • plot_result.py: plot the figure with varying transmission blocks from the output files in store/, process and store the training results in matlab/training_result/
  • plot_Pr.py: plot the figure with varying P_r from the output files in store/, process and store the training results in matlab/training_result/

How to Use

  1. Use the codes for channel models in matlab/ to obtain the channel coefficients.

  2. The main file for optimization in Matlab is matlab/main_cmp.m, which optimizes the variables of the proposed relay-assisted scheme and benchmark schemes.

Run matlab/main_cmp.m, the obtained optimization results are then used for FL.

  1. The main file for FL is main.py. It can take the following user-input parameters by a parser (also see the function initial() in main.py):
Parameter Name Meaning Default Value Type/Range
K total number of devices 20 int
N total number of relays 1 int
PL path loss exponent 3.0 float
trial total number of Monte Carlo trials 50 int
SNR -noise variance in dB 100 float
P_r relay transmit power budget 0.1 float
verbose output no/importatnt/detailed messages in running the scripts 1 0, 1
seed random seed 1 int
gpu GPU index used for learning (if possible) 0 int
local_ep number of local epochs, E 1 int
local_bs local batch size, B, 0 for full batch 0 int
lr learning rate, lambda 0.05 float
low_lr learning rate lower bound, bar_lambda 1e-5 float
gamma learning rate decrease ratio, gamma 0.9 float
step learning rate decrease step, bar_T 50 int
momentum SGD momentum, used only for multiple local updates 0.99 float
epochs number of training rounds, T 500 int
iid 1 for iid, 0 for non-iid 1 0, 1
noniid_level number of classes at each device for non-iid 2 2, 4, 6, 8, 10
V_idx Variable index 0 int

Here is an example for executing the scripts in a Linux terminal:

python main.py --gpu=0 --trial=50 --V_idx 0

Releases

No releases published

Packages

No packages published