Skip to content

Commit

Permalink
Add float8 & int4 numpy integration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573934532
Change-Id: Iee8ffac67909534bd87f5d4df30b9cd8cc326b4a
  • Loading branch information
ChromeHearts authored and copybara-github committed Oct 16, 2023
1 parent 957f027 commit b8ad93d
Show file tree
Hide file tree
Showing 17 changed files with 286 additions and 2,426 deletions.
22 changes: 13 additions & 9 deletions python/tensorstore/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
":_tensorstore",
"@pypa_ml_dtypes//:ml_dtypes", # build_cleaner: keep
"@pypa_numpy//:numpy",
],
)
Expand Down Expand Up @@ -438,14 +439,10 @@ pybind11_cc_library(
pybind11_cc_library(
name = "data_type",
srcs = [
"bfloat16.cc",
"data_type.cc",
"int4.cc",
],
hdrs = [
"bfloat16.h",
"data_type.h",
"int4.h",
],
deps = [
":json_type_caster",
Expand All @@ -454,15 +451,12 @@ pybind11_cc_library(
":tensorstore_module_components",
"//tensorstore:data_type",
"//tensorstore/internal:global_initializer",
"//tensorstore/internal:type_traits",
"//tensorstore/util:bfloat16",
"//tensorstore/internal:no_destructor",
"//tensorstore/util:executor",
"//tensorstore/util:int4",
"//tensorstore/util:quote_string",
"//tensorstore/util:str_cat",
"@com_github_nlohmann_json//:nlohmann_json",
"@com_github_pybind_pybind11//:pybind11",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
],
alwayslink = True,
Expand Down Expand Up @@ -1109,3 +1103,13 @@ tensorstore_pytest_test(
":tensorstore",
],
)

tensorstore_pytest_test(
name = "tests/custom_dtypes_test",
srcs = ["tests/custom_dtypes_test.py"],
deps = [
":tensorstore",
"@pypa_ml_dtypes//:ml_dtypes",
"@pypa_numpy//:numpy",
],
)
45 changes: 45 additions & 0 deletions python/tensorstore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,51 @@ class FutureLike(metaclass=_abc.ABCMeta):
Data types
"""

float8_e4m3fn: dtype
"""8-bit floating-point data type.
Details in https://github.com/jax-ml/ml_dtypes#float8_e4m3fn
Group:
Data types
"""

float8_e4m3fnuz: dtype
"""8-bit floating-point data type.
Details in https://github.com/jax-ml/ml_dtypes#float8_e4m3fnuz
Group:
Data types
"""

float8_e4m3b11fnuz: dtype
"""8-bit floating-point data type.
Details in https://github.com/jax-ml/ml_dtypes#float8_e4m3b11fnuz
Group:
Data types
"""

float8_e5m2: dtype
"""8-bit floating-point data type.
Details in https://github.com/jax-ml/ml_dtypes#float8_e5m2
Group:
Data types
"""

float8_e5m2fnuz: dtype
"""8-bit floating-point data type.
Details in https://github.com/jax-ml/ml_dtypes#float8_e5m2fnuz
Group:
Data types
"""

float16: dtype
""":wikipedia:`IEEE 754 binary16 <Half-precision_floating-point_format>` half-precision floating-point data type. Correspond to ``numpy.float16``.
Expand Down
Loading

0 comments on commit b8ad93d

Please sign in to comment.