Skip to content

EliaFantini/Road-Segmentation-convolutional-neural-network-classifier

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🛣️Road_Segmentation GitHub commit activity GitHub last commit GitHub code size GitHub repo size GitHub follow GitHub fork GitHub watchers GitHub star

This project aims at classifying the pixels who represent a road and those who don't in an aerial/satellite image, thanks to the use of a Convolutional Neural Network (CNN).

The problem was part of an artificial intelligence Road Segmentation Challenge from AICrowd. Our team, called Pasta-Balalaika, reached the position 12/107 on the leaderboard, with an F1 score of 0.910 and an accuracy of 0.952. This project was also done as an assignment of the EPFL course CS-433 Machine Learning.

The following image shows an example of the prediction made by our final model:

Immagine 2022-08-04 222840

Authors

Summary

  • How to install
  • Usage
  • Results and Examples
  • Conclusion and Future work

How to install

  • Download the project repository:
git clone https://github.com/EliaFantini/Road-Segmentation-convolutional-neural-network-classifier.git

cd  Road-Segmentation-convolutional-neural-network-classifier
  • Create a virtual enviroment (download and install Python 3 first if you don't have it already):
python3 -m pip install --user virtualenv

virtualenv -p python3 pasta-balalaika
  • Activate the enviroment:
source pasta-balalaika/bin/activate
  • Install the requirements:
python3 -m pip install -r requirements.txt

Usage

In the repository, the libseg python package contains all the code of the project.

The package experiments contains the final notebook, in which you can find the running of different pipelines and experiments, plots of the losses and metrics.

libseg:

  • dataset: contains the code of class SegmentDataset for creating items for training, validation or testing.

  • losses: contains the code for DiceLoss, which was used during training.

  • model: in this package you could find the code for the class Model - the main class for training and testing.

  • nets:comprises the code of 3 different Neural Networks:

    More information about the neural networks as well as our explanations of why we used these architectures can be found in our report.

  • preprocessing: this module comprises the code for data preprocessing:

    • data augmentation (flipping, rotations)
    • gamma correction
    • clache
    • standardization

    More about preprocessing methods you could find in our report.

  • utils: consists of the helpers such as train_valid_split - function for splitting the data into train and valid, fix_seed - function to fix all random processes during the training such as model initialization, spliting and so on.

Also this module includes the code for cropping images into patches for the final submission and functions for choosing the criterion and net.

  • config file consists of:

    • seed - random state

    • valid_size - size of the validation part,

    • data_path - the global path to the data folder,

    • clahe - the flag for clahe,

    • gamma_correction - the flag for gamma correction,

    • gamma - gamma value,

    • normalize_and_centre - the flag for normalization,

    • data_augmentation - the flag for data augmentation,

    • num_rotations - if data_augmentation is True than this is the number of random rotations,

    • divide_in_patches - the flag for dividing into patches (small subimages),

    • patches_size - if divide_in_patches is True, this is the size of the patches,

    • batch_size - batch size for training,

    • backbone - the name of the net,

    • criterion - the name of the criterion,

    • optimizer_learning_rate - learning rate,

    • epochs - the number of the epochs to train,

    • epochs_print_gap - the gap to verbose current metrics and losses,

    • foreground_threshold - the threshold for background,

    • device - the name of the device (cpu or cuda),

    • create_submission - the flag for creating the submission after training,

    • save_model - the flag for saving the model weights after training,

    • train - train or test,

    • postprocessing - the flag for using ensambling during testing (more in the report)

How to train and test

In order to train the model:

  1. Download the data for train and test here
  2. Put data folder with data into project repository
  3. Run script with following command: python run.py (config['train'] = True):

Using the default config, you can reproduce our result.

The submission will be saved into the file submissions.csv and the model into the file model.pt (if all required flags are True).

The weights of the final model can also be found here.

In order to evalute our final model:

  1. Download the weight from disc.
  2. Put downloaded file model.pt in the project repo (same folder of run.py)
  3. Change config['train'] = False

Results

We receive the best score on test data (0.91 F1 score) with DeepLabPlus Neural Network, DiceLoss. We trained the net 50 epoches with Adam optimizer.

More important details about our work can be found in the report.

In this picture you can see the plots of train/test loss, F1 score and accuracy for different networks and preprocessing pipelines: Immagine 2022-08-05 105935

Conclusion and Future work

In this project we were solving a road segmentation task.

With pretrained DeepLabPlus model link, data preprocessing and ensambling predictions, we achieved 0.910 F1 score on the test dataset.

Many different adaptations, tests, and experiments have been left for the future due to lack of time. The following ideas could be tested in the future:

  • Add extra data to increase model's generalization ability. For example, can take this dataset.
  • Use the combination of DiceLoss and BCELoss (inspired by paper).
  • Try other optimizers, such as AmsGrad and YOGI (inpired by this paper and this paper, respectively)
  • Try an ensemble machine learning algorithm such as Stacked Generalization.

🛠 Skills

Python, PyTorch, Matplotlib, Jupyter Notebooks. Machine learning and convolutional neural network knowledge, analysis of the impact of different preprocessing techniques on training, plotting the experiments, ensuring reproducibility.

🔗 Links

portfolio linkedin