Skip to content

An implementation of several unsupervised object discovery models (Slot Attention, SLATE, GNM) in PyTorch with pre-trained models.

License

Notifications You must be signed in to change notification settings

HHousen/object-discovery-pytorch

Repository files navigation

Object Discovery PyTorch

This is an implementation of several unsupervised object discovery models (Slot Attention, SLATE, GNM) in PyTorch.

GitHub license Github commits made-with-python GitHub issues GitHub pull-requests

This repo is in active development. Expect some breaking changes.

The initial code for this repo was forked from untitled-ai/slot_attention.

Visualization of a slot attention model trained on CLEVR6. This image demonstrates the model's ability to divide objects into slots.

Setup

Requirements

  • Poetry
  • Python >= 3.9
  • CUDA enabled computing device

Getting Started

  1. Clone the repo: git clone https://github.com/HHousen/slot-attention-pytorch/ && cd slot-attention-pytorch.
  2. Install requirements and activate environment: poetry install then poetry shell.
  3. Download the CLEVR (with masks) dataset (or the original CLEVR dataset by running ./data_scripts/download_clevr.sh /tmp/CLEVR). More details about the datasets are below.
  4. Modify the hyperparameters in object_discovery/params.py to fit your needs. Make sure to change data_root to the location of your dataset.
  5. Train a model: python -m slot_attention.train.

Pre-trained Models

Code to load these models can be adapted from predict.py.

Model Dataset Download
Slot Attention CLEVR6 Masks Google Drive
Slot Attention Sketchy Google Drive
GNM CLEVR6 Masks Google Drive
Slot Attention ClevrTex6 Google Drive
GNM ClevrTex6 Google Drive
SLATE CLEVR6 Masks Google Drive

Usage

Train a model by running python -m slot_attention.train.

Hyperparameters can be changed in object_discovery/params.py. training_params has global parameters that apply to all model types. These parameters can be overridden if the same key is present in slot_attention_params or slate_params. Change the global parameter model_type to sa to use Slot Attention (SlotAttentionModel in slot_attention_model.py) or slate to use SLATE (SLATE in slate_model.py). This will determine which model's set of parameters will be merged with training_params.

Perform inference by modifying and running the predict.py script.

Models

Our implementations are based on several open-source repositories.

  1. Slot Attention ("Object-Centric Learning with Slot Attention"): untitled-ai/slot_attention & Official
  2. SLATE ("Illiterate DALL-E Learns to Compose"): Official
  3. GNM ("Generative Neurosymbolic Machines"): karazijal/clevrtex & Official

Datasets

Select a dataset by changing the dataset parameter in object_discovery/params.py to the name of the dataset: clevr, shapes3d, or ravens. Then, set the data_root parameter to the location of the data. The code for loading supported datasets is in object_discovery/data.py.

  1. CLEVR: Download by executing download_clevr.sh.
  2. CLEVR (with masks): Original TFRecords Download / Our HDF5 PyTorch Version.
    • This dataset is a regenerated version of CLEVR but with ground-truth segmentation masks. This enables the training script to calculate Adjusted Rand Index (ARI) during validation runs.
    • The dataset contains 100,000 images with a resolution of 240x320 pixels. The dataloader splits them 70K train, 15K validation, 15k test. Test images are not used by the object_discovery/train.py script.
    • We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the data_scripts/preprocess_clevr_with_masks.py script, which takes approximately 2 hours to execute depending on your machine.
  3. 3D Shapes: Official Google Cloud Bucket
  4. RAVENS Robot Data: Official Train & Official Test
    • We generated a dataset similar in structure to CLEVR (with masks) but of robotic images using RAVENS. Our modified version of RAVENS used to generate the dataset is HHousen/ravens.
    • The dataset contains 85,002 images split 70,002 train and 15K validation/test.
  5. Sketchy: Download and process by following directions in applied-ai-lab/genesis / Download Our Processed Version
  6. ClevrTex: Download by executing download_clevrtex.sh. Our dataloader needs to index the entire dataset before training can begin. This can take around 2 hours. Thus, it is recommended to download our pre-made index from this Google Drive folder and put it in ./data/cache/.
  7. Tetrominoes: Original TFRecords Download / Our HDF5 PyTorch Version.
    • There are 1,000,000 samples in the dataset. However, following the Slot Attention paper, we only use the first 60K samples for training.
    • We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the data_scripts/preprocess_tetrominoes.py script, which takes approximately 2 hours to execute depending on your machine.

Logging

To log outputs to wandb, run wandb login YOUR_API_KEY and set is_logging_enabled=True in SlotAttentionParams.

If you use a dataset with ground-truth segmentation masks, then the Adjusted Rand Index (ARI), a clustering similarity score, will be logged for each validation loop. We convert the implementation from deepmind/multi_object_datasets to PyTorch in object_discovery/segmentation_metrics.py.

More Visualizations

Slot Attention CLEVR10 Slot Attention Sketchy

Visualizations (above) for a model trained on CLEVR6 predicting on CLEVR10 (with no increase in number of slots) and a model trained and predicting on Sketchy. The order from left to right of the images is original, reconstruction, raw predicted segmentation mask, processed segmentation mask, and then the slots.

Slot Attention ClevrTex6 GNM ClevrTex6

The Slot Attention visualization image order is the same as in the above visualizations. For GNM, the order is original, reconstruction, ground truth segmentation mask, prediction segmentation mask (repeated 4 times).

SLATE CLEVR6 GNM CLEVR6

For SLATE, the image order is original, dVAE reconstruction, autoregressive reconstruction, and then the pixels each slot pays attention to.

References

  1. untitled-ai/slot_attention: An unofficial implementation of Slot Attention from which this repo was forked.
  2. Slot Attention: Official Code / "Object-Centric Learning with Slot Attention".
  3. SLATE: Official Code / "Illiterate DALL-E Learns to Compose".
  4. IODINE: Official Code / "Multi-Object Representation Learning with Iterative Variational Inference". In the Slot Attention paper, IODINE was frequently used for comparison. The IODINE code was helpful to create this repo.
  5. Multi-Object Datasets: deepmind/multi_object_datasets. This is the original source of the CLEVR (with masks) dataset.
  6. Implicit Slot Attention: "Object Representations as Fixed Points: Training Iterative Refinement Algorithms with Implicit Differentiation". This paper explains a one-line change that improves the optimization of Slot Attention while simultaneously making backpropagation have constant space and time complexity.

About

An implementation of several unsupervised object discovery models (Slot Attention, SLATE, GNM) in PyTorch with pre-trained models.

Topics

Resources

License

Stars

Watchers

Forks