Skip to content

Commit

Permalink
Use switch case without default in LiteralToValue
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Sep 17, 2024
1 parent f6b6175 commit 914db47
Showing 1 changed file with 56 additions and 55 deletions.
111 changes: 56 additions & 55 deletions xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,66 +236,67 @@ TensorOrMemref<T> ArrayLiteralToTensor(const xla::Literal& literal) {
} // namespace

absl::StatusOr<InterpreterValue> LiteralToValue(const xla::Literal& literal) {
if (literal.shape().IsTuple()) {
auto elements = literal.Clone().DecomposeTuple();
Tuple result;
for (auto& element : elements) {
TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(element));
result.values.push_back(
std::make_shared<InterpreterValue>(std::move(converted)));
}
return {{result}};
}

if (literal.shape().IsToken()) {
return absl::UnimplementedError("token arguments are not implemented");
}

if (literal.shape().IsArray()) {
auto type = literal.shape().element_type();
if (xla::primitive_util::IsF8Type(type)) {
auto type = literal.shape().element_type();
switch (type) {
case xla::PRED:
return {{ArrayLiteralToTensor<bool>(literal)}};
case xla::S8:
return {{ArrayLiteralToTensor<int8_t>(literal)}};
case xla::S16:
return {{ArrayLiteralToTensor<int16_t>(literal)}};
case xla::S32:
return {{ArrayLiteralToTensor<int32_t>(literal)}};
case xla::S64:
return {{ArrayLiteralToTensor<int64_t>(literal)}};
case xla::U8:
return {{ArrayLiteralToTensor<uint8_t>(literal)}};
case xla::U16:
return {{ArrayLiteralToTensor<uint16_t>(literal)}};
case xla::U32:
return {{ArrayLiteralToTensor<uint32_t>(literal)}};
case xla::U64:
return {{ArrayLiteralToTensor<uint64_t>(literal)}};
case xla::F32:
return {{ArrayLiteralToTensor<float>(literal)}};
case xla::F64:
return {{ArrayLiteralToTensor<double>(literal)}};
case xla::C64:
return {{ArrayLiteralToTensor<std::complex<float>>(literal)}};
case xla::C128:
return {{ArrayLiteralToTensor<std::complex<double>>(literal)}};
case xla::S2:
case xla::S4:
case xla::U2:
case xla::U4:
case xla::F16:
case xla::BF16:
case xla::F8E4M3FN:
case xla::F8E4M3FNUZ:
case xla::F8E4M3B11FNUZ:
case xla::F8E5M2:
case xla::F8E5M2FNUZ:
case xla::PrimitiveType::TOKEN:
return absl::UnimplementedError(
absl::StrCat(xla::primitive_util::LowercasePrimitiveTypeName(type),
" not implemented"));
case xla::PrimitiveType::TUPLE: {
auto elements = literal.Clone().DecomposeTuple();
Tuple result;
for (auto& element : elements) {
TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(element));
result.values.push_back(
std::make_shared<InterpreterValue>(std::move(converted)));
}
return {{result}};
}
switch (type) {
case xla::PRED:
return {{ArrayLiteralToTensor<bool>(literal)}};
case xla::S8:
return {{ArrayLiteralToTensor<int8_t>(literal)}};
case xla::S16:
return {{ArrayLiteralToTensor<int16_t>(literal)}};
case xla::S32:
return {{ArrayLiteralToTensor<int32_t>(literal)}};
case xla::S64:
return {{ArrayLiteralToTensor<int64_t>(literal)}};
case xla::U8:
return {{ArrayLiteralToTensor<uint8_t>(literal)}};
case xla::U16:
return {{ArrayLiteralToTensor<uint16_t>(literal)}};
case xla::U32:
return {{ArrayLiteralToTensor<uint32_t>(literal)}};
case xla::U64:
return {{ArrayLiteralToTensor<uint64_t>(literal)}};
case xla::F16:
return absl::UnimplementedError("F16 not implemented");
case xla::F32:
return {{ArrayLiteralToTensor<float>(literal)}};
case xla::BF16:
return absl::UnimplementedError("BF16 not implemented");
case xla::F64:
return {{ArrayLiteralToTensor<double>(literal)}};
case xla::C64:
return {{ArrayLiteralToTensor<std::complex<float>>(literal)}};
case xla::C128:
return {{ArrayLiteralToTensor<std::complex<double>>(literal)}};
default:
// Fallthrough intended.
break;
}
case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID:
case xla::PrimitiveType::OPAQUE_TYPE:
case xla::PrimitiveType::PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
case xla::PrimitiveType::PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
return absl::InvalidArgumentError(
absl::StrCat("Unexpected literal type: ",
xla::primitive_util::LowercasePrimitiveTypeName(type)));
}

return absl::InvalidArgumentError("unexpected literal type");
}

absl::StatusOr<InterpreterValue> LiteralToValue(
Expand Down

0 comments on commit 914db47

Please sign in to comment.