diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp index d45991f72d39..c89b0f957bee 100644 --- a/test/cpp/test_xla_generator.cpp +++ b/test/cpp/test_xla_generator.cpp @@ -1,15 +1,54 @@ +#include #include #include +#include + #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/xla_generator.h" namespace torch_xla { namespace cpp_test { +// Ensure PJRT is configured to a CPU backend for tests that touch the PJRT +// runtime. Optionally allow overriding the environment values by passing +// `pjrt_device` and/or `cpu_num_devices`. +static void EnsurePjrtCpuBackend(const char* pjrt_device = nullptr, + const char* cpu_num_devices = nullptr) { + // PJRT_DEVICE: override if provided, otherwise set default if not present + if (pjrt_device != nullptr && pjrt_device[0] != '\0') { + // Force override of any existing value + setenv("PJRT_DEVICE", pjrt_device, 1); + } else { + const char* pjrt = std::getenv("PJRT_DEVICE"); + if (pjrt == nullptr || pjrt[0] == '\0') { + // Use CPU backend with a single device by default. + setenv("PJRT_DEVICE", "CPU", 1); + } + } + + // CPU_NUM_DEVICES: override if provided, otherwise set default if not present + if (cpu_num_devices != nullptr && cpu_num_devices[0] != '\0') { + // Force override of any existing value + setenv("CPU_NUM_DEVICES", cpu_num_devices, 1); + } else { + const char* cpu_devices = std::getenv("CPU_NUM_DEVICES"); + if (cpu_devices == nullptr || cpu_devices[0] == '\0') { + // Default to a single CPU device. Preserve existing behavior of not + // overwriting if already present (use overwrite=0 to match previous + // semantics). + setenv("CPU_NUM_DEVICES", "1", 0); + } + } +} + // Test fixture for XLAGenerator tests class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { protected: + // Runs once before the test suite / test case to ensure PJRT is configured + // before any XLA runtime initialization happens in per-test SetUp(). + static void SetUpTestCase() { EnsurePjrtCpuBackend("CPU", "2"); } + void SetUp() { // Create a generator for XLA device 0 gen_ = at::make_generator(0); @@ -102,5 +141,122 @@ TEST_F(XLAGeneratorTest, Clone) { ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); } +TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) { + // Test getting default generator for device 0 + auto result = at::detail::GetDefaultXLAGenerator(0); + ASSERT_TRUE(result.ok()) << "Failed to get default generator: " + << result.status(); + + const at::Generator& default_gen = result.value(); + ASSERT_EQ(default_gen.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen.device().index(), 0); + + // Test getting default generator with -1 (should default to device 0) + auto result_default = at::detail::GetDefaultXLAGenerator(-1); + ASSERT_TRUE(result_default.ok()) + << "Failed to get default generator with -1: " << result_default.status(); + + const at::Generator& default_gen_neg1 = result_default.value(); + ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen_neg1.device().index(), 0); + ASSERT_EQ(default_gen, default_gen_neg1); + + // Test that subsequent calls return the same generator instance + auto result2 = at::detail::GetDefaultXLAGenerator(0); + ASSERT_TRUE(result2.ok()); + const at::Generator& default_gen2 = result2.value(); + ASSERT_EQ(default_gen, default_gen2); + + // Test getting non-defuault device generator + auto result_device1 = at::detail::GetDefaultXLAGenerator(1); + ASSERT_TRUE(result_device1.ok()) + << "Failed to get default generator for device 1: " + << result_device1.status(); + + const at::Generator& default_gen_device1 = result_device1.value(); + ASSERT_EQ(default_gen_device1.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen_device1.device().index(), 1); + ASSERT_NE(default_gen_device1, default_gen); +} + +TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) { + // Test with invalid device indices + auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2); + ASSERT_FALSE(result_neg2.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status())); + ASSERT_THAT(result_neg2.status().message(), + testing::HasSubstr("Invalid XLA device index")); + + // Test with very large device index (assuming there aren't 1000 XLA devices) + auto result_large = at::detail::GetDefaultXLAGenerator(100); + ASSERT_FALSE(result_large.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_large.status())); + ASSERT_THAT(result_large.status().message(), + testing::HasSubstr("Invalid XLA device index")); +} + +TEST_F(XLAGeneratorTest, CreateXLAGenerator) { + // Test creating generator for device 1 + auto result = at::detail::CreateXLAGenerator(1); + ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status(); + + at::Generator created_gen = result.value(); + ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA); + ASSERT_EQ(created_gen.device().index(), 1); + + // Test that the generator is initialized with default seed + ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val); + + // Test creating generator with -1 (should use current device) + auto result_current = at::detail::CreateXLAGenerator(-1); + ASSERT_TRUE(result_current.ok()) + << "Failed to create generator with -1: " << result_current.status(); + + at::Generator created_gen_neg1 = result_current.value(); + ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA); + // Device index should be >= 0 (actual device depends on current XLA device) + ASSERT_GE(created_gen_neg1.device().index(), 0); +} + +TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) { + // Test that each call creates a new generator instance + auto result1 = at::detail::CreateXLAGenerator(0); + auto result2 = at::detail::CreateXLAGenerator(0); + + ASSERT_TRUE(result1.ok()); + ASSERT_TRUE(result2.ok()); + + at::Generator gen1 = result1.value(); + at::Generator gen2 = result2.value(); + + // Should be different instances (compare generators, not their stack + // addresses) + ASSERT_NE(gen1, gen2); + + // But should have same device and initial seed + ASSERT_EQ(gen1.device(), gen2.device()); + ASSERT_EQ(gen1.current_seed(), gen2.current_seed()); + + // Modifying one should not affect the other + gen1.set_current_seed(12345); + ASSERT_NE(gen1.current_seed(), gen2.current_seed()); +} + +TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) { + // Test with invalid device indices + auto result_neg2 = at::detail::CreateXLAGenerator(-2); + ASSERT_FALSE(result_neg2.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status())); + ASSERT_THAT(result_neg2.status().message(), + testing::HasSubstr("Invalid XLA device index")); + + // Test with very large device index (assuming there aren't 100 XLA devices) + auto result_large = at::detail::CreateXLAGenerator(100); + ASSERT_FALSE(result_large.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_large.status())); + ASSERT_THAT(result_large.status().message(), + testing::HasSubstr("Invalid XLA device index")); +} + } // namespace cpp_test -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 5d0a7c15866b..0e311bda2632 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -5,10 +5,115 @@ #include #include #include +#include #include +#include #include #include +#include +#include + +#include "absl/status/status.h" +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" + +namespace at { + +namespace detail { + +namespace { + +// Total number of XLA devices in the system. +static int64_t num_xla_devices; + +// Ensures default_gens_xla is initialized once. +static std::deque xla_gens_init_flag; + +// Default, global XLA generators, one per XLA device. +static std::vector default_gens_xla; + +/* + * Populates the global variables related to XLA generators + * Warning: this function must only be called once! + */ +static absl::Status InitXLAGenVector() { + static absl::Status init_status = []() { + XLA_ASSIGN_OR_RETURN(auto c_client, + torch_xla::runtime::GetComputationClient()); + num_xla_devices = static_cast(c_client->GetNumDevices()); + xla_gens_init_flag.resize(num_xla_devices); + default_gens_xla.resize(num_xla_devices); + return absl::OkStatus(); + }(); + return init_status; +} + +// Validates and normalizes an XLA device index. +// If requested_index == -1, the current device index is used. +// Returns InvalidArgument if the resolved index is out of range. +static absl::StatusOr NormalizeXLADeviceIndex( + c10::DeviceIndex requested_index) { + c10::DeviceIndex idx = requested_index; + if (idx == -1) { + idx = torch_xla::bridge::GetCurrentAtenDevice().index(); + } + if (idx < 0 || idx >= num_xla_devices) { + return absl::InvalidArgumentError( + "Invalid device index for XLA generator. Provided index: " + + std::to_string(idx)); + } + return idx; +} + +} // anonymous namespace + +/** + * PyTorch maintains a collection of default generators that get + * initialized once. The purpose of these default generators is to + * maintain a global running state of the pseudo random number generation, + * when a user does not explicitly mention any generator. + * GetDefaultXLAGenerator gets the default generator for a particular + * XLA device. + */ +absl::StatusOr GetDefaultXLAGenerator( + c10::DeviceIndex device_index) { + XLA_RETURN_IF_ERROR(InitXLAGenVector(), + "Failed to initialize XLA generators"); + // Normalize and validate the target device index; default to current device + // when unspecified + XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx, + NormalizeXLADeviceIndex(device_index), + "Invalid XLA device index"); + c10::call_once(xla_gens_init_flag[idx], [&] { + default_gens_xla[idx] = at::make_generator(idx); + default_gens_xla[idx].seed(); + }); + return default_gens_xla[idx]; +} + +/** + * Utility to create a XLAGeneratorImpl. Returns a shared_ptr + */ +absl::StatusOr CreateXLAGenerator( + c10::DeviceIndex device_index) { + XLA_RETURN_IF_ERROR(InitXLAGenVector(), + "Failed to initialize XLA generators"); + // Normalize and validate the target device index; default to current device + // when unspecified + XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx, + NormalizeXLADeviceIndex(device_index), + "Invalid XLA device index"); + auto gen = at::make_generator(idx); + auto xla_gen = at::check_generator(gen); + xla_gen->set_current_seed(c10::default_rng_seed_val); + return gen; +} + +} // namespace detail +} // namespace at namespace at { diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 330d32861200..8001737e795c 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -2,10 +2,17 @@ #include #include +#include +#include +#include +#include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" + namespace at { // Holds the actual state variables for the XLA generator. @@ -53,4 +60,13 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { c10::intrusive_ptr state_; }; -} // namespace at \ No newline at end of file +namespace detail { + +absl::StatusOr GetDefaultXLAGenerator( + c10::DeviceIndex device_index = -1); +absl::StatusOr CreateXLAGenerator( + c10::DeviceIndex device_index = -1); + +} // namespace detail + +} // namespace at