Skip to content

discovery-unicamp/spinn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 

Repository files navigation

Code Availability Notice

We are currently working on making SPINN available under a license agreement. If you are interested, please contact the authors.

SPINN - Scalable Parallel INference Network

SPINN (Scalable Parallel INference Network) is a framework designed to perform ML model inference in large scale. It supports loading ML models from different frameworks (PyTorch, TensorFlow, and ONNX) in computational clusters and then subdvide the input data (a large data volume) into smaller chunks, then perform the model inference on the chunks, followed by a reconstruction of the output volume. Given that it allows to have overlapping patch grids for the same data region, it also provides 3 different combination strategies, weighted average, hard voting, and soft voting.

The engines that can be used to run the inference in a parallel and distributed manner are Ray and Dask. The clusters for each one of these engines can be created or connected to in 3 ways. SLURM, Local or Running.

Patching Operation

Patching Operation

Sliding Window Operation

Sliding Window Operation

Config Files

Config files are the way inference tasks are configured. They can be JSON files that are passed to the script to be read, or if working on a Python script, it can be a Python dictionary.

Base Config

Base config has configurations shared by Dask and Ray executions, the fields that can be set are as follows.

Required Fields

  • cluster: which cluster type to use. Acceptable values: local (single node cluster instantiation), slurm (single or multi-node instantiation inside a SLURM job), and running (connect to running cluster).
  • combine: which combination strategy to use. Acceptable values: weighted_avg (weighted average), soft (soft voting), and hard (hard voting).
  • inference_type: which inference strategy to use. Acceptable values: patch (Patch Inference with Offsets), sliding_window (Sliding Window Inference with Strides).
  • chunksize: list indicating the dimensions of the chunks to read from the source file, e.g., [128, 128, 128].

Optional Fields

  • hardware: hardware to use on inference. Acceptable values: gpu and cpu_only. Defaults to gpu.
  • pad_mode: which pad mode to use. Acceptable values: constant, reflect, symmetric, and edge (see NumPy pad documentation for reference). Defaults to constant.
  • pad_value: value used as padding when pad_mode=constant. Defaults to 0.
  • pad_width: list with pad for each of the input data dimensions. If the data has 3 dimensions, [10, 10] would be an invalid pad_width, while [0, 10, 10] is valid.
  • offsets: list of offset tuples to extract overlapping patch grids, e.g., [[-32, -32, -32], [0, 16, 16]], used in Patch Inference(inference_type=patch). The provided example will extract two additional patch grids offset from the starting point of the base patch grid, base_grid_point + offset is used to get the starting point of the offset grid.
  • strides: strides for each dimension, e.g., [10, 10, 1], used in Sliding Window Inference(inference_type=sliding_window).
  • weight_function: function that computes a weight distribution for a given patch shape. See spinn.core.weights for reference. The function must be passed as a list. If the function is implemented at spinn.core.weights the list will contain only the name of the function, e.g, ["weight_gaussian"], which then uses spinn.core.weights.weight_gaussian to generate the weight distribution. In case it's a user-defined weight distribution, the path to the Python script with the implementation and the function name must be provided, e.g., ["path/to/my/python/script.py", "my_weights"], which then will load the my_weights function from the file path/to/my/python/script.py.

Dask Config

Dask config is an optional field that will be at the root of the dictionary, alongside the base config fields. It is used to set Dask cluster configurations.

Required Fields

  • address: cluster scheduler address, required only when cluster=running.

Optional Fields

  • protocol: communication protocol between workers and scheduler. Acceptable values: tcp and ucx. Defaults to tcp.
  • interface: network interface to use. Acceptable values are strings, such as ib0. Defaults to None, which will use the default network interface of the system.
  • scheduler_port: Dask scheduler port. Acceptable values are ints. Defaults to 8786.
  • n_workers: number of Dask workers to instantiate per node on the cluster. Acceptable values are ints. If not set will default to 1 worker per GPU when hardware=gpu, and 1 worker per node when hardware=cpu_only.
  • threads_per_worker: number of threads per worker. Acceptable values are ints. If not set defaults to Dask default behavior.

Ray Config

Ray config is an optional field that will be at the root of the dictionary, alongside the base config fields. It is used to set Ray cluster configurations.

Required Fields

  • address: Ray head node address, required only when cluster=running.
  • port: Ray head node connection port, required only when cluster=running.

Optional Fields

  • port: Ray connection port, optional when cluster!=running. Acceptable values are ints. Defaults to 6379.

Model Config

Model config is a required field that will be at the root of the dictionary, alongside the base config fields. It is used to correctly load and instantiate the ML model across the cluster. It has a set of required and optional fields, but also runtime specific fields.

