Skip to content
This repository has been archived by the owner on Jul 29, 2023. It is now read-only.

Commit

Permalink
Revert "Fix model architecture for deployment to ONNX (#234)"
Browse files Browse the repository at this point in the history
This reverts commit 00375b0.
  • Loading branch information
ziw-liu authored May 30, 2023
1 parent 00375b0 commit 00c484a
Show file tree
Hide file tree
Showing 13 changed files with 310 additions and 454 deletions.
178 changes: 0 additions & 178 deletions micro_dl/cli/onnx_export_script.py

This file was deleted.

4 changes: 3 additions & 1 deletion micro_dl/cli/torch_inference_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import micro_dl.inference.inference as torch_inference_utils
import micro_dl.utils.aux_utils as aux_utils
import micro_dl.utils.gpu_utils as gpu_utils


def check_gpu_availability(gpu_id):
"""
Expand Down Expand Up @@ -178,7 +180,7 @@ def main(config, gpu, gpu_mem_frac):

if gpu is not None:
# Get GPU ID and memory fraction
gpu_id, gpu_mem_frac = select_gpu(
gpu_id, gpu_mem_frac = gpu_utils.select_gpu(
gpu,
gpu_mem_frac,
)
Expand Down
28 changes: 15 additions & 13 deletions micro_dl/data_organization.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Data Organization for Virtual Staining

Here we document our conventions for storing data, metadata, configs, and models.
> This advisory only applies to data management on Biohub's compute infrastructure.
> It is not normative for external users.
Here we document the conventions for storing data, metadata, configs,
and models during the development of the virtual staining pipeline.

## Data flow in the pipeline

Expand Down Expand Up @@ -86,21 +90,19 @@ virtual_staining:
config.yaml
yyyymmdd-hhmmss:
...
# Inference and/or Evaluation of selected models.
test:
# config for prediction with test dataset.
test_<suffix>.yml # config used for inference, optionally copies ground truth and input for evaluation. This config will follow the lightning CLI/config format.

# inference output on test dataset, may include copies of input and ground truth to facilitate visualization of model performance.
test_<suffix>.zarr # Not all test datasets need to have human curated ground truth.
# evaluation of select models
evaluation:
# configs for evaluation: checkpoint path, test data path, ground turth path, and choice of metrics.
evaluation_01.yaml
evaluation_02.yaml
...

# config for evaluation: checkpoint path, test data path that have ground turth included, and choice of metrics.
evaluation_<suffix>.yaml
# inference output on test dataset, may include copies of input and ground truth to facilitate visualization of model performance.
prediction_01.zarr
prediction_02.zarr
...

# evaluation metrics
evaluation_metrics_<suffix>.csv
metrics_01.csv
metrics_02.csv
...
# (optional) tensorboard logs generated to visualize distribution of metrics or specific samples of input, prediction, ground truth.
evaluation_logs:
Expand Down
8 changes: 1 addition & 7 deletions micro_dl/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,13 @@ def __init__(self, config, device, single_prediction=False) -> None:
self.input_channels = self.config["input_channels"]
self.time_indices = self.config["time_indices"]

scale_factor = 1
if "model_mag" in self.config and "source_mag" in self.config:
scale_factor = self.config["model_mag"] / self.config["source_mag"]

self.dataset = inference_dataset.TorchInferenceDataset(
zarr_dir=self.config["zarr_dir"],
batch_pred_num=self.config["batch_size"],
normalize_inputs=self.config["normalize_inputs"],
norm_type=self.config["norm_type"],
norm_scheme=self.config["norm_scheme"],
sample_depth=self.network_z_depth,
scale_factor=scale_factor,
device=self.device,
)

Expand Down Expand Up @@ -165,7 +160,7 @@ def predict_image(
img_tensor = aux_utils.ToTensor(device=self.device)(input_image)

img_tensor, pads = _pad_input(img_tensor, num_blocks=model.num_blocks)
pred = model(img_tensor)
pred = model(img_tensor, validate_input=False)
return TF.crop(
pred.detach().cpu(), *(pads[1], pads[0]) + input_image.shape[-2:]
).numpy()
Expand Down Expand Up @@ -221,7 +216,6 @@ def run_inference(self):
for batch, z0, size, _ in dataloader:
batch_pred = self.predict_image(batch[0]) # redundant batch dim
batch_pred = np.squeeze(np.swapaxes(batch_pred, 0, 2), axis=0)
batch_pred = self.dataset._scale(batch_pred, unscale=True)
output_array[time_idx, :, z0 : z0 + size, ...] = batch_pred

# write config to save dir
Expand Down
26 changes: 5 additions & 21 deletions micro_dl/inference/readme.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# Inference
## Inference

The main command for inference is:

```buildoutcfg
python micro_dl/cli/torchinference_script.py --config <config path (.yml)> --gpu <gpu id (default 0)> --gpu_mem_frac <0-1 (default 1>
```

where the parameters are defined as follows:

* **config** (yaml file): Configuration file, see below.
* **gpu** (int): ID number of if you'd like to specify which GPU you'd like to run on. If you don't
specify a GPU then the GPU with the largest amount of available memory will be selected for you.
Expand All @@ -16,7 +14,7 @@ If there's not enough memory available on the GPU, and AssertionError will be ra
If memory fraction is unspecified, all memory currently available on the GPU will automatically
be allocated for you.

## Config
# Config

> **zarr_dir**: `absolute path` (absolute path to HCS-compatible zarr store containing data)
>
Expand Down Expand Up @@ -44,15 +42,13 @@ be allocated for you.
>
> ***custom_save_preds_dir:*** `absolute path` (Path to custom save directory. Generally try to avoid using this, since it delocates model predictions from the models)
## Config Example

# Config Example
Some working config examples can be found at:

```buildoutcfg
/hpc/projects/CompMicro/projects/virtualstaining/torch_microDL/config_files/2022_HEK_nuc_mem_Soorya/TestData_HEK_2022_04_16/
```

## Single Sample Prediction
# Single Sample Prediction
It is sometimes the case that inference needs to be run on individual samples. This can be performed by using the `predict_image` method of the `TorchPredictor` object.

Initializing the `TorchPredictor` object for this task requires a config dictionary specifying the model to use for prediction:
Expand Down Expand Up @@ -87,16 +83,4 @@ sample_prediction = torch_predictor.predict_image(sample_input)

```

## Exporting models to onnx

If you wish to run inference via usage of the ONNXruntime, models can be exported to onnx using the `micro_dl/cli/onnx_export_script.py`. See below for an example usage of this script with 5-input-stack model:

```bash
python micro_dl/cli/onnx_export_script.py --model_path path/to/your/pt_model.pt --stack_depth 5 --export_path intended/path/to/model/export.onnx --test_input path/to/test/input.npy
```

**Some Notes:**

* For cpu sharing reasons, running an onnx model requires a dedicated node on hpc OR a non-distributed system (for example a personal laptop or other device).
* Test inputs are optional, but help verify that the exported model can be run if exporting from intended usage device.
* Models must be located in a lighting training logs directory with a valid `config.yaml` in order to be initialized. This can be "hacked" by locating the config in a directory called `checkpoints` beneath a valid config's directory.
#TODO: evaluation script and config
9 changes: 5 additions & 4 deletions micro_dl/light/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import tempfile
from typing import Any, Callable, Iterable, Literal, Union
from typing import Any, Callable, Literal, Union, Iterable

import numpy as np
import torch
Expand All @@ -21,6 +21,7 @@
from numpy.typing import NDArray
from torch.utils.data import DataLoader, Dataset


Sample = dict[str, torch.Tensor]


Expand Down Expand Up @@ -293,12 +294,12 @@ def _train_transform(self) -> list[Callable]:
prob=0.5,
rotate_range=(np.pi, 0, 0),
shear_range=(0, (0.05), (0.05)),
scale_range=(0, 0.3, 0.3),
scale_range=(0, 0.2, 0.2),
),
RandAdjustContrastd(keys=["source"], prob=0.3, gamma=(0.75, 1.5)),
RandAdjustContrastd(keys=["source"], prob=0.1, gamma=(0.75, 1.5)),
RandGaussianSmoothd(
keys=["source"],
prob=0.3,
prob=0.2,
sigma_x=(0.05, 0.25),
sigma_y=(0.05, 0.25),
sigma_z=((0.05, 0.25)),
Expand Down
3 changes: 1 addition & 2 deletions micro_dl/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
logger=True,
batch_size=self.batch_size,
sync_dist=True,
)
if batch_idx < self.log_num_samples:
self.training_step_outputs.append(
Expand All @@ -83,7 +82,7 @@ def validation_step(self, batch, batch_idx):
target = batch["target"]
pred = self.forward(source)
loss = self.loss_function(pred, target)
self.log("val_loss", loss, batch_size=self.batch_size, sync_dist=True)
self.log("loss/val", loss, batch_size=self.batch_size)
if batch_idx < self.log_num_samples:
self.validation_step_outputs.append(
self._detach_sample((source, target, pred))
Expand Down
Loading

0 comments on commit 00c484a

Please sign in to comment.