Skip to content
This repository has been archived by the owner on Jul 29, 2023. It is now read-only.

Commit

Permalink
Inference in lightning (#236)
Browse files Browse the repository at this point in the history
* rename lightning cli main script

* add docstring

* predict stage in data module

* write predict result with callback

* remove old inference module

* rename tests

Signed-off-by: Ziwen Liu <[email protected]>

---------

Signed-off-by: Ziwen Liu <[email protected]>
  • Loading branch information
ziw-liu committed May 29, 2023
1 parent 5f31f97 commit b0fe3bd
Show file tree
Hide file tree
Showing 18 changed files with 303 additions and 890 deletions.
9 changes: 8 additions & 1 deletion micro_dl/cli/train.py → micro_dl/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class VSLightningCLI(LightningCLI):
"""Extending lightning CLI arguments and defualts."""

def add_arguments_to_parser(self, parser):
# https://pytorch-lightning.readthedocs.io/en/1.6.0/api/pytorch_lightning.utilities.cli.html#pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser
parser.link_arguments("data.batch_size", "model.batch_size")
Expand All @@ -24,7 +26,12 @@ def add_arguments_to_parser(self, parser):
save_dir="",
version=datetime.now().strftime(r"%Y%m%d-%H%M%S"),
log_graph=True,
)
),
"trainer.callbacks": [
{
"class_path": "micro_dl.light.prediction_writer.HCSPredictionWriter",
}
],
}
)

Expand Down
2 changes: 1 addition & 1 deletion micro_dl/cli/curator_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import iohub.ngff as ngff
import argparse

import micro_dl.inference.evaluation_metrics as metrics
import micro_dl.evaluation.evaluation_metrics as metrics
import micro_dl.utils.aux_utils as aux_utils
# from waveorder.focus import focus_from_transverse_band

Expand Down
1 change: 1 addition & 0 deletions micro_dl/cli/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ data:
- 256
augment: true
caching: false
normalize_source: false
ckpt_path: null
2 changes: 1 addition & 1 deletion micro_dl/cli/metrics_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import argparse
import pandas as pd

import micro_dl.inference.evaluation_metrics as metrics
import micro_dl.evaluation.evaluation_metrics as metrics
import micro_dl.utils.aux_utils as aux_utils

# %% read the below details from the config file
Expand Down
69 changes: 69 additions & 0 deletions micro_dl/cli/predict_example.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# lightning.pytorch==2.0.1
predict:
seed_everything: true
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 32-true
callbacks:
- class_path: micro_dl.light.prediction_writer.HCSPredictionWriter
init_args:
output_store: null
write_input: false
write_interval: batch
fast_dev_run: false
max_epochs: null
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: null
limit_test_batches: null
limit_predict_batches: null
overfit_batches: 0.0
val_check_interval: null
check_val_every_n_epoch: 1
num_sanity_val_steps: null
log_every_n_steps: null
enable_checkpointing: null
enable_progress_bar: null
enable_model_summary: null
accumulate_grad_batches: 1
gradient_clip_val: null
gradient_clip_algorithm: null
deterministic: null
benchmark: null
inference_mode: true
use_distributed_sampler: true
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
model_config: {}
loss_function: null
lr: 0.001
schedule: Constant
log_num_samples: 8
data:
data_path: null
source_channel: null
target_channel: null
z_window_size: null
split_ratio: null
batch_size: 16
num_workers: 8
yx_patch_size:
- 256
- 256
augment: true
caching: false
normalize_source: false
return_predictions: null
ckpt_path: null
15 changes: 15 additions & 0 deletions micro_dl/cli/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# CLI

## Exporting models to ONNX

If you wish to run inference via usage of the ONNXruntime, models can be exported to onnx using the `micro_dl/cli/onnx_export_script.py`. See below for an example usage of this script with 5-input-stack model:

```bash
python micro_dl/cli/onnx_export_script.py --model_path path/to/your/pt_model.pt --stack_depth 5 --export_path intended/path/to/model/export.onnx --test_input path/to/test/input.npy
```

**Some Notes:**

* For cpu sharing reasons, running an onnx model requires a dedicated node on hpc OR a non-distributed system (for example a personal laptop or other device).
* Test inputs are optional, but help verify that the exported model can be run if exporting from intended usage device.
* Models must be located in a lighting training logs directory with a valid `config.yaml` in order to be initialized. This can be "hacked" by locating the config in a directory called `checkpoints` beneath a valid config's directory.
201 changes: 0 additions & 201 deletions micro_dl/cli/torch_inference_script.py

This file was deleted.

2 changes: 1 addition & 1 deletion micro_dl/evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import micro_dl.inference.evaluation_metrics as inference_metrics
import micro_dl.evaluation.evaluation_metrics as inference_metrics
from torch.utils.tensorboard import SummaryWriter


Expand Down
1 change: 0 additions & 1 deletion micro_dl/inference/__init__.py

This file was deleted.

Loading

0 comments on commit b0fe3bd

Please sign in to comment.