Skip to content

Commit

Permalink
Merge pull request #1516 from helmholtz-analytics/features/1383-Imple…
Browse files Browse the repository at this point in the history
…ment_vmap

Distributed `vmap` functionality for vectorization across split dimension
  • Loading branch information
mrfh92 authored Jul 8, 2024
2 parents 72dffaf + 2512954 commit b19931c
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 0 deletions.
1 change: 1 addition & 0 deletions heat/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
from .types import finfo, iinfo
from . import version
from .version import __version__
from .vmap import *
145 changes: 145 additions & 0 deletions heat/core/tests/test_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import heat as ht
import torch

from .test_suites.basic_test import TestCase


class TestVmap(TestCase):
if torch.__version__ < "2.0.0":

def test_vmap(self):
out_dims = (0, 0)

def func(x0, x1, k=2, scale=1e-2):
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1

with self.assertRaises(RuntimeError):
vfunc = ht.vmap(func, out_dims) # noqa: F841

else:

def test_vmap(self):
# two inputs (both split), two outputs, including keyword arguments that are not vmapped
# inputs split along different axes, output split along same axis (one of them different to input split)
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0)
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=1)
out_dims = 0 # test with out_dims as int (tuple below)

def func(x0, x1, k=2, scale=1e-2):
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1

vfunc = ht.vmap(func, out_dims)
y0, y1 = vfunc(x0, x1, k=2, scale=2.2)

# compare with torch
x0_torch = x0.resplit(None).larray
x1_torch = x1.resplit(None).larray
vfunc_torch = torch.vmap(func, (0, 1), (0, 0))
y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=2, scale=2.2)

self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch))
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch))

# two inputs (only one of them split), two outputs, including keyword arguments that are not vmapped
# output split along different axis, one output has different data type than input
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0)
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=None)
out_dims = (0, 1)

def func(x0, x1, k=2, scale=1e-2):
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, (scale * x0 @ x1).int()

vfunc = ht.vmap(func, out_dims)
y0, y1 = vfunc(x0, x1, k=2, scale=2.2)

# compare with torch
x0_torch = x0.resplit(None).larray
x1_torch = x1.resplit(None).larray
vfunc_torch = torch.vmap(func, (0, None), (0, 1))
y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=2, scale=2.2)

self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch))
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch))

# catch wrong number of output dimensions
with self.assertRaises(ValueError):
vfunc = ht.vmap(func, (0, 1, 2))
y0, y1 = vfunc(x0, x1, k=2, scale=2.2)

# one output only
def func(x0, m=1, scale=2):
return (x0 - m) ** scale

vfunc = ht.vmap(func, out_dims=(0,))

x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0)
y0 = vfunc(x0, m=2, scale=3)[0]

x0_torch = x0.resplit(None).larray
vfunc_torch = torch.vmap(func, (0,), (0,))
y0_torch = vfunc_torch(x0_torch, m=2, scale=3)

print(y0.resplit(None).larray, y0_torch)

self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch))

def test_vmap_with_chunks(self):
# same as before but now with prescribed chunk sizes for the vmap
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0)
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=1)
out_dims = (0, 0)

def func(x0, x1, k=2, scale=1e-2):
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1

vfunc = ht.vmap(func, out_dims, chunk_size=2)
y0, y1 = vfunc(x0, x1, k=2, scale=-2.2)

# compare with torch
x0_torch = x0.resplit(None).larray
x1_torch = x1.resplit(None).larray
vfunc_torch = torch.vmap(func, (0, 1), (0, 0))
y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=2, scale=-2.2)

self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch))
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch))

# two inputs (only one of them split), two outputs, including keyword arguments that are not vmapped
# output split along different axis
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0)
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=None)
out_dims = (0, 1)

def func(x0, x1, k=2, scale=1e-2):
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1

vfunc = ht.vmap(func, out_dims, chunk_size=1)
y0, y1 = vfunc(x0, x1, k=5, scale=2.2)

# compare with torch
x0_torch = x0.resplit(None).larray
x1_torch = x1.resplit(None).larray
vfunc_torch = torch.vmap(func, (0, None), (0, 1))
y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=5, scale=2.2)

self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch))
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch))

