Skip to content

Commit

Permalink
Merge pull request #169 from Pale-Blue-Dot-97/torchcompile-dev
Browse files Browse the repository at this point in the history
Added `torch.compile` functionality
  • Loading branch information
Pale-Blue-Dot-97 authored May 11, 2023
2 parents a368086 + 4158003 commit 4e07006
Show file tree
Hide file tree
Showing 49 changed files with 1,157 additions and 327 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ Don't forget to give the project a star! Thanks again!

## License 🔏

Minerva is distributed under a [GNU LGPLv3 License](https://choosealicense.com/licenses/lgpl-3.0/).
Minerva is distributed under a [MIT License](https://choosealicense.com/licenses/mit/).

<p align="right">(<a href="#top">back to top</a>)</p>

Expand Down
2 changes: 1 addition & 1 deletion minerva/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@
__version__ = "0.23.4"
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
6 changes: 4 additions & 2 deletions minerva/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"PairedDataset",
Expand Down Expand Up @@ -384,7 +384,9 @@ def get_subdataset(
)

# Construct the root to the sub-dataset's files.
sub_dataset_root: Path = universal_path(data_directory) / sub_dataset_params["root"]
sub_dataset_root: Path = (
universal_path(data_directory) / sub_dataset_params["root"]
)
sub_dataset_root = sub_dataset_root.absolute()

return _sub_dataset, str(sub_dataset_root)
Expand Down
2 changes: 1 addition & 1 deletion minerva/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"MinervaLogger",
Expand Down
2 changes: 1 addition & 1 deletion minerva/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"MinervaMetrics",
Expand Down
2 changes: 1 addition & 1 deletion minerva/modelio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"sup_tg",
Expand Down
2 changes: 1 addition & 1 deletion minerva/models/__depreciated.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"


Expand Down
2 changes: 1 addition & 1 deletion minerva/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"

# =====================================================================================================================
Expand Down
4 changes: 2 additions & 2 deletions minerva/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"

__all__ = [
Expand Down Expand Up @@ -449,7 +449,7 @@ def get_output_shape(
assert isinstance(_image_dim, Iterable)
random_input = torch.rand([4, *_image_dim])

output: Tensor = model(random_input)
output: Tensor = model(random_input.to(next(model.parameters()).device))

if len(output[0].data.shape) == 1:
return output[0].data.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion minerva/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion minerva/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"ResNet",
Expand Down
2 changes: 1 addition & 1 deletion minerva/models/siamese.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"MinervaSiamese",
Expand Down
2 changes: 1 addition & 1 deletion minerva/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion minerva/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"RandomPairGeoSampler",
Expand Down
18 changes: 16 additions & 2 deletions minerva/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@

__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = ["Trainer"]

# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
import os
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import (
Expand All @@ -61,6 +62,7 @@
from alive_progress import alive_bar, alive_it
from inputimeout import TimeoutOccurred, inputimeout
from nptyping import Int, NDArray
from packaging.version import Version
from torch import Tensor
from torch.nn.modules import Module
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -414,6 +416,16 @@ def __init__(
)
self.model = MinervaDataParallel(self.model, DDP, device_ids=[gpu])

# Wraps the model in `torch.compile` to speed up computation time.
# Python 3.11+ is not yet supported though, hence the exception clause.
if Version(torch.__version__) > Version("2.0.0"): # pragma: no cover
try:
_compiled_model = torch.compile(self.model)
assert isinstance(_compiled_model, (MinervaModel, MinervaDataParallel))
self.model = _compiled_model
except RuntimeError as err:
warnings.warn(str(err))

def init_wandb_metrics(self) -> None:
"""Setups up separate step counters for :mod:`wandb` logging of train, val, etc."""
if isinstance(self.writer, Run):
Expand Down Expand Up @@ -1063,7 +1075,9 @@ def weighted_knn_validation(
for batch in test_bar:
test_data: Tensor = batch["image"].to(self.device, non_blocking=True)
test_target: Tensor = torch.mode(
torch.flatten(batch["mask"], start_dim=1)
torch.flatten(
batch["mask"].to(self.device, non_blocking=True), start_dim=1
)
).values

# Get features from passing the input data through the model.
Expand Down
2 changes: 1 addition & 1 deletion minerva/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"ClassTransform",
Expand Down
2 changes: 1 addition & 1 deletion minerva/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"universal_path",
Expand Down
2 changes: 1 addition & 1 deletion minerva/utils/config_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"DEFAULT_CONF_DIR_PATH",
Expand Down
4 changes: 2 additions & 2 deletions minerva/utils/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"GENERIC_PARSER",
Expand All @@ -51,8 +51,8 @@
# =====================================================================================================================
import argparse
import os
import signal
import shlex
import signal
import subprocess
from argparse import Namespace
from typing import Any, Callable, Optional, Union
Expand Down
12 changes: 8 additions & 4 deletions minerva/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"IMAGERY_CONFIG_PATH",
Expand Down Expand Up @@ -121,12 +121,12 @@
import random
import re as regex
import shlex
from subprocess import Popen
import sys
import webbrowser
from collections import Counter, OrderedDict
from datetime import datetime
from pathlib import Path
from subprocess import Popen
from types import ModuleType
from typing import Any, Callable
from typing import Counter as CounterType
Expand Down Expand Up @@ -1669,15 +1669,19 @@ def run_tensorboard(
os.chdir(_path)

# Activates the correct Conda environment.
Popen(shlex.split(f"conda activate {env_name}"), shell=True).wait() # nosec B607
Popen( # nosec B607, B602
shlex.split(f"conda activate {env_name}"), shell=True
).wait()

if _testing:
os.chdir(cwd)
return 0

else: # pragma: no cover
# Runs TensorBoard log.
Popen(shlex.split(f"tensorboard --logdir {exp_name}"), shell=True) # nosec B607
Popen( # nosec B607, B602
shlex.split(f"tensorboard --logdir {exp_name}"), shell=True
)

# Opens the TensorBoard log in a locally hosted webpage of the default system browser.
webbrowser.open(f"localhost:{host_num}")
Expand Down
2 changes: 1 addition & 1 deletion minerva/utils/visutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU LGPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"DATA_CONFIG",
Expand Down
2 changes: 1 addition & 1 deletion scripts/ManifestMake.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU GPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"


Expand Down
2 changes: 1 addition & 1 deletion scripts/MinervaClusterVis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU GPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"


Expand Down
2 changes: 1 addition & 1 deletion scripts/MinervaExp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU GPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"


Expand Down
4 changes: 2 additions & 2 deletions scripts/MinervaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU GPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"


Expand Down Expand Up @@ -54,7 +54,7 @@ def main(config_path: str):
)

try:
exit_code = subprocess.Popen( # nosec B607
exit_code = subprocess.Popen( # nosec B607, B602
shlex.split(f"python MinervaExp.py -c {config[key]}"),
shell=True,
).wait()
Expand Down
2 changes: 1 addition & 1 deletion scripts/RunTensorBoard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU GPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"


Expand Down
2 changes: 1 addition & 1 deletion scripts/TorchWeightDownloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU GPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"


Expand Down
2 changes: 1 addition & 1 deletion scripts/Torch_to_ONNX.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "GNU GPLv3"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"


Expand Down
26 changes: 26 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2023 Harry Baker
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program in LICENSE.txt. If not,
# see <https://www.gnu.org/licenses/>.
#
# @org: University of Southampton
# Created under a project funded by the Ordnance Survey Ltd.
r""":mod:`pytest` suite for :mod:`minerva` CI/CD.
"""

__author__ = "Harry Baker"
__contact__ = "[email protected]"
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
Loading

0 comments on commit 4e07006

Please sign in to comment.