Saliency Detection (saldet) is a collection of models and tools to perform Saliency Detection with PyTorch (cuda, mps, etc.).
List of saliency detection models supported by saldet:
- U2Net - https://arxiv.org/abs/2005.09007v3 (U2Net repo)
- PGNet - https://arxiv.org/abs/2204.05041 (follow training instructions from PGNet repo)
- PFAN - https://arxiv.org/pdf/1903.00179v2.pdf (PFAN repo)
- PGNet -> weights from PGNet repo converted to saldet version from here
- U2Net Lite -> weights from here (U2Net repository)
- U2Net Full -> weights from here (U2Net repository)
- U2Net Full - Portrait -> weights for portrait images from here (U2Net repository)
- U2Net Full - Human Segmentation -> weights for segmenting humans from here (U2Net repository)
- PFAN -> weights from PFAN repo converted to saldet version from here
To load pre-trained weights:
from saldet import create_model
model = create_model("pgnet", checkpoint_path="PATH/TO/pgnet.pth")
The library comes with easy access to train models thanks to the amazing PyTorch Lightning support.
from saldet.experiment import train
train(
data_dir=...,
config_path="config/u2net_lite.yaml", # check the config folder with some configurations
output_dir=...,
resume_from=...,
seed=42
)
Once the training is over, configuration file and checkpoints will be saved into the output dir.
[WARNING] The dataset must be structured as follows:
dataset
├── train
| ├── images
| │ ├── img_1.jpg
| │ └── img_2.jpg
| └── masks
| ├── img_1.png
| └── img_2.png
└── val
├── images
│ ├── img_10.jpg
│ └── img_11.jpg
└── masks
├── img_10.png
└── img_11.png
The library provides utils for model and data PyTorch Lightning Modules.
import pytorch_lightning as pl
from saldet import create_model
from saldet.pl import
SaliencyPLDataModule, SaliencyPLModel
from saldet.transform import SaliencyTransform
# datamodule
datamodule = SaliencyPLDataModule(
root_dir=data_dir,
train_transform=SaliencyTransform(train=True, **config["transform"]),
val_transform=SaliencyTransform(train=False, **config["transform"]),
**config["datamodule"],
)
model = create_model(...)
criterion = ...
optimizer = ...
lr_scheduler = ...
pl_model = SaliencyPLModel(
model=model, criterion=criterion, optimizer=optimizer, lr_scheduler=lr_scheduler
)
trainer = pl.Trainer(...)
# fit
print(f"Launching training...")
trainer.fit(model=pl_model, datamodule=datamodule)
Alternatively you can define your custom training process and use the create_model()
util to use the model you like.
The library comes with easy access to inference saliency maps from a folder with images.
from saldet.experiment import inference
inference(
images_dir=...,
ckpt=..., # path to ckpt/pth model file
config_path=..., # path to configuration file from saldet train
output_dir=..., # where to save saliency maps
sigmoid=..., # whether to apply sigmoid to predicted masks
)
[ ] Improve code coverage
[ ] ReadTheDocs documentation