-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Support float4e2m1 #1908
base: main
Are you sure you want to change the base?
[IR] Support float4e2m1 #1908
Conversation
❌ 14 Tests Failed:
View the full list of 3 ❄️ flaky tests
To view individual test run time comparison to the main branch, go to the Test Analytics Dashboard |
Returns: | ||
A numpy array of float32 reshaped to dims. | ||
""" | ||
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to handle when ml_dtypes.float4_e2m1fn
is not available?
@@ -32,6 +32,8 @@ def test_enums_are_the_same_as_spec(self): | |||
self.assertEqual(_enums.DataType.FLOAT8E5M2FNUZ, onnx.TensorProto.FLOAT8E5M2FNUZ) | |||
self.assertEqual(_enums.DataType.UINT4, onnx.TensorProto.UINT4) | |||
self.assertEqual(_enums.DataType.INT4, onnx.TensorProto.INT4) | |||
if hasattr(onnx.TensorProto, "FLOAT4E2M1"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the new functions are already tested somewhere else.
Support the float4e2m1 dtype from IRv11 (which is not yet released). This allows our tests to pass in the weekly-onnx CI. We use the ml_dtypes.float4_e2m1fn type for numpy conversion. Since ml_dtypes.float4_e2m1fn is only available in the latest ml_dtypes release which has dropped support for python 3.8, I used a conditional logic to build the numpy dtype mapping table.