Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 157 additions & 1 deletion test/cpp/test_xla_generator.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,54 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <torch/torch.h>

#include <cstdlib>

#include "test/cpp/torch_xla_test.h"
#include "torch_xla/csrc/xla_generator.h"

namespace torch_xla {
namespace cpp_test {

// Ensure PJRT is configured to a CPU backend for tests that touch the PJRT
// runtime. Optionally allow overriding the environment values by passing
// `pjrt_device` and/or `cpu_num_devices`.
static void EnsurePjrtCpuBackend(const char* pjrt_device = nullptr,
const char* cpu_num_devices = nullptr) {
// PJRT_DEVICE: override if provided, otherwise set default if not present
if (pjrt_device != nullptr && pjrt_device[0] != '\0') {
// Force override of any existing value
setenv("PJRT_DEVICE", pjrt_device, 1);
} else {
const char* pjrt = std::getenv("PJRT_DEVICE");
if (pjrt == nullptr || pjrt[0] == '\0') {
// Use CPU backend with a single device by default.
setenv("PJRT_DEVICE", "CPU", 1);
}
}

// CPU_NUM_DEVICES: override if provided, otherwise set default if not present
if (cpu_num_devices != nullptr && cpu_num_devices[0] != '\0') {
// Force override of any existing value
setenv("CPU_NUM_DEVICES", cpu_num_devices, 1);
} else {
const char* cpu_devices = std::getenv("CPU_NUM_DEVICES");
if (cpu_devices == nullptr || cpu_devices[0] == '\0') {
// Default to a single CPU device. Preserve existing behavior of not
// overwriting if already present (use overwrite=0 to match previous
// semantics).
setenv("CPU_NUM_DEVICES", "1", 0);
}
}
}

Comment on lines +13 to +44
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed above, let's not do this.
Let's leave the environment variable initialization to the test runner script.

// Test fixture for XLAGenerator tests
class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest {
protected:
// Runs once before the test suite / test case to ensure PJRT is configured
// before any XLA runtime initialization happens in per-test SetUp().
static void SetUpTestCase() { EnsurePjrtCpuBackend("CPU", "2"); }

void SetUp() {
// Create a generator for XLA device 0
gen_ = at::make_generator<at::XLAGeneratorImpl>(0);
Expand Down Expand Up @@ -102,5 +141,122 @@ TEST_F(XLAGeneratorTest, Clone) {
ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed());
}

TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) {
// Test getting default generator for device 0
auto result = at::detail::GetDefaultXLAGenerator(0);
ASSERT_TRUE(result.ok()) << "Failed to get default generator: "
<< result.status();

const at::Generator& default_gen = result.value();
ASSERT_EQ(default_gen.device().type(), at::DeviceType::XLA);
ASSERT_EQ(default_gen.device().index(), 0);

// Test getting default generator with -1 (should default to device 0)
auto result_default = at::detail::GetDefaultXLAGenerator(-1);
ASSERT_TRUE(result_default.ok())
<< "Failed to get default generator with -1: " << result_default.status();

const at::Generator& default_gen_neg1 = result_default.value();
ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA);
ASSERT_EQ(default_gen_neg1.device().index(), 0);
ASSERT_EQ(default_gen, default_gen_neg1);

// Test that subsequent calls return the same generator instance
auto result2 = at::detail::GetDefaultXLAGenerator(0);
ASSERT_TRUE(result2.ok());
const at::Generator& default_gen2 = result2.value();
ASSERT_EQ(default_gen, default_gen2);

// Test getting non-defuault device generator
auto result_device1 = at::detail::GetDefaultXLAGenerator(1);
ASSERT_TRUE(result_device1.ok())
<< "Failed to get default generator for device 1: "
<< result_device1.status();

const at::Generator& default_gen_device1 = result_device1.value();
ASSERT_EQ(default_gen_device1.device().type(), at::DeviceType::XLA);
ASSERT_EQ(default_gen_device1.device().index(), 1);
ASSERT_NE(default_gen_device1, default_gen);
}

TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) {
// Test with invalid device indices
auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2);
ASSERT_FALSE(result_neg2.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));
ASSERT_THAT(result_neg2.status().message(),
testing::HasSubstr("Invalid XLA device index"));

// Test with very large device index (assuming there aren't 1000 XLA devices)
auto result_large = at::detail::GetDefaultXLAGenerator(100);
ASSERT_FALSE(result_large.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
ASSERT_THAT(result_large.status().message(),
testing::HasSubstr("Invalid XLA device index"));
}

TEST_F(XLAGeneratorTest, CreateXLAGenerator) {
// Test creating generator for device 1
auto result = at::detail::CreateXLAGenerator(1);
ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status();

at::Generator created_gen = result.value();
ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA);
ASSERT_EQ(created_gen.device().index(), 1);

// Test that the generator is initialized with default seed
ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val);

// Test creating generator with -1 (should use current device)
auto result_current = at::detail::CreateXLAGenerator(-1);
ASSERT_TRUE(result_current.ok())
<< "Failed to create generator with -1: " << result_current.status();

at::Generator created_gen_neg1 = result_current.value();
ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA);
// Device index should be >= 0 (actual device depends on current XLA device)
ASSERT_GE(created_gen_neg1.device().index(), 0);
}

TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) {
// Test that each call creates a new generator instance
auto result1 = at::detail::CreateXLAGenerator(0);
auto result2 = at::detail::CreateXLAGenerator(0);

ASSERT_TRUE(result1.ok());
ASSERT_TRUE(result2.ok());

at::Generator gen1 = result1.value();
at::Generator gen2 = result2.value();

// Should be different instances (compare generators, not their stack
// addresses)
ASSERT_NE(gen1, gen2);

// But should have same device and initial seed
ASSERT_EQ(gen1.device(), gen2.device());
ASSERT_EQ(gen1.current_seed(), gen2.current_seed());

// Modifying one should not affect the other
gen1.set_current_seed(12345);
ASSERT_NE(gen1.current_seed(), gen2.current_seed());
}

TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) {
// Test with invalid device indices
auto result_neg2 = at::detail::CreateXLAGenerator(-2);
ASSERT_FALSE(result_neg2.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));
ASSERT_THAT(result_neg2.status().message(),
testing::HasSubstr("Invalid XLA device index"));

// Test with very large device index (assuming there aren't 100 XLA devices)
auto result_large = at::detail::CreateXLAGenerator(100);
ASSERT_FALSE(result_large.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
ASSERT_THAT(result_large.status().message(),
testing::HasSubstr("Invalid XLA device index"));
}

} // namespace cpp_test
} // namespace torch_xla
} // namespace torch_xla
105 changes: 105 additions & 0 deletions torch_xla/csrc/xla_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,115 @@
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/CallOnce.h>
#include <c10/util/intrusive_ptr.h>

#include <cstring>
#include <deque>
#include <vector>

#include "absl/status/status.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/status.h"

namespace at {

namespace detail {

namespace {

// Total number of XLA devices in the system.
static int64_t num_xla_devices;

// Ensures default_gens_xla is initialized once.
static std::deque<c10::once_flag> xla_gens_init_flag;

// Default, global XLA generators, one per XLA device.
static std::vector<at::Generator> default_gens_xla;

/*
* Populates the global variables related to XLA generators
* Warning: this function must only be called once!
*/
static absl::Status InitXLAGenVector() {
static absl::Status init_status = []() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leak it (see this example).

XLA_ASSIGN_OR_RETURN(auto c_client,
torch_xla::runtime::GetComputationClient());
num_xla_devices = static_cast<int64_t>(c_client->GetNumDevices());
xla_gens_init_flag.resize(num_xla_devices);
default_gens_xla.resize(num_xla_devices);
return absl::OkStatus();
}();
return init_status;
}

// Validates and normalizes an XLA device index.
// If requested_index == -1, the current device index is used.
// Returns InvalidArgument if the resolved index is out of range.
static absl::StatusOr<c10::DeviceIndex> NormalizeXLADeviceIndex(
c10::DeviceIndex requested_index) {
c10::DeviceIndex idx = requested_index;
if (idx == -1) {
idx = torch_xla::bridge::GetCurrentAtenDevice().index();
}
if (idx < 0 || idx >= num_xla_devices) {
return absl::InvalidArgumentError(
"Invalid device index for XLA generator. Provided index: " +
std::to_string(idx));
}
return idx;
}

} // anonymous namespace

/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
* maintain a global running state of the pseudo random number generation,
* when a user does not explicitly mention any generator.
* GetDefaultXLAGenerator gets the default generator for a particular
* XLA device.
*/
absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(
c10::DeviceIndex device_index) {
XLA_RETURN_IF_ERROR(InitXLAGenVector(),
"Failed to initialize XLA generators");
// Normalize and validate the target device index; default to current device
// when unspecified
XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx,
NormalizeXLADeviceIndex(device_index),
"Invalid XLA device index");
c10::call_once(xla_gens_init_flag[idx], [&] {
default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
default_gens_xla[idx].seed();
});
return default_gens_xla[idx];
}

/**
* Utility to create a XLAGeneratorImpl. Returns a shared_ptr
*/
absl::StatusOr<at::Generator> CreateXLAGenerator(
c10::DeviceIndex device_index) {
XLA_RETURN_IF_ERROR(InitXLAGenVector(),
"Failed to initialize XLA generators");
Comment on lines +102 to +103
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually needed for CreateXLAGenerator?

// Normalize and validate the target device index; default to current device
// when unspecified
XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx,
NormalizeXLADeviceIndex(device_index),
"Invalid XLA device index");
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
xla_gen->set_current_seed(c10::default_rng_seed_val);
return gen;
}

} // namespace detail
} // namespace at

namespace at {

Expand Down
18 changes: 17 additions & 1 deletion torch_xla/csrc/xla_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@

#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/intrusive_ptr.h>

#include <cstdint>

#include "absl/status/status.h"
#include "absl/status/statusor.h"

namespace at {

// Holds the actual state variables for the XLA generator.
Expand Down Expand Up @@ -53,4 +60,13 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
c10::intrusive_ptr<XLAGeneratorState> state_;
};

} // namespace at
namespace detail {

absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(
c10::DeviceIndex device_index = -1);
absl::StatusOr<at::Generator> CreateXLAGenerator(
c10::DeviceIndex device_index = -1);

} // namespace detail

} // namespace at