-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #169 from Pale-Blue-Dot-97/torchcompile-dev
Added `torch.compile` functionality
- Loading branch information
Showing
49 changed files
with
1,157 additions
and
327 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,5 +45,5 @@ | |
__version__ = "0.23.4" | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"PairedDataset", | ||
|
@@ -384,7 +384,9 @@ def get_subdataset( | |
) | ||
|
||
# Construct the root to the sub-dataset's files. | ||
sub_dataset_root: Path = universal_path(data_directory) / sub_dataset_params["root"] | ||
sub_dataset_root: Path = ( | ||
universal_path(data_directory) / sub_dataset_params["root"] | ||
) | ||
sub_dataset_root = sub_dataset_root.absolute() | ||
|
||
return _sub_dataset, str(sub_dataset_root) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
|
||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"MinervaLogger", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"MinervaMetrics", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"sup_tg", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
# ===================================================================================================================== | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
__all__ = [ | ||
|
@@ -449,7 +449,7 @@ def get_output_shape( | |
assert isinstance(_image_dim, Iterable) | ||
random_input = torch.rand([4, *_image_dim]) | ||
|
||
output: Tensor = model(random_input) | ||
output: Tensor = model(random_input.to(next(model.parameters()).device)) | ||
|
||
if len(output[0].data.shape) == 1: | ||
return output[0].data.shape[0] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
__all__ = [ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"ResNet", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"MinervaSiamese", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
__all__ = [ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"RandomPairGeoSampler", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,14 +31,15 @@ | |
|
||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = ["Trainer"] | ||
|
||
# ===================================================================================================================== | ||
# IMPORTS | ||
# ===================================================================================================================== | ||
import os | ||
import warnings | ||
from contextlib import nullcontext | ||
from pathlib import Path | ||
from typing import ( | ||
|
@@ -61,6 +62,7 @@ | |
from alive_progress import alive_bar, alive_it | ||
from inputimeout import TimeoutOccurred, inputimeout | ||
from nptyping import Int, NDArray | ||
from packaging.version import Version | ||
from torch import Tensor | ||
from torch.nn.modules import Module | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
|
@@ -414,6 +416,16 @@ def __init__( | |
) | ||
self.model = MinervaDataParallel(self.model, DDP, device_ids=[gpu]) | ||
|
||
# Wraps the model in `torch.compile` to speed up computation time. | ||
# Python 3.11+ is not yet supported though, hence the exception clause. | ||
if Version(torch.__version__) > Version("2.0.0"): # pragma: no cover | ||
try: | ||
_compiled_model = torch.compile(self.model) | ||
assert isinstance(_compiled_model, (MinervaModel, MinervaDataParallel)) | ||
self.model = _compiled_model | ||
except RuntimeError as err: | ||
warnings.warn(str(err)) | ||
|
||
def init_wandb_metrics(self) -> None: | ||
"""Setups up separate step counters for :mod:`wandb` logging of train, val, etc.""" | ||
if isinstance(self.writer, Run): | ||
|
@@ -1063,7 +1075,9 @@ def weighted_knn_validation( | |
for batch in test_bar: | ||
test_data: Tensor = batch["image"].to(self.device, non_blocking=True) | ||
test_target: Tensor = torch.mode( | ||
torch.flatten(batch["mask"], start_dim=1) | ||
torch.flatten( | ||
batch["mask"].to(self.device, non_blocking=True), start_dim=1 | ||
) | ||
).values | ||
|
||
# Get features from passing the input data through the model. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"ClassTransform", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"universal_path", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"DEFAULT_CONF_DIR_PATH", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"GENERIC_PARSER", | ||
|
@@ -51,8 +51,8 @@ | |
# ===================================================================================================================== | ||
import argparse | ||
import os | ||
import signal | ||
import shlex | ||
import signal | ||
import subprocess | ||
from argparse import Namespace | ||
from typing import Any, Callable, Optional, Union | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"IMAGERY_CONFIG_PATH", | ||
|
@@ -121,12 +121,12 @@ | |
import random | ||
import re as regex | ||
import shlex | ||
from subprocess import Popen | ||
import sys | ||
import webbrowser | ||
from collections import Counter, OrderedDict | ||
from datetime import datetime | ||
from pathlib import Path | ||
from subprocess import Popen | ||
from types import ModuleType | ||
from typing import Any, Callable | ||
from typing import Counter as CounterType | ||
|
@@ -1669,15 +1669,19 @@ def run_tensorboard( | |
os.chdir(_path) | ||
|
||
# Activates the correct Conda environment. | ||
Popen(shlex.split(f"conda activate {env_name}"), shell=True).wait() # nosec B607 | ||
Popen( # nosec B607, B602 | ||
shlex.split(f"conda activate {env_name}"), shell=True | ||
).wait() | ||
|
||
if _testing: | ||
os.chdir(cwd) | ||
return 0 | ||
|
||
else: # pragma: no cover | ||
# Runs TensorBoard log. | ||
Popen(shlex.split(f"tensorboard --logdir {exp_name}"), shell=True) # nosec B607 | ||
Popen( # nosec B607, B602 | ||
shlex.split(f"tensorboard --logdir {exp_name}"), shell=True | ||
) | ||
|
||
# Opens the TensorBoard log in a locally hosted webpage of the default system browser. | ||
webbrowser.open(f"localhost:{host_num}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU LGPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
__all__ = [ | ||
"DATA_CONFIG", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU GPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU GPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU GPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU GPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
|
||
|
@@ -54,7 +54,7 @@ def main(config_path: str): | |
) | ||
|
||
try: | ||
exit_code = subprocess.Popen( # nosec B607 | ||
exit_code = subprocess.Popen( # nosec B607, B602 | ||
shlex.split(f"python MinervaExp.py -c {config[key]}"), | ||
shell=True, | ||
).wait() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU GPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU GPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ | |
# ===================================================================================================================== | ||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "GNU GPLv3" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (C) 2023 Harry Baker | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Lesser General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Lesser General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Lesser General Public License | ||
# along with this program in LICENSE.txt. If not, | ||
# see <https://www.gnu.org/licenses/>. | ||
# | ||
# @org: University of Southampton | ||
# Created under a project funded by the Ordnance Survey Ltd. | ||
r""":mod:`pytest` suite for :mod:`minerva` CI/CD. | ||
""" | ||
|
||
__author__ = "Harry Baker" | ||
__contact__ = "[email protected]" | ||
__license__ = "MIT License" | ||
__copyright__ = "Copyright (C) 2023 Harry Baker" |
Oops, something went wrong.