def test_vmap_catch_errors(self):
# not a callable
with self.assertRaises(TypeError):
ht.vmap(1)
# invalid randomness
with self.assertRaises(ValueError):
ht.vmap(lambda x: x, randomness="random")
# invalid chunk_size
with self.assertRaises(TypeError):
ht.vmap(lambda x: x, chunk_size="1")
with self.assertRaises(ValueError):
ht.vmap(lambda x: x, chunk_size=0)
# not all inputs are DNDarrays
with self.assertRaises(TypeError):
ht.vmap(lambda x: x, out_dims=0)(ht.ones(10), 2)
# number of output DNDarrays does not match number of split dimensions
with self.assertRaises(ValueError):
ht.vmap(lambda x: x, out_dims=(0, 1))(ht.ones(10))
104 changes: 104 additions & 0 deletions heat/core/vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
This implements a functionality similar to PyTorchs vmap function.
Requires PyTorch 2.0.0 or higher.
"""

import torch

from .dndarray import DNDarray
from .factories import array
from .communication import MPI_WORLD
from typing import Union, Tuple, Optional, Callable

__all__ = ["vmap"]


def vmap(
func: Callable[[Tuple[torch.Tensor]], Tuple[torch.Tensor]],
out_dims: Union[Tuple[int], int] = 0,
randomness: str = "error",
*,
chunk_size: int = None,
) -> Callable[[Tuple[DNDarray]], Tuple[DNDarray]]:
"""
This function is used to apply a function to a DNDarray in a vectorized way.
`heat.vmap` return a callable that can be applied to DNDarrays.
Vectorization will automatically take place along the split axis/axes of the DNDarray(s);
therefore, unlike in PyTorch, there is no argument `in_dims`.
What we here refer to as "split axis/dimension" in the Heat terminology is often referred to as "batch axis/dimension" in the PyTorch terminology.
Parameters
----------
func : callable
The function to apply in a vmapped way to the DNDarray(s). It must take PyTorch tensor(s) as positional arguments.
Additional parameters, not to be vmapped over, can be passed as keyword arguments. The callable returned by
by `heat.vmap` will also accept these keyword arguments.
out_dims : int or tuple of int, optional
The dimensions of the output(s) that are mapped over; identical to the split dimension(s) of the output(s).
Default is 0.
randomness : {'error', 'different', 'same'}, optional
Determines how to handle randomness in the function to be vmapped. This argument is directly passed to the underlying PyTorch vmaps;
see the corresponding PyTorch documentation for more information and the note below.
If 'error' (default), an error is raised if the function to be mapped contains randomness.
chunk_size : int, optional
The size of the chunks to use for the process-local computation.
If None (default), apply a single PyTorch vmap over the process-local chunks of data. If not None, then compute the process-local PyTorch vmap `chunk_size`
many samples at a time. Note that `chunk_size=1` is equivalent to computing the process-local PyTorch vmap's with a for-loop.
If you run into memory issues computing the vmap, please try a non-None chunk_size.
Note
------
This function is a wrapper around PyTorch's `torch.vmap` function. In essence, a PyTorch vmap is applied to the input function `func` on each MPI process separately.
This process-local PyTorch-vmapped function is then applied to the process-local chunks of the input DNDarray(s).
Please note that the options 'same' and 'different' for `randomness` will result in behaviour different from the one known by PyTorch as (at least currently)
no actions are taken to synchronize randomness across the MPI processes.
"""
# check PyTorch version, return error if not 2.0.0 or higher
if torch.__version__ < "2.0.0":
raise RuntimeError("The function `heat.vmap` requires PyTorch 2.0.0 or higher.")
# rough check of input argument types
if not callable(func):
raise TypeError("The input function `func` must be callable.")
if randomness not in ["error", "different", "same"]:
raise ValueError(
"The input argument `randomness` must be one of the strings 'error', 'different', or 'same'."
)
if chunk_size is not None and not isinstance(chunk_size, int):
raise TypeError("The input argument `chunk_size` must be None or an integer.")
else:
if chunk_size is not None and chunk_size < 1:
raise ValueError("If an integer, the input argument `chunk_size` must be at least 1.")

def vmapped_func(*args, **kwargs):
for arg in args:
if not isinstance(arg, DNDarray):
raise TypeError(
f"All inputs to the vmapped-version of your function must be DNDarrays, but one is {type(arg)}."
)
in_dims = tuple([arg.split for arg in args])

# apply Torch vmap to the input function and the result to the local arrays of the input DNDarray
torch_vmap_func = torch.vmap(
func, in_dims, out_dims, randomness=randomness, chunk_size=chunk_size
)
out_larrays = torch_vmap_func(*[arg.larray for arg in args], **kwargs)

if isinstance(out_larrays, torch.Tensor):
# circumvent misinterpretation of the following call of len() in case of a single output
out_larrays = [out_larrays]
if isinstance(out_dims, int):
out_split = [out_dims] * len(out_larrays)
else:
out_split = out_dims
if len(out_split) != len(out_larrays):
raise ValueError(
f"The number of output DNDarrays ({len(out_larrays)}) must match the number of their split dimensions provided in `out_dims` ({len(out_split)})."
)
# generate output DNDarray(s)
out_dndarrays = [
array(out_larrays[k], is_split=out_split[k]) for k in range(len(out_larrays))
]
return tuple(out_dndarrays)

return vmapped_func

0 comments on commit b19931c

Please sign in to comment.