We are currently working on making SPINN available under a license agreement. If you are interested, please contact the authors.
- @SerodioJ - [email protected]
- @eborin - [email protected]
SPINN (Scalable Parallel INference Network) is a framework designed to perform ML model inference in large scale.
It supports loading ML models from different frameworks (PyTorch, TensorFlow, and ONNX) in computational clusters and then subdvide the input data (a large data volume) into smaller chunks, then perform the model inference on the chunks, followed by a reconstruction of the output volume.
Given that it allows to have overlapping patch grids for the same data region, it also provides 3 different combination strategies, weighted average, hard voting, and soft voting.
The engines that can be used to run the inference in a parallel and distributed manner are Ray and Dask.
The clusters for each one of these engines can be created or connected to in 3 ways.
SLURM, Local or Running.
Config files are the way inference tasks are configured. They can be JSON files that are passed to the script to be read, or if working on a Python script, it can be a Python dictionary.
Base config has configurations shared by Dask and Ray executions, the fields that can be set are as follows.
cluster: which cluster type to use. Acceptable values:local(single node cluster instantiation),slurm(single or multi-node instantiation inside a SLURM job), andrunning(connect to running cluster).combine: which combination strategy to use. Acceptable values:weighted_avg(weighted average),soft(soft voting), andhard(hard voting).inference_type: which inference strategy to use. Acceptable values:patch(Patch Inference with Offsets),sliding_window(Sliding Window Inference with Strides).chunksize: list indicating the dimensions of the chunks to read from the source file, e.g., [128, 128, 128].
hardware: hardware to use on inference. Acceptable values:gpuandcpu_only. Defaults togpu.pad_mode: which pad mode to use. Acceptable values:constant,reflect,symmetric, andedge(see NumPy pad documentation for reference). Defaults toconstant.pad_value: value used as padding when pad_mode=constant. Defaults to0.pad_width: list with pad for each of the input data dimensions. If the data has 3 dimensions, [10, 10] would be an invalid pad_width, while [0, 10, 10] is valid.offsets: list of offset tuples to extract overlapping patch grids, e.g., [[-32, -32, -32], [0, 16, 16]], used in Patch Inference(inference_type=patch). The provided example will extract two additional patch grids offset from the starting point of the base patch grid, base_grid_point + offset is used to get the starting point of the offset grid.strides: strides for each dimension, e.g., [10, 10, 1], used in Sliding Window Inference(inference_type=sliding_window).weight_function: function that computes a weight distribution for a given patch shape. Seespinn.core.weightsfor reference. The function must be passed as a list. If the function is implemented atspinn.core.weightsthe list will contain only the name of the function, e.g, ["weight_gaussian"], which then usesspinn.core.weights.weight_gaussianto generate the weight distribution. In case it's a user-defined weight distribution, the path to the Python script with the implementation and the function name must be provided, e.g., ["path/to/my/python/script.py", "my_weights"], which then will load themy_weightsfunction from the filepath/to/my/python/script.py.
Dask config is an optional field that will be at the root of the dictionary, alongside the base config fields. It is used to set Dask cluster configurations.
address: cluster scheduler address, required only when cluster=running.
protocol: communication protocol between workers and scheduler. Acceptable values:tcpanducx. Defaults totcp.interface: network interface to use. Acceptable values are strings, such asib0. Defaults toNone, which will use the default network interface of the system.scheduler_port: Dask scheduler port. Acceptable values are ints. Defaults to8786.n_workers: number of Dask workers to instantiate per node on the cluster. Acceptable values are ints. If not set will default to 1 worker per GPU when hardware=gpu, and 1 worker per node when hardware=cpu_only.threads_per_worker: number of threads per worker. Acceptable values are ints. If not set defaults to Dask default behavior.
Ray config is an optional field that will be at the root of the dictionary, alongside the base config fields. It is used to set Ray cluster configurations.
address: Ray head node address, required only when cluster=running.port: Ray head node connection port, required only when cluster=running.
port: Ray connection port, optional when cluster!=running. Acceptable values are ints. Defaults to6379.
Model config is a required field that will be at the root of the dictionary, alongside the base config fields. It is used to correctly load and instantiate the ML model across the cluster. It has a set of required and optional fields, but also runtime specific fields.
runtime: runtime to use in model loading and inference. Supported values aretorch,tensorflow, andonnx.model_input_shape: expected input shape of the model, e.g., [1, 128, 128].dtype: data type of the model input. Supported values arefloat16,float32,float64,int8,int16,int32, andint64.max_batch_size: limits the number of data items that will be used on inference at a time. Especially important in GPU inference, to avoid Out-of-Memory problems. Acceptable values are ints.
device: device to use for inference. Supported values arecpuandcuda. Defaults tocpu.model_output_shape: output shape of the model. Defaults tomodel_input_shapevalue.processing_steps: path to Python script containing pre/post processing steps. Functions need to be calledpreprocessingandpostprocessingfor functions to be called outside ModelWrapper call method, orpreprocessing_protectedandpostprocessing_protectedfor functions to be called inside ModelWrapper call method (at that point data will already be converted and resident on the device used for inference).lost_dim: used for 2.5D inference that uses neighboring region extracted from the volume, the model input size for the dimension marked as lost must be an odd number (Region of Interest + 2*considered neighborhood), andpad_widthmust match the size of neighborhood to be considered. As an example, if I have a model that has (128, 5, 128) as input and produces (128, 1, 128); thelost_dimwill be1, thepad_widthfor dimension1needs to be2, so a validpad_widthwould be (10, 2, 10).join_dim: used for multiple inputs inference, the model input shape needs to be specified as1on the dimension that will be used to concatenate the inputs. As an example, if working with 4 inputs (A, B, C, D), amodel_input_shapeof (1, 128, 128) andjoin_dim=1, the data that will be received by the model will have shape (N, 4, 128, 128), where N is the number of patches on the batch. Another important thing is that the inputs will be concatenated in the order they were provided, sodata[:,0,:,:]is from input A,data[:,1,:,:]from B, and so on.
type: type of the Torch model being loaded. Supported types areclass,file, andtorchscript.class_file: path to Python file containing model class definition. Required when type=class.class_name: name of the model class. Required when type=class.kwargs_file: path to Python file containing a dictionary namedkwargsto be used in model class instantiation. Optional when type=class.checkpoint: path to model checkpoint. Optional when type=class.model_file: path to serialized model file. Required when type=file.ts_file: path to TorchScript serialized model file. Required when type=torchscript.
type: type of the TensorFlow model being loaded. Supported types areclass,file, andjson.class_file: path to Python file containing model class definition. Required when type=class.class_name: name of the model class. Required when type=class.kwargs_file: path to Python file containing a dictionary namedkwargsto be used in model class instantiation. Optional when type=class.checkpoint: path to model checkpoint. Optional when type=classor type=json.json_file: path to json serialized model file. Required when type=json.model_file: path to serialized model file. Required when type=file.
type: type of the ONNX model being loaded. Supported type isfile.model_file: path to serialized model file. Required when type=file.
{
"cluster": "slurm",
"hardware": "gpu",
"combine": "soft",
"chunksize": [3, 384, 384],
"pad_width": [0, 128, 128],
"dask_config": {
"n_workers": 4,
"threads_per_worker": 10
},
"ray_config": {},
"model_config": {
"runtime": "torch",
"processing_steps": "steps.py",
"max_batch_size": 256,
"model_input_shape": [1, 128, 128],
"model_output_shape": [21, 1, 128, 128],
"dtype": "float32",
"device": "cuda",
"type": "torchscript",
"ts_file": "ts_model.pt"
}
}The file spinn/inference.py contains a CLI util that can be used to run inferences, it can be invoked by using the command spinn`.
A possible usage would be
spinn -e [dask|ray] -i <INPUT_PATH> -o <OUTPUT_PATH> -c <CONFIG_PATH>TensorFlow by default logs a lot of information, if you want to disable that you can set the environment variable TF_CPP_MIN_LOG_LEVEL to 3 (export TF_CPP_MIN_LOG_LEVEL=3).
When running experiments as SLURM jobs, the mode cluster=slurm is recommended.
SPINN will automatically create the Dask/Ray cluster using all the nodes available for the job.
#!/usr/bin/env bash
#SBATCH -J spinn_job
#SBATCH -p <SLURM_PARTITION>
#SBATCH -A <SLURM_ACCOUNT>
#SBATCH --nodes=4
#SBATCH --ntasks=4
#SBATCH -t 2:00:00
srun --ntasks=4 spinn -e [dask|ray] -i <INPUT_PATH> -o <OUTPUT_PATH> -c <CONFIG_PATH>For this SPINN version, we recommend the following base images:
- python:3.10-slim-bullseye
- nvcr.io/nvidia/pytorch:24.10-py3

