Skip to content
This repository has been archived by the owner on Jul 29, 2023. It is now read-only.

Commit

Permalink
Fix model architecture for deployment to ONNX (#234)
Browse files Browse the repository at this point in the history
* added primitive xy scaling feature to inference

* fixed opset mismatches

* changes to model + scripts for exporting to onnx

* fixed dropout incompatibility, moved logger to utils.

* removed inference session from model export script

* use builtins to normalize

* caching argument
not doing anything yet

* profiling script

* separate augmentation

* Revert "separate augmentation"

This reverts commit c30ed61.

* remove unused import

* let zarr auto-detect multi-processing

* remove unused import

* copy zarr store to memory

* allow cache and preload

* update profiling script

* format

* remove preload
this is not practical for larger datasets

* cleanup

* configurable number of log samples

* split train and validation dataset at fov level

* fix shuffling

* fix dropout layers initialization

* fix filter hyperparameter check

* disable dropout for the head

* stronger default augmentation

* formalized onnx exporting and moved script to CLI, updated documentation in inference readme

* fix merge conflict

* isort

* combine training CLI with others

* Multi-GPU training (#235)

* sync log metrics

* example of using more GPUs

* sync log metrics

* example of using more GPUs

* revised data format

* updated data org

* remove profile output
this was accidentally tracked

---------

Co-authored-by: Christian Foley <[email protected]>
Co-authored-by: Christian Foley <[email protected]>
Co-authored-by: Shalin Mehta <[email protected]>
  • Loading branch information
4 people committed May 29, 2023
1 parent 7d81270 commit 00375b0
Show file tree
Hide file tree
Showing 13 changed files with 454 additions and 310 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
seed_everything: true
trainer:
accelerator: auto
strategy: auto
strategy: auto # ddp_find_unused_parameters_true for more GPUs
devices: auto
num_nodes: 1
precision: 32-true
Expand Down Expand Up @@ -36,7 +36,7 @@ trainer:
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
sync_batchnorm: true
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
Expand Down
178 changes: 178 additions & 0 deletions micro_dl/cli/onnx_export_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import argparse
import numpy as np
import os
import onnxruntime as ort
import onnx
import pathlib
import sys
import torch.onnx as torch_onnx
import torch

sys.path.insert(0, "/home/christian.foley/virtual_staining/workspaces/microDL")
import micro_dl.inference.inference as inference

def parse_args():
"""
Parse command line arguments
In python namespaces are implemented as dictionaries
:return: namespace containing the arguments passed.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
help="path to yaml configuration file",
)
parser.add_argument(
"--stack_depth",
required=True,
help="Stack depth of model. If model is 2D, use 1."
)
parser.add_argument(
"--export_path",
required=True,
help="Path to store exported model"
)
parser.add_argument(
"--test_input",
required=False,
default=None,
help="Path to .npy test input for additional model validation after export."
)
args = parser.parse_args()
return args


def validate_export(model_path) -> None:
"""
Run ONNX validation on exported model. Assures export success.
:param str model_path: path to exported onnx model
"""
print("Validating model...", end='')
onnx_model = onnx.load(model_path)
try:
onnx.checker.check_model(onnx_model)
print("Passed!")
except Exception as e:
print("Failed:")
print("\t", e)
sys.exit()


def remove_initializer_from_input(model_path):
"""
De-initializes model at model_path, and overwrites with de-initialized version
:param str model_path: path to model to de-initialize inputs
"""
model = onnx.load(model_path)
if model.ir_version < 4:
print(
"Model with ir_version below 4 requires to include initilizer in graph input"
)
return

inputs = model.graph.input
name_to_input = {}
for input in inputs:
name_to_input[input.name] = input

for initializer in model.graph.initializer:
if initializer.name in name_to_input:
inputs.remove(name_to_input[initializer.name])

onnx.save(model, model_path)


def export_model(model_dir, model_name, stack_depth, export_path) -> None:
"""
Export a model to onnx. Due to restrictions in the pytorch-onnx opset conversion,
opset for 2.5D and 2D unets are limited to version 10 without dropout.
:param str model_dir: path to model directory
:param str model_name: name of model in directory
:patah str export_path: intended path for exported model
"""
print("Initializing model in pytorch...")
torch_predictor = inference.TorchPredictor(
config={"model_dir": model_dir, "model_name": model_name},
device="cpu",
single_prediction=True,
)
torch_predictor.load_model()
model = torch_predictor.model
model.eval()

if stack_depth == 1:
sample_input = np.random.rand(1, 1, 512, 512)
else:
sample_input = np.random.rand(1, 1, stack_depth, 512, 512)
input_tensor = torch.tensor(sample_input.astype(np.float32), requires_grad=True)
print("Done!")

# Export the model
print("Exporting model to onnx...", end="")
torch_onnx.export(
model,
input_tensor,
export_path,
export_params=True,
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "channels", 3: "num_rows", 4: "num_cols"},
"output": {0: "batch_size", 1: "channels", 3: "num_rows", 4: "num_cols"},
},
)
remove_initializer_from_input(export_path)
validate_export(export_path)
print("Done!")


def infer(model_path, data_path, output_path) -> None:
"""
Run an inference session with an onnx model. Data will be read into a numpy array
and stored as a numpy array.
:param str model_path: path to onnx model for inference
:param str data_path: path to data for inference
:param str output_path: path to save model output to
"""

validate_export(model_path)
data = np.load(data_path)

options = ort.SessionOptions()
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1

ort_sess = ort.InferenceSession(model_path)
outputs = ort_sess.run(None, {"input": data})

np.save(output_path, outputs)


def main(args):
model_dir = pathlib.Path(args.model_path).parent.absolute().parent.absolute()
model_name = os.path.basename(args.model_path)

export_model(model_dir, model_name, args.stack_depth, args.export_path)

# if specified, run test with some test input numpy array
if args.test_input is not None:
print("Running inference test with ONNX model on CPU...")
test_out_dir = pathlib.Path(args.test_input).parent.absolute()
test_out_name = "test_pred_" + os.path.basename(args.test_input)
test_out_path = os.path.join(test_out_dir, test_out_name)
infer(args.export_path, args.test_input, test_out_path)
print("Done!")

if __name__ == "__main__":
args = parse_args()
main(args)

4 changes: 1 addition & 3 deletions micro_dl/cli/torch_inference_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import micro_dl.inference.inference as torch_inference_utils
import micro_dl.utils.aux_utils as aux_utils
import micro_dl.utils.gpu_utils as gpu_utils


def check_gpu_availability(gpu_id):
"""
Expand Down Expand Up @@ -180,7 +178,7 @@ def main(config, gpu, gpu_mem_frac):

if gpu is not None:
# Get GPU ID and memory fraction
gpu_id, gpu_mem_frac = gpu_utils.select_gpu(
gpu_id, gpu_mem_frac = select_gpu(
gpu,
gpu_mem_frac,
)
Expand Down
File renamed without changes.
28 changes: 13 additions & 15 deletions micro_dl/data_organization.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# Data Organization for Virtual Staining

> This advisory only applies to data management on Biohub's compute infrastructure.
> It is not normative for external users.
Here we document the conventions for storing data, metadata, configs,
and models during the development of the virtual staining pipeline.
Here we document our conventions for storing data, metadata, configs, and models.

## Data flow in the pipeline

Expand Down Expand Up @@ -90,19 +86,21 @@ virtual_staining:
config.yaml
yyyymmdd-hhmmss:
...
# evaluation of select models
evaluation:
# configs for evaluation: checkpoint path, test data path, ground turth path, and choice of metrics.
evaluation_01.yaml
evaluation_02.yaml
...
# Inference and/or Evaluation of selected models.
test:
# config for prediction with test dataset.
test_<suffix>.yml # config used for inference, optionally copies ground truth and input for evaluation. This config will follow the lightning CLI/config format.

# inference output on test dataset, may include copies of input and ground truth to facilitate visualization of model performance.
prediction_01.zarr
prediction_02.zarr
test_<suffix>.zarr # Not all test datasets need to have human curated ground truth.
...

# config for evaluation: checkpoint path, test data path that have ground turth included, and choice of metrics.
evaluation_<suffix>.yaml
...

# evaluation metrics
metrics_01.csv
metrics_02.csv
evaluation_metrics_<suffix>.csv
...
# (optional) tensorboard logs generated to visualize distribution of metrics or specific samples of input, prediction, ground truth.
evaluation_logs:
Expand Down
8 changes: 7 additions & 1 deletion micro_dl/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ def __init__(self, config, device, single_prediction=False) -> None:
self.input_channels = self.config["input_channels"]
self.time_indices = self.config["time_indices"]

scale_factor = 1
if "model_mag" in self.config and "source_mag" in self.config:
scale_factor = self.config["model_mag"] / self.config["source_mag"]

self.dataset = inference_dataset.TorchInferenceDataset(
zarr_dir=self.config["zarr_dir"],
batch_pred_num=self.config["batch_size"],
normalize_inputs=self.config["normalize_inputs"],
norm_type=self.config["norm_type"],
norm_scheme=self.config["norm_scheme"],
sample_depth=self.network_z_depth,
scale_factor=scale_factor,
device=self.device,
)

Expand Down Expand Up @@ -160,7 +165,7 @@ def predict_image(
img_tensor = aux_utils.ToTensor(device=self.device)(input_image)

img_tensor, pads = _pad_input(img_tensor, num_blocks=model.num_blocks)
pred = model(img_tensor, validate_input=False)
pred = model(img_tensor)
return TF.crop(
pred.detach().cpu(), *(pads[1], pads[0]) + input_image.shape[-2:]
).numpy()
Expand Down Expand Up @@ -216,6 +221,7 @@ def run_inference(self):
for batch, z0, size, _ in dataloader:
batch_pred = self.predict_image(batch[0]) # redundant batch dim
batch_pred = np.squeeze(np.swapaxes(batch_pred, 0, 2), axis=0)
batch_pred = self.dataset._scale(batch_pred, unscale=True)
output_array[time_idx, :, z0 : z0 + size, ...] = batch_pred

# write config to save dir
Expand Down
26 changes: 21 additions & 5 deletions micro_dl/inference/readme.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
## Inference
# Inference

The main command for inference is:

```buildoutcfg
python micro_dl/cli/torchinference_script.py --config <config path (.yml)> --gpu <gpu id (default 0)> --gpu_mem_frac <0-1 (default 1>
```

where the parameters are defined as follows:

* **config** (yaml file): Configuration file, see below.
* **gpu** (int): ID number of if you'd like to specify which GPU you'd like to run on. If you don't
specify a GPU then the GPU with the largest amount of available memory will be selected for you.
Expand All @@ -14,7 +16,7 @@ If there's not enough memory available on the GPU, and AssertionError will be ra
If memory fraction is unspecified, all memory currently available on the GPU will automatically
be allocated for you.

# Config
## Config

> **zarr_dir**: `absolute path` (absolute path to HCS-compatible zarr store containing data)
>
Expand Down Expand Up @@ -42,13 +44,15 @@ be allocated for you.
>
> ***custom_save_preds_dir:*** `absolute path` (Path to custom save directory. Generally try to avoid using this, since it delocates model predictions from the models)
# Config Example
## Config Example

Some working config examples can be found at:

```buildoutcfg
/hpc/projects/CompMicro/projects/virtualstaining/torch_microDL/config_files/2022_HEK_nuc_mem_Soorya/TestData_HEK_2022_04_16/
```

# Single Sample Prediction
## Single Sample Prediction
It is sometimes the case that inference needs to be run on individual samples. This can be performed by using the `predict_image` method of the `TorchPredictor` object.

Initializing the `TorchPredictor` object for this task requires a config dictionary specifying the model to use for prediction:
Expand Down Expand Up @@ -83,4 +87,16 @@ sample_prediction = torch_predictor.predict_image(sample_input)

```

#TODO: evaluation script and config
## Exporting models to onnx

If you wish to run inference via usage of the ONNXruntime, models can be exported to onnx using the `micro_dl/cli/onnx_export_script.py`. See below for an example usage of this script with 5-input-stack model:

```bash
python micro_dl/cli/onnx_export_script.py --model_path path/to/your/pt_model.pt --stack_depth 5 --export_path intended/path/to/model/export.onnx --test_input path/to/test/input.npy
```

**Some Notes:**

* For cpu sharing reasons, running an onnx model requires a dedicated node on hpc OR a non-distributed system (for example a personal laptop or other device).
* Test inputs are optional, but help verify that the exported model can be run if exporting from intended usage device.
* Models must be located in a lighting training logs directory with a valid `config.yaml` in order to be initialized. This can be "hacked" by locating the config in a directory called `checkpoints` beneath a valid config's directory.
9 changes: 4 additions & 5 deletions micro_dl/light/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import tempfile
from typing import Any, Callable, Literal, Union, Iterable
from typing import Any, Callable, Iterable, Literal, Union

import numpy as np
import torch
Expand All @@ -21,7 +21,6 @@
from numpy.typing import NDArray
from torch.utils.data import DataLoader, Dataset


Sample = dict[str, torch.Tensor]


Expand Down Expand Up @@ -294,12 +293,12 @@ def _train_transform(self) -> list[Callable]:
prob=0.5,
rotate_range=(np.pi, 0, 0),
shear_range=(0, (0.05), (0.05)),
scale_range=(0, 0.2, 0.2),
scale_range=(0, 0.3, 0.3),
),
RandAdjustContrastd(keys=["source"], prob=0.1, gamma=(0.75, 1.5)),
RandAdjustContrastd(keys=["source"], prob=0.3, gamma=(0.75, 1.5)),
RandGaussianSmoothd(
keys=["source"],
prob=0.2,
prob=0.3,
sigma_x=(0.05, 0.25),
sigma_y=(0.05, 0.25),
sigma_z=((0.05, 0.25)),
Expand Down
Loading

0 comments on commit 00375b0

Please sign in to comment.