Required Fields

  • runtime: runtime to use in model loading and inference. Supported values are torch, tensorflow, and onnx.
  • model_input_shape: expected input shape of the model, e.g., [1, 128, 128].
  • dtype: data type of the model input. Supported values are float16, float32, float64, int8, int16, int32, and int64.
  • max_batch_size: limits the number of data items that will be used on inference at a time. Especially important in GPU inference, to avoid Out-of-Memory problems. Acceptable values are ints.

Optional Fields

  • device: device to use for inference. Supported values are cpu and cuda. Defaults to cpu.
  • model_output_shape: output shape of the model. Defaults to model_input_shape value.
  • processing_steps: path to Python script containing pre/post processing steps. Functions need to be called preprocessing and postprocessing for functions to be called outside ModelWrapper call method, or preprocessing_protected and postprocessing_protected for functions to be called inside ModelWrapper call method (at that point data will already be converted and resident on the device used for inference).
  • lost_dim: used for 2.5D inference that uses neighboring region extracted from the volume, the model input size for the dimension marked as lost must be an odd number (Region of Interest + 2*considered neighborhood), and pad_width must match the size of neighborhood to be considered. As an example, if I have a model that has (128, 5, 128) as input and produces (128, 1, 128); the lost_dim will be 1, the pad_width for dimension 1 needs to be 2, so a valid pad_width would be (10, 2, 10).
  • join_dim: used for multiple inputs inference, the model input shape needs to be specified as 1 on the dimension that will be used to concatenate the inputs. As an example, if working with 4 inputs (A, B, C, D), a model_input_shape of (1, 128, 128) and join_dim=1, the data that will be received by the model will have shape (N, 4, 128, 128), where N is the number of patches on the batch. Another important thing is that the inputs will be concatenated in the order they were provided, so data[:,0,:,:] is from input A, data[:,1,:,:] from B, and so on.

Runtime Specific Fields: Torch

  • type: type of the Torch model being loaded. Supported types are class, file, and torchscript.
  • class_file: path to Python file containing model class definition. Required when type=class.
  • class_name: name of the model class. Required when type=class.
  • kwargs_file: path to Python file containing a dictionary named kwargs to be used in model class instantiation. Optional when type=class.
  • checkpoint: path to model checkpoint. Optional when type=class.
  • model_file: path to serialized model file. Required when type=file.
  • ts_file: path to TorchScript serialized model file. Required when type=torchscript.

Runtime Specific Fields: TensorFlow

  • type: type of the TensorFlow model being loaded. Supported types are class, file, and json.
  • class_file: path to Python file containing model class definition. Required when type=class.
  • class_name: name of the model class. Required when type=class.
  • kwargs_file: path to Python file containing a dictionary named kwargs to be used in model class instantiation. Optional when type=class.
  • checkpoint: path to model checkpoint. Optional when type=class or type=json.
  • json_file: path to json serialized model file. Required when type=json.
  • model_file: path to serialized model file. Required when type=file.

Runtime Specific Fields: ONNX

  • type: type of the ONNX model being loaded. Supported type is file.
  • model_file: path to serialized model file. Required when type=file.

Complete Example

{
  "cluster": "slurm",
  "hardware": "gpu",
  "combine": "soft",
  "chunksize": [3, 384, 384],
  "pad_width": [0, 128, 128],
  "dask_config": {
    "n_workers": 4,
    "threads_per_worker": 10
  },
  "ray_config": {},
  "model_config": {
    "runtime": "torch",
    "processing_steps": "steps.py",
    "max_batch_size": 256,
    "model_input_shape": [1, 128, 128],
    "model_output_shape": [21, 1, 128, 128],
    "dtype": "float32",
    "device": "cuda",
    "type": "torchscript",
    "ts_file": "ts_model.pt"
  }
}

Running an inference

The file spinn/inference.py contains a CLI util that can be used to run inferences, it can be invoked by using the command spinn`. A possible usage would be

spinn -e [dask|ray] -i <INPUT_PATH> -o <OUTPUT_PATH> -c <CONFIG_PATH>

TensorFlow by default logs a lot of information, if you want to disable that you can set the environment variable TF_CPP_MIN_LOG_LEVEL to 3 (export TF_CPP_MIN_LOG_LEVEL=3).

Running using SLURM's sbatch

When running experiments as SLURM jobs, the mode cluster=slurm is recommended. SPINN will automatically create the Dask/Ray cluster using all the nodes available for the job.

#!/usr/bin/env bash

#SBATCH -J spinn_job
#SBATCH -p <SLURM_PARTITION>
#SBATCH -A <SLURM_ACCOUNT>
#SBATCH --nodes=4
#SBATCH --ntasks=4
#SBATCH -t 2:00:00

srun --ntasks=4 spinn -e [dask|ray] -i <INPUT_PATH> -o <OUTPUT_PATH> -c <CONFIG_PATH>

Docker Base Images

For this SPINN version, we recommend the following base images:

  • python:3.10-slim-bullseye
  • nvcr.io/nvidia/pytorch:24.10-py3

About

Distributed Inference on Massive Data Samples

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published