Skip to content

Saliency Detection library (models, loss, utils) with PyTorch

License

Notifications You must be signed in to change notification settings

riccardomusmeci/saldet

Repository files navigation

saldet

Saliency Detection (saldet) is a collection of models and tools to perform Saliency Detection with PyTorch (cuda, mps, etc.).

PyPI Version Build Status Code Coverage

Models

List of saliency detection models supported by saldet:

Weights

  • PGNet -> weights from PGNet repo converted to saldet version from here
  • U2Net Lite -> weights from here (U2Net repository)
  • U2Net Full -> weights from here (U2Net repository)
  • U2Net Full - Portrait -> weights for portrait images from here (U2Net repository)
  • U2Net Full - Human Segmentation -> weights for segmenting humans from here (U2Net repository)
  • PFAN -> weights from PFAN repo converted to saldet version from here

To load pre-trained weights:

from saldet import create_model
model = create_model("pgnet", checkpoint_path="PATH/TO/pgnet.pth")

Train

Easy Mode

The library comes with easy access to train models thanks to the amazing PyTorch Lightning support.

from saldet.experiment import train

train(
    data_dir=...,
    config_path="config/u2net_lite.yaml", # check the config folder with some configurations
    output_dir=...,
    resume_from=...,
    seed=42
)

Once the training is over, configuration file and checkpoints will be saved into the output dir.

[WARNING] The dataset must be structured as follows:

dataset
    ├── train                    
    |       ├── images          
    |       │   ├── img_1.jpg
    |       │   └── img_2.jpg                
    |       └── masks
    |           ├── img_1.png
    |           └── img_2.png   
    └── val
           ├── images          
           │   ├── img_10.jpg
           │   └── img_11.jpg                
           └── masks
               ├── img_10.png
               └── img_11.png   

PyTorch Lighting Mode

The library provides utils for model and data PyTorch Lightning Modules.

import pytorch_lightning as pl
from saldet import create_model
from saldet.pl import
 SaliencyPLDataModule, SaliencyPLModel
from saldet.transform import SaliencyTransform

# datamodule
datamodule = SaliencyPLDataModule(
    root_dir=data_dir,
    train_transform=SaliencyTransform(train=True, **config["transform"]),
    val_transform=SaliencyTransform(train=False, **config["transform"]),
    **config["datamodule"],
)

model = create_model(...)
criterion = ...
optimizer = ...
lr_scheduler = ...

pl_model = SaliencyPLModel(
    model=model, criterion=criterion, optimizer=optimizer, lr_scheduler=lr_scheduler
)

trainer = pl.Trainer(...)

# fit
print(f"Launching training...")
trainer.fit(model=pl_model, datamodule=datamodule)

PyTorch Mode

Alternatively you can define your custom training process and use the create_model() util to use the model you like.

Inference

The library comes with easy access to inference saliency maps from a folder with images.

from saldet.experiment import inference

inference(
    images_dir=...,
    ckpt=..., # path to ckpt/pth model file
    config_path=..., # path to configuration file from saldet train
    output_dir=..., # where to save saliency maps
    sigmoid=..., # whether to apply sigmoid to predicted masks
)

To-Dos

[ ] Improve code coverage

[ ] ReadTheDocs documentation