diff --git a/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc b/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc index 81173bfd18456..b61f34ea8f2c3 100644 --- a/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc +++ b/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc @@ -236,66 +236,67 @@ TensorOrMemref ArrayLiteralToTensor(const xla::Literal& literal) { } // namespace absl::StatusOr LiteralToValue(const xla::Literal& literal) { - if (literal.shape().IsTuple()) { - auto elements = literal.Clone().DecomposeTuple(); - Tuple result; - for (auto& element : elements) { - TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(element)); - result.values.push_back( - std::make_shared(std::move(converted))); - } - return {{result}}; - } - - if (literal.shape().IsToken()) { - return absl::UnimplementedError("token arguments are not implemented"); - } - - if (literal.shape().IsArray()) { - auto type = literal.shape().element_type(); - if (xla::primitive_util::IsF8Type(type)) { + auto type = literal.shape().element_type(); + switch (type) { + case xla::PRED: + return {{ArrayLiteralToTensor(literal)}}; + case xla::S8: + return {{ArrayLiteralToTensor(literal)}}; + case xla::S16: + return {{ArrayLiteralToTensor(literal)}}; + case xla::S32: + return {{ArrayLiteralToTensor(literal)}}; + case xla::S64: + return {{ArrayLiteralToTensor(literal)}}; + case xla::U8: + return {{ArrayLiteralToTensor(literal)}}; + case xla::U16: + return {{ArrayLiteralToTensor(literal)}}; + case xla::U32: + return {{ArrayLiteralToTensor(literal)}}; + case xla::U64: + return {{ArrayLiteralToTensor(literal)}}; + case xla::F32: + return {{ArrayLiteralToTensor(literal)}}; + case xla::F64: + return {{ArrayLiteralToTensor(literal)}}; + case xla::C64: + return {{ArrayLiteralToTensor>(literal)}}; + case xla::C128: + return {{ArrayLiteralToTensor>(literal)}}; + case xla::S2: + case xla::S4: + case xla::U2: + case xla::U4: + case xla::F16: + case xla::BF16: + case xla::F8E4M3FN: + case xla::F8E4M3FNUZ: + case xla::F8E4M3B11FNUZ: + case xla::F8E5M2: + case xla::F8E5M2FNUZ: + case xla::PrimitiveType::TOKEN: return absl::UnimplementedError( absl::StrCat(xla::primitive_util::LowercasePrimitiveTypeName(type), " not implemented")); + case xla::PrimitiveType::TUPLE: { + auto elements = literal.Clone().DecomposeTuple(); + Tuple result; + for (auto& element : elements) { + TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(element)); + result.values.push_back( + std::make_shared(std::move(converted))); + } + return {{result}}; } - switch (type) { - case xla::PRED: - return {{ArrayLiteralToTensor(literal)}}; - case xla::S8: - return {{ArrayLiteralToTensor(literal)}}; - case xla::S16: - return {{ArrayLiteralToTensor(literal)}}; - case xla::S32: - return {{ArrayLiteralToTensor(literal)}}; - case xla::S64: - return {{ArrayLiteralToTensor(literal)}}; - case xla::U8: - return {{ArrayLiteralToTensor(literal)}}; - case xla::U16: - return {{ArrayLiteralToTensor(literal)}}; - case xla::U32: - return {{ArrayLiteralToTensor(literal)}}; - case xla::U64: - return {{ArrayLiteralToTensor(literal)}}; - case xla::F16: - return absl::UnimplementedError("F16 not implemented"); - case xla::F32: - return {{ArrayLiteralToTensor(literal)}}; - case xla::BF16: - return absl::UnimplementedError("BF16 not implemented"); - case xla::F64: - return {{ArrayLiteralToTensor(literal)}}; - case xla::C64: - return {{ArrayLiteralToTensor>(literal)}}; - case xla::C128: - return {{ArrayLiteralToTensor>(literal)}}; - default: - // Fallthrough intended. - break; - } + case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID: + case xla::PrimitiveType::OPAQUE_TYPE: + case xla::PrimitiveType::PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: + case xla::PrimitiveType::PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: + return absl::InvalidArgumentError( + absl::StrCat("Unexpected literal type: ", + xla::primitive_util::LowercasePrimitiveTypeName(type))); } - - return absl::InvalidArgumentError("unexpected literal type"); } absl::StatusOr LiteralToValue(