Skip to content

Commit

Permalink
PR #17754: [NFC] Extract common builder code into hlo_creation_utils
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17754

Added `XlaComputationToHloComputation` helper function to `hlo_creation_utils` and replace the uses.

Copybara import of the project:

--
63cbc86 by Sergey Kozub <[email protected]>:

Extract common builder code into hlo_creation_utils

Merging this change closes #17754

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17754 from openxla:skozub/xla_builder 63cbc86
PiperOrigin-RevId: 680912058
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Oct 2, 2024
1 parent 7d4740a commit 5b5e9fd
Show file tree
Hide file tree
Showing 93 changed files with 1,592 additions and 1,193 deletions.
901 changes: 0 additions & 901 deletions third_party/stablehlo/temporary.patch

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "9d9290dc2308c1850cea69ea05f8c94017e484ee"
STABLEHLO_SHA256 = "29803fc8a3a96f9e5469c7ab51f2ff4292dc2419c17bd0466f5d15a448cf6815"
STABLEHLO_COMMIT = "f7f8e4e35296deeff2e12e39421ac8d9599ba340"
STABLEHLO_SHA256 = "c92b55d5512e58d6fefba62c58e60d7762adb184dc3ad489521de562f6ca7aeb"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
2 changes: 2 additions & 0 deletions third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ limitations under the License.
#include "ml_dtypes/include/intn.h" // from @ml_dtypes

namespace tsl {
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
Expand Down
2 changes: 2 additions & 0 deletions xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ xla_cc_test(
":util",
"@com_google_absl//absl/base",
"@com_google_absl//absl/numeric:bits",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:ml_dtypes",
"@tsl//tsl/platform:test_main",
],
Expand Down Expand Up @@ -373,6 +374,7 @@ xla_cc_test(
":test",
":types",
":util",
"@ml_dtypes//:float8",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:ml_dtypes",
"@tsl//tsl/platform:test_main",
Expand Down
28 changes: 28 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3Fn) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3fn>(1.0, 3.5, 3, 2);

Expand Down Expand Up @@ -190,6 +204,20 @@ TEST(Array2dTest, LinspaceF8E4M3FnNoNan) {
}
}

TEST(Array2dTest, LinspaceF8E3M4) {
auto arr = MakeLinspaceArray2D<tsl::float8_e3m4>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
4 changes: 4 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "TOKEN";
case XLA_FFI_DataType_F8E5M2:
return os << "F8E5M2";
case XLA_FFI_DataType_F8E3M4:
return os << "F8E3M4";
case XLA_FFI_DataType_F8E4M3:
return os << "F8E4M3";
case XLA_FFI_DataType_F8E4M3FN:
return os << "F8E4M3FN";
case XLA_FFI_DataType_F8E4M3B11FNUZ:
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ typedef enum {
XLA_FFI_DataType_C128 = 18,
XLA_FFI_DataType_TOKEN = 17,
XLA_FFI_DataType_F8E5M2 = 19,
XLA_FFI_DataType_F8E3M4 = 29,
XLA_FFI_DataType_F8E4M3 = 28,
XLA_FFI_DataType_F8E4M3FN = 20,
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
XLA_FFI_DataType_F8E5M2FNUZ = 24,
Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ enum class DataType : uint8_t {
C128 = XLA_FFI_DataType_C128,
TOKEN = XLA_FFI_DataType_TOKEN,
F8E5M2 = XLA_FFI_DataType_F8E5M2,
F8E4M3 = XLA_FFI_DataType_F8E4M3,
F8E4M3FN = XLA_FFI_DataType_F8E4M3FN,
F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ,
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
F8E3M4 = XLA_FFI_DataType_F8E3M4,
};

// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
Expand All @@ -98,10 +100,12 @@ inline constexpr DataType C64 = DataType::C64;
inline constexpr DataType C128 = DataType::C128;
inline constexpr DataType TOKEN = DataType::TOKEN;
inline constexpr DataType F8E5M2 = DataType::F8E5M2;
inline constexpr DataType F8E4M3 = DataType::F8E4M3;
inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN;
inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
inline constexpr DataType F8E3M4 = DataType::F8E3M4;

inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
return os << static_cast<XLA_FFI_DataType>(dtype);
Expand All @@ -117,10 +121,12 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::S8:
case DataType::U8:
case DataType::F8E5M2:
case DataType::F8E4M3:
case DataType::F8E4M3FN:
case DataType::F8E4M3B11FNUZ:
case DataType::F8E5M2FNUZ:
case DataType::F8E4M3FNUZ:
case DataType::F8E3M4:
return 1;
case DataType::S16:
case DataType::U16:
Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ TEST(FfiTest, DataTypeEnumValue) {
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));

EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ),
encoded(DataType::F8E4M3B11FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
}

TEST(FfiTest, DataTypeByteWidth) {
Expand Down Expand Up @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) {

EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
ByteWidth(DataType::F8E5M2));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
ByteWidth(DataType::F8E4M3));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN),
ByteWidth(DataType::F8E4M3FN));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ),
Expand All @@ -187,6 +191,8 @@ TEST(FfiTest, DataTypeByteWidth) {
ByteWidth(DataType::F8E5M2FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FNUZ),
ByteWidth(DataType::F8E4M3FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
ByteWidth(DataType::F8E3M4));
}

