Skip to content

Commit

Permalink
Add dependency on ml_types
Browse files Browse the repository at this point in the history
This will enable us to write bf16 numpy files.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Dec 4, 2024
1 parent f9edfe7 commit e617d5b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
12 changes: 11 additions & 1 deletion iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

# TODO: Monkey-patching f16 support, need to fix in iree.
import numpy
import ml_dtypes

bench.DTYPE_TO_ABI_TYPE[numpy.dtype(numpy.float16)] = "f16"

Expand Down Expand Up @@ -654,7 +655,16 @@ def compile_and_invoke(
for inp in all_inputs:
if isinstance(inp, torch.Tensor):
tf = tempfile.NamedTemporaryFile(suffix=".npy")
numpy.save(tf, inp.cpu().numpy())
inp = inp.cpu()
if inp.dtype == torch.bfloat16:
inp = (
inp.view(dtype=torch.uint16)
.numpy()
.view(dtype=ml_dtypes.bfloat16)
)
else:
inp = inp.numpy()
numpy.save(tf, inp)
tempfiles.append(tf)
inputs.append("@" + tf.name)
elif isinstance(inp, int):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pytest==8.0.0
pytest-xdist==3.5.0
lit==18.1.7
mypy==1.8.0
ml_dtypes==0.5.0
setuptools
wheel

Expand Down

0 comments on commit e617d5b

Please sign in to comment.