Skip to content

Commit

Permalink
Remove some unnecessary stream_executor/platform dependencies.
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17295 from ROCm:ci_rv_clang b0f316408f62052125973cfff6f9371a91e84464
PiperOrigin-RevId: 683241935
  • Loading branch information
klucke authored and tensorflower-gardener committed Oct 7, 2024
1 parent 456e264 commit 208b1a1
Show file tree
Hide file tree
Showing 25 changed files with 534 additions and 263 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ tf_cc_binary(
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
"//tensorflow/compiler/mlir/tf2xla/tests/registration:graph_to_tf_executor_registration",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
"@com_google_absl//absl/strings",
Expand Down
14 changes: 0 additions & 14 deletions tensorflow/compiler/mlir/tensorflow/translate/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,6 @@ cc_library(
],
)

tf_cc_test(
name = "tf_mlir_translate_registration_test",
size = "small",
srcs = ["tf_mlir_translate_registration_test.cc"],
deps = [
":translate_registration",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:TranslateLib",
],
)

cc_library(
name = "export_tf_dialect_op",
srcs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,60 +17,18 @@ limitations under the License.
// to satisfy the API of MLIR pass registration. In order to do this, the
// command-line option header is pulled in.

#include <memory>
#include <utility>

#include "absl/container/flat_hash_set.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "xla/client/client_library.h"
#include "xla/client/compile_only_client.h"
#include "xla/stream_executor/host/host_platform_id.h"
#include "xla/stream_executor/platform_manager.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tsl/platform/protobuf.h"

namespace mlir {
using tsl::Status;
using tsl::StatusOr;

static constexpr char kMlirToGraphCompilationCheckName[] =
"mlir-to-graph-compilation-check";
// Use CPU arbitrarily in order to check that a graph compiles at all
static constexpr char kArbitraryDeviceName[] = "XLA_CPU_JIT";

namespace {
inline absl::string_view StringRefToView(llvm::StringRef ref) {
return {ref.data(), ref.size()};
}
} // namespace

static OwningOpRef<mlir::ModuleOp> GraphdefToMlirTranslateFunction(
llvm::StringRef input, MLIRContext* context) {
tensorflow::GraphdefToMlirOptions options{
debug_info_file, xla_compile_device_type,
prune_unused_nodes, convert_legacy_fed_inputs,
graph_as_function, upgrade_legacy,
enable_shape_inference, unconditionally_use_set_output_shapes,
enable_soft_placement, set_original_tf_func_name};

auto module_or = tensorflow::GraphdefToMlirTranslateFunction(
input, input_arrays, input_dtypes, input_shapes, output_arrays,
control_output_arrays, options, context);
if (!module_or.status().ok()) return nullptr;
return std::move(module_or).value();
}

static TranslateToMLIRRegistration GraphdefToMlirTranslate(
"graphdef-to-mlir", "graphdef-to-mlir", GraphdefToMlirTranslateFunction);

static OwningOpRef<mlir::ModuleOp> GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, MLIRContext* context) {
Expand All @@ -90,112 +48,4 @@ static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate(
"graphdef-to-splatted-mlir", "graphdef-to-splatted-mlir",
GraphdefToSplattedMlirTranslateFunction);

static Status CompileGraph(tensorflow::Graph* graph,
xla::CompileOnlyClient* client) {
if (!graph || !client) {
return Status(absl::StatusCode::kInvalidArgument,
"Invalid graph or client");
}

tensorflow::FunctionDefLibrary flib;
auto flib_def = std::make_unique<tensorflow::FunctionLibraryDefinition>(
tensorflow::OpRegistry::Global(), flib);

tensorflow::XlaCompiler::Options options;
options.device_type = tensorflow::DeviceType(kArbitraryDeviceName);
options.client = client;
options.flib_def = flib_def.get();
tensorflow::XlaCompiler compiler(options);

std::unique_ptr<tensorflow::Graph> graph_copy(
new tensorflow::Graph(tensorflow::OpRegistry::Global()));
tensorflow::CopyGraph(*graph, graph_copy.get());

tensorflow::XlaCompiler::CompileOptions compile_options;
tensorflow::XlaCompiler::CompilationResult result;
return compiler.CompileGraph(compile_options,
kMlirToGraphCompilationCheckName,
std::move(graph_copy), {}, &result);
}

static LogicalResult MlirToGraphTranslateFunction(ModuleOp module,
llvm::raw_ostream& output) {
if (!module) return failure();

tensorflow::GraphExportConfig confs;
confs.export_entry_func_to_flib = export_entry_func_to_flib;
confs.export_original_tf_func_name = export_original_tf_func_name;

std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def;
auto graph =
std::make_unique<tensorflow::Graph>(tensorflow::OpRegistry::Global());
absl::flat_hash_set<tensorflow::Node*> control_ret_nodes;
auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph(
module, confs, &graph, flib_def.get(), &control_ret_nodes);
if (!status.ok()) {
LOG(ERROR) << "Export to Graph failed: " << status;
return mlir::failure();
}

// Use Host platform, which should always exist, to make sure graphs compile.
auto platform = stream_executor::PlatformManager::PlatformWithId(
stream_executor::host::kHostPlatformId);
auto client =
xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform.value());

tensorflow::XlaOpRegistry::RegisterCompilationKernels();

// Verify that the resulting graph can compile.
if (!CompileGraph(graph.get(), client.value()).ok()) {
return mlir::failure();
}

auto graphdef = std::make_unique<tensorflow::GraphDef>();
// Print the graph to the output after going through GraphDef conversion.
// The DumpGraphToFile would do this anyway so just skip straight to it.
graph->ToGraphDef(graphdef.get());
output << tsl::LegacyUnredactedDebugString(*graphdef);

return success();
}

static TranslateFromMLIRRegistration mlir_to_graph_translate(
/*name=*/"mlir-to-graph", /*description=*/"convert mlir to graph",
MlirToGraphTranslateFunction, [](DialectRegistry& registry) {
mlir::RegisterAllTensorFlowDialects(registry);
});

static LogicalResult MlirToGraphdefTranslateFunction(
ModuleOp module, llvm::raw_ostream& output) {
if (!module) return failure();

tensorflow::GraphExportConfig confs;
confs.export_entry_func_to_flib = export_entry_func_to_flib;
confs.export_original_tf_func_name = export_original_tf_func_name;

tensorflow::FunctionLibraryDefinition flib_def(
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
auto graph =
std::make_unique<tensorflow::Graph>(tensorflow::OpRegistry::Global());
absl::flat_hash_set<tensorflow::Node*> control_ret_nodes;

auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph(
module, confs, &graph, &flib_def, &control_ret_nodes);
if (!status.ok()) {
LOG(ERROR) << "Export to Graph failed: " << status;
return mlir::failure();
}

tensorflow::GraphDef graphdef;
graph->ToGraphDef(&graphdef);
output << tsl::LegacyUnredactedDebugString(graphdef);
return success();
}

static TranslateFromMLIRRegistration mlir_to_graphdef_translate(
"mlir-to-graphdef", "mlir-to-graphdef", MlirToGraphdefTranslateFunction,
[](DialectRegistry& registry) {
mlir::RegisterAllTensorFlowDialects(registry);
});

} // namespace mlir
59 changes: 59 additions & 0 deletions tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "graph_to_tf_executor_registration",
srcs = [
"graph_to_tf_executor_registration.cc",
],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow/translate:translate_cl_options",
"//tensorflow/compiler/mlir/tensorflow/translate:translate_lib",
"//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TranslateLib",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:status",
"@local_xla//xla/client:client_library",
"@local_xla//xla/client:compile_only_client",
"@local_xla//xla/service/cpu:cpu_compiler",
"@local_xla//xla/service/cpu:cpu_transfer_manager",
"@local_xla//xla/stream_executor",
"@local_xla//xla/stream_executor/host:host_platform",
"@local_xla//xla/stream_executor/host:host_platform_id",
],
alwayslink = 1,
)

tf_cc_test(
name = "graph_to_tf_executor_registration_test",
size = "small",
srcs = ["graph_to_tf_executor_registration_test.cc"],
deps = [
":graph_to_tf_executor_registration",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TranslateLib",
],
)
Loading

0 comments on commit 208b1a1

Please sign in to comment.