Skip to content

Commit

Permalink
support bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 12, 2024
1 parent bd82051 commit 222ffa0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ DenseElementsAttr fromTensor(stablehlo::Tensor inp) {
auto type = inp.getType();
auto elemType = type.getElementType();

if (elemType.isBF16()) {
auto floatValues =
ArrayRef((char *)inp.getData(), 2 * inp.getNumElements());
return DenseFPElementsAttr::getFromRawBuffer(type, floatValues);
}

if (elemType.isF32()) {
auto floatValues = ArrayRef((float *)inp.getData(), inp.getNumElements());
return DenseFPElementsAttr::get(type, floatValues);
Expand Down
14 changes: 14 additions & 0 deletions test/lit_tests/convertconst.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {
func.func @main() -> tensor<4xbf16> {
%concat = stablehlo.constant dense<3.140000e+00> : tensor<4xf32>
%conv = stablehlo.convert %concat : (tensor<4xf32>) -> tensor<4xbf16>
return %conv : tensor<4xbf16>
}
}

// CHECK: func.func @main() -> tensor<4xbf16> {
// CHECK-NEXT: %0 = stablehlo.constant dense<3.140630e+00> : tensor<4xbf16>
// CHECK-NEXT: return %0 : tensor<4xbf16>
// CHECK-NEXT: }

0 comments on commit 222ffa0

Please sign in to comment.