Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove if_cuda_is_configured and if_rocm_is_configured from command_buffer_cmd and custom_call_thunk #17201

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 108 additions & 104 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ cc_library(
name = "ir_emitter_context",
srcs = ["ir_emitter_context.cc"],
hdrs = ["ir_emitter_context.h"],
tags = ["gpu"],
deps = [
":execution_stream_assignment",
":gpu_constants",
Expand Down Expand Up @@ -308,6 +309,7 @@ cc_library(
]) + if_rocm_hipblaslt([
"TF_HIPBLASLT=1",
]),
tags = ["gpu"],
deps = [
":backend_configs_cc",
":cublas_cudnn",
Expand Down Expand Up @@ -350,12 +352,14 @@ cc_library(
"//xla/service/gpu/kernels:custom_kernel",
"//xla/service/gpu/kernels:topk_custom_kernel",
"//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
"//xla/service/gpu/runtime:cholesky_thunk",
"//xla/service/gpu/runtime:command_buffer_cmd",
"//xla/service/gpu/runtime:command_buffer_cmd_emitter",
"//xla/service/gpu/runtime:command_buffer_thunk",
"//xla/service/gpu/runtime:conditional_thunk",
"//xla/service/gpu/runtime:convolution_thunk",
"//xla/service/gpu/runtime:copy_thunk",
"//xla/service/gpu/runtime:cub_sort_thunk",
"//xla/service/gpu/runtime:cudnn_thunk",
"//xla/service/gpu/runtime:custom_call_thunk",
"//xla/service/gpu/runtime:fft_thunk",
Expand All @@ -379,6 +383,7 @@ cc_library(
"//xla/service/gpu/runtime:send_recv_thunk",
"//xla/service/gpu/runtime:sequential_thunk",
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/runtime:triangular_solve_thunk",
"//xla/service/gpu/runtime:wait_for_streams_thunk",
"//xla/service/gpu/runtime:while_thunk",
"//xla/service/llvm_ir:buffer_assignment_util",
Expand Down Expand Up @@ -421,11 +426,7 @@ cc_library(
"@tsl//tsl/platform:human_readable_json",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/protobuf:dnn_proto_cc",
] + if_gpu_is_configured([
"//xla/service/gpu/runtime:cholesky_thunk",
"//xla/service/gpu/runtime:cub_sort_thunk",
"//xla/service/gpu/runtime:triangular_solve_thunk",
]) + if_rocm_is_configured([
] + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
]),
)
Expand All @@ -442,7 +443,7 @@ cc_library(
"ir_emitter.h",
"ir_emitter_nested.h",
],
copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]),
tags = ["gpu"],
deps = [
":backend_configs_cc",
":hlo_to_ir_bindings",
Expand Down Expand Up @@ -486,11 +487,9 @@ cc_library(

cc_library(
name = "triton_call",
srcs = if_gpu_is_configured(["triton_call.cc"]),
srcs = ["triton_call.cc"],
hdrs = ["triton_call.h"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
tags = ["gpu"],
deps = [
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:IR",
Expand Down Expand Up @@ -556,6 +555,7 @@ cc_library(
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
tags = ["gpu"],
deps = [
":backend_configs_cc",
":buffer_allocations",
Expand Down Expand Up @@ -1252,6 +1252,7 @@ cc_library(
"compile_module_to_llvm_ir.h",
],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
tags = ["gpu"],
deps = [
":executable_proto_cc",
":execution_stream_assignment",
Expand Down Expand Up @@ -1303,6 +1304,7 @@ cc_library(
name = "fusion_pipeline",
srcs = ["fusion_pipeline.cc"],
hdrs = ["fusion_pipeline.h"],
tags = ["gpu"],
deps = [
"//xla:xla_proto_cc",
"//xla/service:cpu_gpu_shape_verifier",
Expand Down Expand Up @@ -1350,14 +1352,14 @@ cc_library(

cc_library(
name = "gpu_compiler",
srcs = if_gpu_is_configured([
srcs = [
"gpu_compiler.cc",
]),
hdrs = if_gpu_is_configured([
],
hdrs = [
"gpu_compiler.h",
]),
deps = if_gpu_is_configured([
# go/keep-sorted start prefix_order=":,,
],
tags = ["gpu"],
deps = [
":buffer_sharing",
":compile_module_to_llvm_ir",
":conv_layout_normalization",
Expand All @@ -1384,83 +1386,17 @@ cc_library(
":reduction_utils",
":runtime_intrinsics",
":stream_executor_util",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TransformUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"//xla:autotune_results_proto_cc",
"//xla:debug_options_flags",
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner",
"//xla/service/gpu/fusions/triton:triton_support",
"//xla/service/gpu/model:gpu_cost_model_stats_collection",
"//xla/service/gpu/model:gpu_hlo_cost_analysis",
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/transforms:algebraic_simplifier",
"//xla/service/gpu/transforms:algorithm_checker",
"//xla/service/gpu/transforms:all_gather_dynamic_slice_simplifier",
"//xla/service/gpu/transforms:all_gather_optimizer",
"//xla/service/gpu/transforms:all_reduce_blueconnect",
"//xla/service/gpu/transforms:all_reduce_splitter",
"//xla/service/gpu/transforms:async_collective_annotator",
"//xla/service/gpu/transforms:async_wrapper",
"//xla/service/gpu/transforms:collective_permute_cycle_decomposer",
"//xla/service/gpu/transforms:collective_permute_valid_iteration_annotator",
"//xla/service/gpu/transforms:command_buffer_scheduling",
"//xla/service/gpu/transforms:conv_rewriter",
"//xla/service/gpu/transforms:convert_async_collectives_to_sync",
"//xla/service/gpu/transforms:cudnn_custom_call_converter",
"//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
"//xla/service/gpu/transforms:dot_dimension_sorter",
"//xla/service/gpu/transforms:dot_operand_converter",
"//xla/service/gpu/transforms:double_buffer_loop_unrolling",
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
"//xla/service/gpu/transforms:fusion_wrapper",
"//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter",
"//xla/service/gpu/transforms:gemm_fusion",
"//xla/service/gpu/transforms:gemm_rewriter",
"//xla/service/gpu/transforms:gemv_rewriter",
"//xla/service/gpu/transforms:layout_assignment",
"//xla/service/gpu/transforms:move_copy_to_users",
"//xla/service/gpu/transforms:pipelined_p2p_rewriter",
"//xla/service/gpu/transforms:reduce_scatter_creator",
"//xla/service/gpu/transforms:reduction_degenerate_dim_remover",
"//xla/service/gpu/transforms:reduction_dimension_grouper",
"//xla/service/gpu/transforms:reduction_layout_normalizer",
"//xla/service/gpu/transforms:reduction_splitter",
"//xla/service/gpu/transforms:rename_fusions",
"//xla/service/gpu/transforms:sanitize_constant_names",
"//xla/service/gpu/transforms:scatter_expander",
"//xla/service/gpu/transforms:scatter_slice_simplifier",
"//xla/service/gpu/transforms:softmax_rewriter_triton",
"//xla/service/gpu/transforms:stream_attribute_annotator",
"//xla/service/gpu/transforms:stream_attribute_async_wrapper",
"//xla/service/gpu/transforms:topk_specializer",
"//xla/service/gpu/transforms:topk_splitter",
"//xla/service/gpu/transforms:transpose_dimension_grouper",
"//xla/service/gpu/transforms:tree_reduction_rewriter",
"//xla/service/gpu/transforms:triton_fusion_numerics_verifier",
"//xla/service/gpu/transforms:windowed_einsum_handler",
"//xla/service/llvm_ir:llvm_util",
"//xla/service/spmd:collective_permute_motion",
"//xla/service:algebraic_simplifier",
"//xla/service:all_gather_broadcast_reorder",
"//xla/service:all_gather_combiner",
Expand Down Expand Up @@ -1561,24 +1497,90 @@ cc_library(
"//xla/service:while_loop_simplifier",
"//xla/service:while_loop_trip_count_annotator",
"//xla/service:zero_sized_hlo_elimination",
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner",
"//xla/service/gpu/fusions/triton:triton_support",
"//xla/service/gpu/model:gpu_cost_model_stats_collection",
"//xla/service/gpu/model:gpu_hlo_cost_analysis",
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/transforms:algebraic_simplifier",
"//xla/service/gpu/transforms:algorithm_checker",
"//xla/service/gpu/transforms:all_gather_dynamic_slice_simplifier",
"//xla/service/gpu/transforms:all_gather_optimizer",
"//xla/service/gpu/transforms:all_reduce_blueconnect",
"//xla/service/gpu/transforms:all_reduce_splitter",
"//xla/service/gpu/transforms:async_collective_annotator",
"//xla/service/gpu/transforms:async_wrapper",
"//xla/service/gpu/transforms:collective_permute_cycle_decomposer",
"//xla/service/gpu/transforms:collective_permute_valid_iteration_annotator",
"//xla/service/gpu/transforms:command_buffer_scheduling",
"//xla/service/gpu/transforms:conv_rewriter",
"//xla/service/gpu/transforms:convert_async_collectives_to_sync",
"//xla/service/gpu/transforms:cudnn_custom_call_converter",
"//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
"//xla/service/gpu/transforms:dot_dimension_sorter",
"//xla/service/gpu/transforms:dot_operand_converter",
"//xla/service/gpu/transforms:double_buffer_loop_unrolling",
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
"//xla/service/gpu/transforms:fusion_wrapper",
"//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter",
"//xla/service/gpu/transforms:gemm_fusion",
"//xla/service/gpu/transforms:gemm_rewriter",
"//xla/service/gpu/transforms:gemv_rewriter",
"//xla/service/gpu/transforms:layout_assignment",
"//xla/service/gpu/transforms:move_copy_to_users",
"//xla/service/gpu/transforms:pipelined_p2p_rewriter",
"//xla/service/gpu/transforms:reduce_scatter_creator",
"//xla/service/gpu/transforms:reduction_degenerate_dim_remover",
"//xla/service/gpu/transforms:reduction_dimension_grouper",
"//xla/service/gpu/transforms:reduction_layout_normalizer",
"//xla/service/gpu/transforms:reduction_splitter",
"//xla/service/gpu/transforms:rename_fusions",
"//xla/service/gpu/transforms:sanitize_constant_names",
"//xla/service/gpu/transforms:scatter_expander",
"//xla/service/gpu/transforms:scatter_slice_simplifier",
"//xla/service/gpu/transforms:softmax_rewriter_triton",
"//xla/service/gpu/transforms:stream_attribute_annotator",
"//xla/service/gpu/transforms:stream_attribute_async_wrapper",
"//xla/service/gpu/transforms:topk_specializer",
"//xla/service/gpu/transforms:topk_splitter",
"//xla/service/gpu/transforms:transpose_dimension_grouper",
"//xla/service/gpu/transforms:tree_reduction_rewriter",
"//xla/service/gpu/transforms:triton_fusion_numerics_verifier",
"//xla/service/gpu/transforms:windowed_einsum_handler",
"//xla/service/llvm_ir:llvm_util",
"//xla/service/spmd:collective_permute_motion",
"//xla/stream_executor",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/integrations:device_mem_allocator",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/stream_executor:dnn",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/integrations:device_mem_allocator",
"//xla/translate/hlo_to_mhlo:hlo_utils",
"//xla/translate/mhlo_to_hlo:location_exporter",
"//xla:autotune_results_proto_cc",
"//xla:debug_options_flags",
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TransformUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@tsl//tsl/lib/monitoring:counter",
"@tsl//tsl/platform:blocking_counter",
"@tsl//tsl/platform:casts",
Expand All @@ -1592,8 +1594,7 @@ cc_library(
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:scoped_annotation",
"@tsl//tsl/profiler/lib:traceme",
# go/keep-sorted end
]) + xla_internal(["service:export_hlo"]) + if_google([
] + xla_internal(["service:export_hlo"]) + if_google([
"//xla/hlo/experimental/auto_sharding",
]),
)
Expand Down Expand Up @@ -2072,6 +2073,7 @@ cc_library(
name = "gpu_hlo_schedule",
srcs = ["gpu_hlo_schedule.cc"],
hdrs = ["gpu_hlo_schedule.h"],
tags = ["gpu"],
deps = [
":backend_configs_cc",
":gpu_latency_hiding_scheduler",
Expand Down Expand Up @@ -2625,8 +2627,9 @@ cc_library(

cc_library(
name = "make_batch_pointers",
srcs = if_gpu_is_configured(["make_batch_pointers.cc"]),
hdrs = if_gpu_is_configured(["make_batch_pointers.h"]),
srcs = ["make_batch_pointers.cc"],
hdrs = ["make_batch_pointers.h"],
tags = ["gpu"],
deps = [
"//xla:types",
"//xla:util",
Expand Down Expand Up @@ -2943,6 +2946,7 @@ cc_library(
xla_cc_test(
name = "gpu_latency_hiding_scheduler_test",
srcs = ["gpu_latency_hiding_scheduler_test.cc"],
tags = ["gpu"],
deps = [
":gpu_hlo_schedule",
":gpu_latency_hiding_scheduler",
Expand Down
Loading
Loading