-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1516 from helmholtz-analytics/features/1383-Imple…
…ment_vmap Distributed `vmap` functionality for vectorization across split dimension
- Loading branch information
Showing
3 changed files
with
250 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,3 +31,4 @@ | |
from .types import finfo, iinfo | ||
from . import version | ||
from .version import __version__ | ||
from .vmap import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import heat as ht | ||
import torch | ||
|
||
from .test_suites.basic_test import TestCase | ||
|
||
|
||
class TestVmap(TestCase): | ||
if torch.__version__ < "2.0.0": | ||
|
||
def test_vmap(self): | ||
out_dims = (0, 0) | ||
|
||
def func(x0, x1, k=2, scale=1e-2): | ||
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1 | ||
|
||
with self.assertRaises(RuntimeError): | ||
vfunc = ht.vmap(func, out_dims) # noqa: F841 | ||
|
||
else: | ||
|
||
def test_vmap(self): | ||
# two inputs (both split), two outputs, including keyword arguments that are not vmapped | ||
# inputs split along different axes, output split along same axis (one of them different to input split) | ||
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0) | ||
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=1) | ||
out_dims = 0 # test with out_dims as int (tuple below) | ||
|
||
def func(x0, x1, k=2, scale=1e-2): | ||
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1 | ||
|
||
vfunc = ht.vmap(func, out_dims) | ||
y0, y1 = vfunc(x0, x1, k=2, scale=2.2) | ||
|
||
# compare with torch | ||
x0_torch = x0.resplit(None).larray | ||
x1_torch = x1.resplit(None).larray | ||
vfunc_torch = torch.vmap(func, (0, 1), (0, 0)) | ||
y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=2, scale=2.2) | ||
|
||
self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch)) | ||
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch)) | ||
|
||
# two inputs (only one of them split), two outputs, including keyword arguments that are not vmapped | ||
# output split along different axis, one output has different data type than input | ||
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0) | ||
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=None) | ||
out_dims = (0, 1) | ||
|
||
def func(x0, x1, k=2, scale=1e-2): | ||
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, (scale * x0 @ x1).int() | ||
|
||
vfunc = ht.vmap(func, out_dims) | ||
y0, y1 = vfunc(x0, x1, k=2, scale=2.2) | ||
|
||
# compare with torch | ||
x0_torch = x0.resplit(None).larray | ||
x1_torch = x1.resplit(None).larray | ||
vfunc_torch = torch.vmap(func, (0, None), (0, 1)) | ||
y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=2, scale=2.2) | ||
|
||
self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch)) | ||
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch)) | ||
|
||
# catch wrong number of output dimensions | ||
with self.assertRaises(ValueError): | ||
vfunc = ht.vmap(func, (0, 1, 2)) | ||
y0, y1 = vfunc(x0, x1, k=2, scale=2.2) | ||
|
||
# one output only | ||
def func(x0, m=1, scale=2): | ||
return (x0 - m) ** scale | ||
|
||
vfunc = ht.vmap(func, out_dims=(0,)) | ||
|
||
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0) | ||
y0 = vfunc(x0, m=2, scale=3)[0] | ||
|
||
x0_torch = x0.resplit(None).larray | ||
vfunc_torch = torch.vmap(func, (0,), (0,)) | ||
y0_torch = vfunc_torch(x0_torch, m=2, scale=3) | ||
|
||
print(y0.resplit(None).larray, y0_torch) | ||
|
||
self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch)) | ||
|
||
def test_vmap_with_chunks(self): | ||
# same as before but now with prescribed chunk sizes for the vmap | ||
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0) | ||
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=1) | ||
out_dims = (0, 0) | ||
|
||
def func(x0, x1, k=2, scale=1e-2): | ||
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1 | ||
|
||
vfunc = ht.vmap(func, out_dims, chunk_size=2) | ||
y0, y1 = vfunc(x0, x1, k=2, scale=-2.2) | ||
|
||
# compare with torch | ||
x0_torch = x0.resplit(None).larray | ||
x1_torch = x1.resplit(None).larray | ||
vfunc_torch = torch.vmap(func, (0, 1), (0, 0)) | ||
y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=2, scale=-2.2) | ||
|
||
self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch)) | ||
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch)) | ||
|
||
# two inputs (only one of them split), two outputs, including keyword arguments that are not vmapped | ||
# output split along different axis | ||
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0) | ||
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=None) | ||
out_dims = (0, 1) | ||
|
||
def func(x0, x1, k=2, scale=1e-2): | ||
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1 | ||
|
||
vfunc = ht.vmap(func, out_dims, chunk_size=1) | ||
y0, y1 = vfunc(x0, x1, k=5, scale=2.2) | ||
|
||
# compare with torch | ||
x0_torch = x0.resplit(None).larray | ||
x1_torch = x1.resplit(None).larray | ||
vfunc_torch = torch.vmap(func, (0, None), (0, 1)) | ||
y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=5, scale=2.2) | ||
|
||
self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch)) | ||
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch)) | ||
|
||
def test_vmap_catch_errors(self): | ||
# not a callable | ||
with self.assertRaises(TypeError): | ||
ht.vmap(1) | ||
# invalid randomness | ||
with self.assertRaises(ValueError): | ||
ht.vmap(lambda x: x, randomness="random") | ||
# invalid chunk_size | ||
with self.assertRaises(TypeError): | ||
ht.vmap(lambda x: x, chunk_size="1") | ||
with self.assertRaises(ValueError): | ||
ht.vmap(lambda x: x, chunk_size=0) | ||
# not all inputs are DNDarrays | ||
with self.assertRaises(TypeError): | ||
ht.vmap(lambda x: x, out_dims=0)(ht.ones(10), 2) | ||
# number of output DNDarrays does not match number of split dimensions | ||
with self.assertRaises(ValueError): | ||
ht.vmap(lambda x: x, out_dims=(0, 1))(ht.ones(10)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
""" | ||
This implements a functionality similar to PyTorchs vmap function. | ||
Requires PyTorch 2.0.0 or higher. | ||
""" | ||
|
||
import torch | ||
|
||
from .dndarray import DNDarray | ||
from .factories import array | ||
from .communication import MPI_WORLD | ||
from typing import Union, Tuple, Optional, Callable | ||
|
||
__all__ = ["vmap"] | ||
|
||
|
||
def vmap( | ||
func: Callable[[Tuple[torch.Tensor]], Tuple[torch.Tensor]], | ||
out_dims: Union[Tuple[int], int] = 0, | ||
randomness: str = "error", | ||
*, | ||
chunk_size: int = None, | ||
) -> Callable[[Tuple[DNDarray]], Tuple[DNDarray]]: | ||
""" | ||
This function is used to apply a function to a DNDarray in a vectorized way. | ||
`heat.vmap` return a callable that can be applied to DNDarrays. | ||
Vectorization will automatically take place along the split axis/axes of the DNDarray(s); | ||
therefore, unlike in PyTorch, there is no argument `in_dims`. | ||
What we here refer to as "split axis/dimension" in the Heat terminology is often referred to as "batch axis/dimension" in the PyTorch terminology. | ||
Parameters | ||
---------- | ||
func : callable | ||
The function to apply in a vmapped way to the DNDarray(s). It must take PyTorch tensor(s) as positional arguments. | ||
Additional parameters, not to be vmapped over, can be passed as keyword arguments. The callable returned by | ||
by `heat.vmap` will also accept these keyword arguments. | ||
out_dims : int or tuple of int, optional | ||
The dimensions of the output(s) that are mapped over; identical to the split dimension(s) of the output(s). | ||
Default is 0. | ||
randomness : {'error', 'different', 'same'}, optional | ||
Determines how to handle randomness in the function to be vmapped. This argument is directly passed to the underlying PyTorch vmaps; | ||
see the corresponding PyTorch documentation for more information and the note below. | ||
If 'error' (default), an error is raised if the function to be mapped contains randomness. | ||
chunk_size : int, optional | ||
The size of the chunks to use for the process-local computation. | ||
If None (default), apply a single PyTorch vmap over the process-local chunks of data. If not None, then compute the process-local PyTorch vmap `chunk_size` | ||
many samples at a time. Note that `chunk_size=1` is equivalent to computing the process-local PyTorch vmap's with a for-loop. | ||
If you run into memory issues computing the vmap, please try a non-None chunk_size. | ||
Note | ||
------ | ||
This function is a wrapper around PyTorch's `torch.vmap` function. In essence, a PyTorch vmap is applied to the input function `func` on each MPI process separately. | ||
This process-local PyTorch-vmapped function is then applied to the process-local chunks of the input DNDarray(s). | ||
Please note that the options 'same' and 'different' for `randomness` will result in behaviour different from the one known by PyTorch as (at least currently) | ||
no actions are taken to synchronize randomness across the MPI processes. | ||
""" | ||
# check PyTorch version, return error if not 2.0.0 or higher | ||
if torch.__version__ < "2.0.0": | ||
raise RuntimeError("The function `heat.vmap` requires PyTorch 2.0.0 or higher.") | ||
# rough check of input argument types | ||
if not callable(func): | ||
raise TypeError("The input function `func` must be callable.") | ||
if randomness not in ["error", "different", "same"]: | ||
raise ValueError( | ||
"The input argument `randomness` must be one of the strings 'error', 'different', or 'same'." | ||
) | ||
if chunk_size is not None and not isinstance(chunk_size, int): | ||
raise TypeError("The input argument `chunk_size` must be None or an integer.") | ||
else: | ||
if chunk_size is not None and chunk_size < 1: | ||
raise ValueError("If an integer, the input argument `chunk_size` must be at least 1.") | ||
|
||
def vmapped_func(*args, **kwargs): | ||
for arg in args: | ||
if not isinstance(arg, DNDarray): | ||
raise TypeError( | ||
f"All inputs to the vmapped-version of your function must be DNDarrays, but one is {type(arg)}." | ||
) | ||
in_dims = tuple([arg.split for arg in args]) | ||
|
||
# apply Torch vmap to the input function and the result to the local arrays of the input DNDarray | ||
torch_vmap_func = torch.vmap( | ||
func, in_dims, out_dims, randomness=randomness, chunk_size=chunk_size | ||
) | ||
out_larrays = torch_vmap_func(*[arg.larray for arg in args], **kwargs) | ||
|
||
if isinstance(out_larrays, torch.Tensor): | ||
# circumvent misinterpretation of the following call of len() in case of a single output | ||
out_larrays = [out_larrays] | ||
if isinstance(out_dims, int): | ||
out_split = [out_dims] * len(out_larrays) | ||
else: | ||
out_split = out_dims | ||
if len(out_split) != len(out_larrays): | ||
raise ValueError( | ||
f"The number of output DNDarrays ({len(out_larrays)}) must match the number of their split dimensions provided in `out_dims` ({len(out_split)})." | ||
) | ||
# generate output DNDarray(s) | ||
out_dndarrays = [ | ||
array(out_larrays[k], is_split=out_split[k]) for k in range(len(out_larrays)) | ||
] | ||
return tuple(out_dndarrays) | ||
|
||
return vmapped_func |