-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: to/from PyTorch Tensor (#3259)
* add new to_torch function * add new from_torch function * add changes suggested by Jim * style: pre-commit fixes * fix style --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
cfe58f3
commit ee5865a
Showing
4 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import awkward as ak | ||
from awkward._dispatch import high_level_function | ||
|
||
__all__ = ("from_torch",) | ||
|
||
|
||
@high_level_function() | ||
def from_torch(array): | ||
""" | ||
Args: | ||
array: (PyTorch Tensor): | ||
Tensor to convert into an Awkward Array. | ||
Converts a PyTorch Tensor into an Awkward Array. | ||
If `array` contains any other data types the function raises an error. | ||
""" | ||
|
||
# Dispatch | ||
yield (array,) | ||
|
||
# Implementation | ||
return _impl(array) | ||
|
||
|
||
def _impl(array): | ||
try: | ||
import torch | ||
except ImportError as err: | ||
raise ImportError( | ||
"""to use ak.from_torch, you must install 'torch' package with: | ||
pip install torch | ||
or | ||
conda install pytorch""" | ||
) from err | ||
|
||
# check if array is a Tensor | ||
if not isinstance(array, torch.Tensor): | ||
raise TypeError("""only PyTorch Tensor can be converted to Awkward Array""") | ||
|
||
# keep the resulting array on the same device as input tensor | ||
device = "cuda" if array.is_cuda else "cpu" | ||
|
||
# convert tensors to cupy if they are on cuda | ||
if device == "cuda": | ||
from awkward._nplikes.cupy import Cupy | ||
|
||
cp = Cupy.instance() | ||
|
||
# zero-copy data exchange through DLPack | ||
cp_array = cp.from_dlpack(array) | ||
ak_array = ak.from_cupy(cp_array) | ||
|
||
else: | ||
np_array = array.numpy() | ||
ak_array = ak.from_numpy(np_array) | ||
|
||
return ak_array |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import awkward as ak | ||
from awkward._dispatch import high_level_function | ||
from awkward._nplikes.numpy_like import NumpyMetadata | ||
|
||
__all__ = ("to_torch",) | ||
|
||
np = NumpyMetadata.instance() | ||
|
||
|
||
@high_level_function() | ||
def to_torch(array): | ||
""" | ||
Args: | ||
array: Array-like data. May be a high level #ak.Array, | ||
or low-level #ak.contents.ListOffsetArray, #ak.contents.ListArray, | ||
#ak.contents.RegularArray, #ak.contents.NumpyArray | ||
Converts `array` (only ListOffsetArray, ListArray, RegularArray and NumpyArray data types supported) | ||
into a PyTorch Tensor, if possible. | ||
If `array` contains any other data types (RecordArray for example) the function raises a TypeError. | ||
""" | ||
|
||
# Dispatch | ||
yield (array,) | ||
|
||
# Implementation | ||
return _impl(array) | ||
|
||
|
||
def _impl(array): | ||
try: | ||
import torch | ||
except ImportError as err: | ||
raise ImportError( | ||
"""to use ak.to_torch, you must install 'torch' package with: | ||
pip install torch | ||
or | ||
conda install pytorch""" | ||
) from err | ||
|
||
# useful function that handles all possible input arrays | ||
array = ak.to_layout(array, allow_record=False) | ||
|
||
# get the device array is on | ||
device = ak.backend(array) | ||
|
||
if device not in ["cuda", "cpu"]: | ||
raise ValueError("Only 'cpu' and 'cuda' backend conversions are allowed") | ||
|
||
# convert to numpy or cupy if `array` on gpu | ||
try: | ||
backend_array = array.to_backend_array(allow_missing=False) | ||
except ValueError as err: | ||
raise TypeError( | ||
"Only arrays containing equal-length lists of numbers can be converted into a PyTorch Tensor" | ||
) from err | ||
|
||
# check if cupy or numpy | ||
if isinstance(backend_array, np.ndarray): | ||
# convert numpy to a torch tensor | ||
tensor = torch.from_numpy(backend_array) | ||
else: | ||
# cupy -> torch tensor | ||
tensor = torch.utils.dlpack.from_dlpack(backend_array.toDlpack()) | ||
|
||
return tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import awkward as ak | ||
|
||
to_torch = ak.operations.to_torch | ||
from_torch = ak.operations.from_torch | ||
|
||
torch = pytest.importorskip("torch") | ||
|
||
a = np.arange(2 * 2 * 2, dtype=np.float64).reshape(2, 2, 2) | ||
b = np.arange(2 * 2 * 2).reshape(2, 2, 2) | ||
|
||
array = np.arange(2 * 3 * 5).reshape(2, 3, 5) | ||
content2 = ak.contents.NumpyArray(array.reshape(-1)) | ||
inneroffsets = ak.index.Index64(np.array([0, 5, 10, 15, 20, 25, 30])) | ||
outeroffsets = ak.index.Index64(np.array([0, 3, 6])) | ||
|
||
|
||
def test_to_torch(): | ||
# a basic test for a 4 dimensional array | ||
array1 = ak.Array([a, b]) | ||
i = 0 | ||
for sub_array in [ | ||
[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]], | ||
[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]], | ||
]: | ||
assert to_torch(array1)[i].tolist() == sub_array | ||
i += 1 | ||
|
||
# test that the data types are remaining the same (float64 in this case) | ||
assert array1.layout.to_backend_array().dtype.name in str(to_torch(array1).dtype) | ||
|
||
# try a listoffset array inside a listoffset array | ||
array2 = ak.contents.ListOffsetArray( | ||
outeroffsets, ak.contents.ListOffsetArray(inneroffsets, content2) | ||
) | ||
assert to_torch(array2)[0].tolist() == [ | ||
[0, 1, 2, 3, 4], | ||
[5, 6, 7, 8, 9], | ||
[10, 11, 12, 13, 14], | ||
] | ||
assert to_torch(array2)[1].tolist() == [ | ||
[15, 16, 17, 18, 19], | ||
[20, 21, 22, 23, 24], | ||
[25, 26, 27, 28, 29], | ||
] | ||
|
||
# try just a python list | ||
array3 = [3, 1, 4, 1, 9, 2, 6] | ||
assert to_torch(array3).tolist() == [3, 1, 4, 1, 9, 2, 6] | ||
|
||
|
||
array1 = torch.tensor([[1.0, -1.0], [1.0, -1.0]], dtype=torch.float32) | ||
array2 = torch.tensor(np.array([[1, 2, 3], [4, 5, 6]])) | ||
|
||
|
||
def test_from_torch(): | ||
# Awkward.to_list() == Tensor.tolist() | ||
assert from_torch(array1).to_list() == array1.tolist() | ||
|
||
assert from_torch(array2).to_list() == array2.tolist() | ||
|
||
# test that the data types are remaining the same (int64 in this case) | ||
assert from_torch(array1).layout.dtype.name in str(array1.dtype) | ||
|
||
# test that the data types are remaining the same (float32 in this case) | ||
assert from_torch(array2).layout.dtype.name in str(array2.dtype) |