Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use switch case without default in LiteralToValue #17279

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 56 additions & 55 deletions xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,66 +236,67 @@ TensorOrMemref<T> ArrayLiteralToTensor(const xla::Literal& literal) {
} // namespace

absl::StatusOr<InterpreterValue> LiteralToValue(const xla::Literal& literal) {
if (literal.shape().IsTuple()) {
auto elements = literal.Clone().DecomposeTuple();
Tuple result;
for (auto& element : elements) {
TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(element));
result.values.push_back(
std::make_shared<InterpreterValue>(std::move(converted)));
}
return {{result}};
}

if (literal.shape().IsToken()) {
return absl::UnimplementedError("token arguments are not implemented");
}

if (literal.shape().IsArray()) {
auto type = literal.shape().element_type();
if (xla::primitive_util::IsF8Type(type)) {
auto type = literal.shape().element_type();
switch (type) {
case xla::PRED:
return {{ArrayLiteralToTensor<bool>(literal)}};
case xla::S8:
return {{ArrayLiteralToTensor<int8_t>(literal)}};
case xla::S16:
return {{ArrayLiteralToTensor<int16_t>(literal)}};
case xla::S32:
return {{ArrayLiteralToTensor<int32_t>(literal)}};
case xla::S64:
return {{ArrayLiteralToTensor<int64_t>(literal)}};
case xla::U8:
return {{ArrayLiteralToTensor<uint8_t>(literal)}};
case xla::U16:
return {{ArrayLiteralToTensor<uint16_t>(literal)}};
case xla::U32:
return {{ArrayLiteralToTensor<uint32_t>(literal)}};
case xla::U64:
return {{ArrayLiteralToTensor<uint64_t>(literal)}};
case xla::F32:
return {{ArrayLiteralToTensor<float>(literal)}};
case xla::F64:
return {{ArrayLiteralToTensor<double>(literal)}};
case xla::C64:
return {{ArrayLiteralToTensor<std::complex<float>>(literal)}};
case xla::C128:
return {{ArrayLiteralToTensor<std::complex<double>>(literal)}};
case xla::S2:
case xla::S4:
case xla::U2:
case xla::U4:
case xla::F16:
case xla::BF16:
case xla::F8E4M3FN:
case xla::F8E4M3FNUZ:
case xla::F8E4M3B11FNUZ:
case xla::F8E5M2:
case xla::F8E5M2FNUZ:
case xla::PrimitiveType::TOKEN:
return absl::UnimplementedError(
absl::StrCat(xla::primitive_util::LowercasePrimitiveTypeName(type),
" not implemented"));
case xla::PrimitiveType::TUPLE: {
auto elements = literal.Clone().DecomposeTuple();
Tuple result;
for (auto& element : elements) {
TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(element));
result.values.push_back(
std::make_shared<InterpreterValue>(std::move(converted)));
}
return {{result}};
}
switch (type) {
case xla::PRED:
return {{ArrayLiteralToTensor<bool>(literal)}};
case xla::S8:
return {{ArrayLiteralToTensor<int8_t>(literal)}};
case xla::S16:
return {{ArrayLiteralToTensor<int16_t>(literal)}};
case xla::S32:
return {{ArrayLiteralToTensor<int32_t>(literal)}};
case xla::S64:
return {{ArrayLiteralToTensor<int64_t>(literal)}};
case xla::U8:
return {{ArrayLiteralToTensor<uint8_t>(literal)}};
case xla::U16:
return {{ArrayLiteralToTensor<uint16_t>(literal)}};
case xla::U32:
return {{ArrayLiteralToTensor<uint32_t>(literal)}};
case xla::U64:
return {{ArrayLiteralToTensor<uint64_t>(literal)}};
case xla::F16:
return absl::UnimplementedError("F16 not implemented");
case xla::F32:
return {{ArrayLiteralToTensor<float>(literal)}};
case xla::BF16:
return absl::UnimplementedError("BF16 not implemented");
case xla::F64:
return {{ArrayLiteralToTensor<double>(literal)}};
case xla::C64:
return {{ArrayLiteralToTensor<std::complex<float>>(literal)}};
case xla::C128:
return {{ArrayLiteralToTensor<std::complex<double>>(literal)}};
default:
// Fallthrough intended.
break;
}
case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID:
case xla::PrimitiveType::OPAQUE_TYPE:
case xla::PrimitiveType::PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
case xla::PrimitiveType::PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
return absl::InvalidArgumentError(
absl::StrCat("Unexpected literal type: ",
xla::primitive_util::LowercasePrimitiveTypeName(type)));
}

return absl::InvalidArgumentError("unexpected literal type");
}

absl::StatusOr<InterpreterValue> LiteralToValue(
Expand Down
Loading