TEST(FfiTest, ErrorEnumValue) {
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,12 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
case PrimitiveType::C128:
case PrimitiveType::TOKEN:
case PrimitiveType::F8E5M2:
case PrimitiveType::F8E4M3:
case PrimitiveType::F8E4M3FN:
case PrimitiveType::F8E4M3B11FNUZ:
case PrimitiveType::F8E5M2FNUZ:
case PrimitiveType::F8E4M3FNUZ:
case PrimitiveType::F8E3M4:
return static_cast<XLA_FFI_DataType>(primitive_type);
default:
DCHECK(false) << "Unsupported primitive type "
Expand Down
62 changes: 58 additions & 4 deletions xla/fp_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <limits>

#include <gtest/gtest.h>
#include "absl/base/casts.h"
#include "absl/numeric/bits.h"
#include "xla/bit_cast.h"
Expand Down Expand Up @@ -111,21 +112,74 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest,
0x1.fffffffffffffp-127,
0x1.aaaaaaaaaaaaap-127));

// Test F8E4M3 floating-point types (F8E4M3FN)
// Test F8E4M3 floating-point types (F8E4M3, F8E4M3FN)
template <typename T>
class FP8E4M3DistanceTest : public ::testing::Test {};

using F8E4M3Types = ::testing::Types<tsl::float8_e4m3fn>;
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);

TEST(FPDistanceTest, F8E3M4Distance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(8.0)),
0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(15.5)),
15);

// a & b have different exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(6)),
8);

// 1 from 0 in the positive direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
tsl::float8_e3m4(0)),
1);

// 1 from 0 in the negative direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
tsl::float8_e3m4(0)),
1);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
std::numeric_limits<tsl::float8_e3m4>::denorm_min()),
2);

// 1 non denorm from 0 in the positive direction
EXPECT_EQ(
CalculateDistanceInFloats<tsl::float8_e3m4>(
std::numeric_limits<tsl::float8_e3m4>::min(), tsl::float8_e3m4(0)),
16);

// 1 non denorm from 0 in the negative direction
EXPECT_EQ(
CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::min(), tsl::float8_e3m4(0)),
16);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::min(),
std::numeric_limits<tsl::float8_e3m4>::min()),
32);
}

TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) {
// a & b are equal, distance should be 0
EXPECT_EQ(
CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(8.0)), 0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(13)),
5);
EXPECT_EQ(
CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(15.0)), 7);

// a & b have different exponents
EXPECT_EQ(
Expand Down
10 changes: 6 additions & 4 deletions xla/hlo/builder/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ XlaOp IsNegZero(XlaOp operand) {
case F32:
return Eq(BitcastConvertType(operand, U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
case F8E3M4:
case F8E4M3:
case F8E5M2:
case F8E4M3FN:
case F8E4M3B11FNUZ:
Expand Down Expand Up @@ -973,8 +975,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down Expand Up @@ -1026,8 +1028,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1743,10 +1743,12 @@ extern template class HloEvaluatorTypedVisitor<complex64>;
extern template class HloEvaluatorTypedVisitor<complex128>;
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;

} // namespace xla

Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ limitations under the License.

namespace xla {
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;
} // namespace xla
16 changes: 8 additions & 8 deletions xla/hlo/experimental/auto_sharding/cluster_environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
#include "xla/hlo/ir/hlo_sharding.h"
Expand Down Expand Up @@ -162,7 +163,7 @@ double ClusterEnvironment::ReshardingCostMixedMeshShape(
}
if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i],
(*src_tensor_dim_to_mesh_axis)[i])) {
// do nothing; the src is sharded more than the dest
// do nothing; the dst is sharded more than the src
continue;
}
if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i],
Expand Down Expand Up @@ -231,17 +232,16 @@ double ClusterEnvironment::CollectivePermuteCost(
// Overestimate the cost of replicating a tensor by decomposing the resharding
// operation as an all-gather on all mesh dimensions.
double ClusterEnvironment::OverestimateReplicationCost(
const Shape& shape, const HloSharding& src_spec,
const Shape& shape, const HloSharding& src_sharding,
const DeviceMesh& device_mesh) const {
if (src_spec.IsTileMaximal() || src_spec.IsManual()) {
// TODO(b/238210866) Do not use kInfinityCost.
return kInfinityCost;
if (src_sharding.IsReplicated()) {
return 0;
}
int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_spec);
int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_sharding);
double cost = 0.0;
for (size_t i = 0; i < device_mesh.num_dimensions(); ++i) {
auto this_cost = this->AllGatherCost(bytes_moved, i);
cost += this_cost;
cost += src_sharding.IsTileMaximal() ? this->AllReduceCost(bytes_moved, i)
: this->AllGatherCost(bytes_moved, i);
bytes_moved *= device_mesh.dimensions()[i];
}
return cost;
Expand Down
Loading

0 comments on commit 5b5e9fd

Please sign in to comment.