From 75ab16dc1cbf3c56d2ca6a271c49934a77dc570e Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 1 Aug 2024 14:59:53 -0700 Subject: [PATCH] #sdy Add custom-call sharding rule registry. PiperOrigin-RevId: 658549415 --- .../dialect/sdy/transforms/propagation/BUILD | 30 +++++ .../custom_call_sharding_registry.cc | 61 ++++++++++ .../custom_call_sharding_registry.h | 108 ++++++++++++++++++ .../custom_call_sharding_registry_test.cc | 57 +++++++++ 4 files changed, 256 insertions(+) create mode 100644 shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.cc create mode 100644 shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.h create mode 100644 shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry_test.cc diff --git a/shardy/dialect/sdy/transforms/propagation/BUILD b/shardy/dialect/sdy/transforms/propagation/BUILD index 51f27f5..13257a0 100644 --- a/shardy/dialect/sdy/transforms/propagation/BUILD +++ b/shardy/dialect/sdy/transforms/propagation/BUILD @@ -264,3 +264,33 @@ cc_test( "@llvm-project//mlir:Support", ], ) + +cc_library( + name = "custom_call_sharding_registry", + srcs = ["custom_call_sharding_registry.cc"], + hdrs = ["custom_call_sharding_registry.h"], + deps = [ + "//shardy/dialect/sdy/ir:dialect", + "//shardy/dialect/sdy/transforms/common:macros", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_test( + name = "custom_call_sharding_registry_test", + srcs = ["custom_call_sharding_registry_test.cc"], + deps = [ + ":custom_call_sharding_registry", + ":op_sharding_rule_registry", + "//shardy/dialect/sdy/ir:dialect", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) diff --git a/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.cc b/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.cc new file mode 100644 index 0000000..5fb4c9a --- /dev/null +++ b/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.cc @@ -0,0 +1,61 @@ +/* Copyright 2024 The Shardy Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.h" + +#include +#include +#include + +#include "llvm/Support/ManagedStatic.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace sdy { + +namespace { +static llvm::ManagedStatic cc_registry; +} // namespace + +CustomCallShardingRegistry& CustomCallShardingRegistry::GetRegistry() { + return *cc_registry; +} + +std::optional +CustomCallShardingRegistry::GetShardingRuleCallBack( + const std::string& op_name) { + return GetCallBack(op_name); +} + +LogicalResult CustomCallShardingRegistry::RegisterShardingRuleCallBack( + const std::string& op_name, ShardingRuleCallBack call_back_func) { + return RegisterCallBack(op_name, + std::move(call_back_func)); +} + +std::optional +CustomCallShardingRegistry::GetShardingPropagationCallBack( + const std::string& op_name) { + return GetCallBack(op_name); +} + +LogicalResult CustomCallShardingRegistry::RegisterShardingPropagationCallBack( + const std::string& op_name, ShardingPropagationCallBack call_back_func) { + return RegisterCallBack( + op_name, std::move(call_back_func)); +} + +} // namespace sdy +} // namespace mlir diff --git a/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.h b/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.h new file mode 100644 index 0000000..15b72cc --- /dev/null +++ b/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.h @@ -0,0 +1,108 @@ +/* Copyright 2024 The Shardy Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef SHARDY_DIALECT_SDY_TRANSFORMS_PROPAGATION_CUSTOM_CALL_SHARDING_REGISTRY_H_ +#define SHARDY_DIALECT_SDY_TRANSFORMS_PROPAGATION_CUSTOM_CALL_SHARDING_REGISTRY_H_ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Mutex.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" +#include "shardy/dialect/sdy/ir/dialect.h" +namespace mlir { +namespace sdy { + +// A registry for custom-call sharding information. +// +// This registry records the name of a custom-call and its corresponding +// callback function that either inspects the custom-call op to produce an +// `OpShardingRuleAttr` or update the `TensorShardingAttr` for the operands +// and results of the op. +// +// A process can have at most one globally shared instance of this object. This +// instance is created lazily the first time the registry is accessed. +// +// This class is thread-safe. +class CustomCallShardingRegistry { + public: + // A ShardingRuleCallBack is a function that inspects a custom-call op and + // returns an `OpShardingRuleAttr` that specifies the sharding rule for the + // op. + using ShardingRuleCallBack = + std::function; + // A ShardingPropagationCallBack is a function that takes a custom-call op + // and updates the `TensorShardingAttr` for the operands and results of the + // op. + using ShardingPropagationCallBack = + std::function; + // A ShardingRegistry is either a ShardingRuleCallBack or a + // ShardingPropagationCallBack. + using ShardingRegistry = + std::variant; + + static LogicalResult RegisterShardingRuleCallBack( + const std::string& op_name, ShardingRuleCallBack call_back_func); + static std::optional GetShardingRuleCallBack( + const std::string& op_name); + + static LogicalResult RegisterShardingPropagationCallBack( + const std::string& op_name, ShardingPropagationCallBack call_back_func); + static std::optional + GetShardingPropagationCallBack(const std::string& op_name); + + private: + static CustomCallShardingRegistry& GetRegistry(); + + template + static std::optional GetCallBack(const std::string& op_name) { + CustomCallShardingRegistry& registry = GetRegistry(); + llvm::sys::ScopedLock scopedLock(registry.mutex_); + if (auto iter = registry.name_to_call_back_.find(op_name); + iter != registry.name_to_call_back_.end()) { + if (std::holds_alternative(iter->second)) { + return std::get(iter->second); + } + } + return std::nullopt; + } + + template + static LogicalResult RegisterCallBack(const std::string& op_name, + T call_back_func) { + CustomCallShardingRegistry& registry = GetRegistry(); + llvm::sys::ScopedLock scopedLock(registry.mutex_); + auto [it, emplaced] = registry.name_to_call_back_.try_emplace( + op_name, std::move(call_back_func)); + if (!emplaced) { + return failure(); + } + return success(); + } + + private: + llvm::sys::Mutex mutex_; + std::unordered_map name_to_call_back_; +}; + +} // namespace sdy +} // namespace mlir + +#endif // SHARDY_DIALECT_SDY_TRANSFORMS_PROPAGATION_CUSTOM_CALL_SHARDING_REGISTRY_H_ diff --git a/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry_test.cc b/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry_test.cc new file mode 100644 index 0000000..c04ef93 --- /dev/null +++ b/shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2024 The Shardy Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "shardy/dialect/sdy/transforms/propagation/custom_call_sharding_registry.h" + +#include + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include + +namespace mlir { +namespace sdy { + +namespace { + +using CustomCallShardingRegistryTest = ::testing::Test; + +TEST_F(CustomCallShardingRegistryTest, SimplyRegister) { + auto callBackFunc = [](mlir::Operation* op) { return OpShardingRuleAttr(); }; + const char kCustomCallOpName[] = "sdy_testonly1"; + + std::optional queryResult = + CustomCallShardingRegistry::GetShardingRuleCallBack(kCustomCallOpName); + EXPECT_FALSE(queryResult.has_value()); + + LogicalResult registerResult = + CustomCallShardingRegistry::RegisterShardingRuleCallBack( + kCustomCallOpName, callBackFunc); + EXPECT_TRUE(registerResult.succeeded()); + + queryResult = + CustomCallShardingRegistry::GetShardingRuleCallBack(kCustomCallOpName); + EXPECT_TRUE(queryResult.has_value()); + + registerResult = CustomCallShardingRegistry::RegisterShardingRuleCallBack( + kCustomCallOpName, callBackFunc); + EXPECT_TRUE(registerResult.failed()); +} + +} // namespace + +} // namespace sdy +} // namespace mlir