From b692cdea479fba8201584054d654f639e925a265 Mon Sep 17 00:00:00 2001 From: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Date: Tue, 12 Nov 2024 08:34:17 -0800 Subject: [PATCH 01/16] AIO File Offsets (#6641) Adding the option for a file offset to the read/write functions of AIO & GDS ops. --------- Co-authored-by: jomayeri Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- accelerator/cpu_accelerator.py | 2 + csrc/aio/common/deepspeed_aio_utils.cpp | 18 ++-- csrc/aio/common/deepspeed_aio_utils.h | 4 +- csrc/aio/py_lib/deepspeed_aio_op_desc.cpp | 6 +- csrc/aio/py_lib/deepspeed_aio_op_desc.h | 4 +- csrc/aio/py_lib/deepspeed_cpu_op.cpp | 19 ++-- csrc/aio/py_lib/deepspeed_cpu_op.h | 3 +- csrc/aio/py_lib/deepspeed_py_aio.cpp | 7 +- csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 65 +++++++++----- csrc/aio/py_lib/deepspeed_py_io_handle.h | 27 ++++-- csrc/aio/py_lib/py_ds_aio.cpp | 24 +++-- csrc/aio/py_test/ds_aio_handle.py | 2 +- csrc/gds/py_lib/deepspeed_gds_op.cpp | 24 +++-- csrc/gds/py_lib/deepspeed_gds_op.h | 3 +- csrc/gds/py_lib/deepspeed_py_gds_handle.cpp | 15 +++- csrc/gds/py_lib/deepspeed_py_gds_handle.h | 3 +- csrc/gds/py_lib/py_ds_gds.cpp | 24 +++-- deepspeed/runtime/swap_tensor/utils.py | 4 +- deepspeed/utils/numa.py | 5 +- tests/unit/ops/aio/test_aio.py | 97 +++++++++++++++++++-- tests/unit/ops/aio/test_gds.py | 87 ++++++++++++++++-- 21 files changed, 342 insertions(+), 101 deletions(-) diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 1e4335b19292..0e49bd9f6458 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -71,6 +71,8 @@ def device_count(self): # In flat mode, HBM is in separate NUMA node with no cores on this node. # Ignore these NUMA nodes with no cores. numa_core_lists = get_numa_cores() + if not numa_core_lists: + return 1 numa_count = 0 prev_core_list = [] for core_list in numa_core_lists: diff --git a/csrc/aio/common/deepspeed_aio_utils.cpp b/csrc/aio/common/deepspeed_aio_utils.cpp index 0536ff6a362e..fb269b58315f 100644 --- a/csrc/aio/common/deepspeed_aio_utils.cpp +++ b/csrc/aio/common/deepspeed_aio_utils.cpp @@ -19,9 +19,14 @@ const int c_io_queue_depth = 8; io_xfer_ctxt::io_xfer_ctxt(const int fd, const int64_t file_offset, + const int64_t buffer_offset, const int64_t num_bytes, const void* buffer) - : _fd(fd), _base_offset(file_offset), _mem_buffer(buffer), _num_bytes(num_bytes) + : _fd(fd), + _file_base_offset(file_offset), + _buffer_base_offset(buffer_offset), + _mem_buffer(buffer), + _num_bytes(num_bytes) { } @@ -41,9 +46,10 @@ void io_prep_context::prep_iocbs(const int n_iocbs, assert(static_cast(n_iocbs) <= _iocbs->size()); for (auto i = 0; i < n_iocbs; ++i) { const auto shift = i * _block_size; - const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_base_offset + shift; - const auto xfer_offset = _xfer_ctxt->_base_offset + start_offset + shift; + const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_buffer_base_offset + shift; + const auto xfer_offset = _xfer_ctxt->_file_base_offset + start_offset + shift; auto byte_count = _block_size; + if ((shift + _block_size) > num_bytes) { byte_count = num_bytes - shift; } if (_read_op) { @@ -79,10 +85,10 @@ int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); for (auto i = 0; i < actual_n_iocbs; ++i, ++_next_iocb_index) { - const auto xfer_offset = _xfer_ctxt->_base_offset + (_next_iocb_index * _block_size); - const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + xfer_offset; + const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + _xfer_ctxt->_buffer_base_offset + + (_next_iocb_index * _block_size); + const auto xfer_offset = _xfer_ctxt->_file_base_offset + (_next_iocb_index * _block_size); const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); - if (_read_op) { io_prep_pread(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); } else { diff --git a/csrc/aio/common/deepspeed_aio_utils.h b/csrc/aio/common/deepspeed_aio_utils.h index 20e81fe8eebd..6b7599acecb4 100644 --- a/csrc/aio/common/deepspeed_aio_utils.h +++ b/csrc/aio/common/deepspeed_aio_utils.h @@ -30,12 +30,14 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. struct io_xfer_ctxt { const int _fd; - const int64_t _base_offset; + const int64_t _file_base_offset; + const int64_t _buffer_base_offset; const void* _mem_buffer; const int64_t _num_bytes; io_xfer_ctxt(const int fd, const int64_t file_offset, + const int64_t buffer_offset, const int64_t num_bytes, const void* buffer); }; diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp index 6f311c5400c7..945251397225 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp @@ -13,14 +13,16 @@ io_op_desc_t::io_op_desc_t(const bool read_op, const char* filename, const int64_t file_num_bytes, const int intra_op_parallelism, - const bool validate) + const bool validate, + const int64_t file_offset) : _read_op(read_op), _buffer(buffer), _fd(fd), _filename(filename), _file_num_bytes(file_num_bytes), + _file_offset(file_offset), _intra_op_parallelism(intra_op_parallelism), - _num_bytes_per_thread(file_num_bytes / intra_op_parallelism), + _num_bytes_per_thread(static_cast(buffer.nbytes()) / intra_op_parallelism), _validate(validate) { } diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h index f841b8ce520a..ac1cdf90f78b 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.h +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h @@ -19,6 +19,7 @@ struct io_op_desc_t { const int64_t _num_bytes_per_thread; torch::Tensor _contiguous_buffer; const bool _validate; + const int64_t _file_offset; io_op_desc_t(const bool read_op, const torch::Tensor& buffer, @@ -26,7 +27,8 @@ struct io_op_desc_t { const char* filename, const int64_t file_num_bytes, const int intra_op_parallelism, - const bool validate); + const bool validate, + const int64_t file_offset); virtual void run(const int tid, std::unique_ptr& aio_ctxt, diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp index da2ff568d74b..56fb33fb1886 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.cpp +++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp @@ -16,8 +16,16 @@ cpu_op_desc_t::cpu_op_desc_t( const char* filename, const int64_t file_num_bytes, const int intra_op_parallelism, - const bool validate) - : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, intra_op_parallelism, validate), + const bool validate, + const int64_t file_offset) + : io_op_desc_t(read_op, + buffer, + fd, + filename, + file_num_bytes, + intra_op_parallelism, + validate, + file_offset), _cpu_buffer(buffer), _pinned_tensor_mgr(pinned_tensor_mgr), _is_managed_bounce_buffer(false) @@ -66,10 +74,11 @@ void cpu_op_desc_t::run(const int tid, deepspeed_aio_config_t* aio_config) { assert(tid < _intra_op_parallelism); - const auto base_offset = _num_bytes_per_thread * tid; + const auto buffer_base_offset = _num_bytes_per_thread * tid; + const auto file_base_offset = _file_offset + (_num_bytes_per_thread * tid); - std::unique_ptr xfer_ctxt( - new io_xfer_ctxt(_fd, base_offset, _num_bytes_per_thread, data_ptr())); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt( + _fd, file_base_offset, buffer_base_offset, _num_bytes_per_thread, data_ptr())); if (aio_config->_overlap_events) { do_aio_operation_overlap(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr); diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h index 9de2fa254048..debaf4a90731 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.h +++ b/csrc/aio/py_lib/deepspeed_cpu_op.h @@ -20,7 +20,8 @@ struct cpu_op_desc_t : io_op_desc_t { const char* filename, const int64_t file_num_bytes, const int intra_op_parallelism, - const bool validate); + const bool validate, + const int64_t file_offset); void run(const int tid, std::unique_ptr& aio_ctxt, diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp index 02b04057d1ac..1ff0397043fa 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -52,7 +52,9 @@ int deepspeed_py_aio_write(const torch::Tensor& buffer, auto write_buffer = (char*)buffer.data_ptr(); const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(fd, 0, 0, num_write_bytes, write_buffer)); std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); if (config._overlap_events) { @@ -97,7 +99,8 @@ int deepspeed_py_aio_read(torch::Tensor& buffer, auto read_buffer = (char*)buffer.data_ptr(); assert(static_cast(buffer.nbytes()) == num_file_bytes); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(fd, 0, 0, num_file_bytes, read_buffer)); std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); if (config._overlap_events) { diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index 48ea8a1339d4..64d7c2e0541e 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -58,7 +58,10 @@ const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_e const int deepspeed_io_handle_t::get_intra_op_parallelism() const { return _intra_op_parallelism; } -int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) +int deepspeed_io_handle_t::read(torch::Tensor& buffer, + const char* filename, + const bool validate, + const int64_t file_offset) { const auto start_time = std::chrono::high_resolution_clock::now(); @@ -76,7 +79,8 @@ int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, con if (fd == -1) { return -1; } auto read_buffer = (char*)buffer.data_ptr(); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(fd, file_offset, 0, num_file_bytes, read_buffer)); if (_aio_config._overlap_events) { do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); @@ -98,7 +102,8 @@ int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, con int deepspeed_io_handle_t::write(const torch::Tensor& buffer, const char* filename, - const bool validate) + const bool validate, + const int64_t file_offset) { assert(_aio_ctxt); @@ -109,7 +114,8 @@ int deepspeed_io_handle_t::write(const torch::Tensor& buffer, auto write_buffer = (char*)buffer.data_ptr(); const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(fd, file_offset, 0, num_write_bytes, write_buffer)); if (_aio_config._overlap_events) { do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); @@ -206,7 +212,8 @@ std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( const int fd, const char* filename, const int64_t file_num_bytes, - const bool validate) + const bool validate, + const int64_t file_offset) { return std::make_shared(read_op, buffer, @@ -215,13 +222,15 @@ std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( filename, file_num_bytes, _intra_op_parallelism, - validate); + validate, + file_offset); } int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, const char* filename, const bool validate, - const bool async) + const bool async, + const int64_t file_offset) { int64_t num_file_bytes; if (-1 == get_file_size(filename, num_file_bytes)) { @@ -229,20 +238,18 @@ int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, report_file_error(filename, " fstat for read", error_code); return -1; } + + // buffer can exceed file size to enable 4k alignment const auto buffer_bytes = static_cast(buffer.nbytes()); - if (buffer_bytes != num_file_bytes) { - std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes - << " != " << num_file_bytes << std::endl; - } - assert(buffer_bytes == num_file_bytes); assert((num_file_bytes % _intra_op_parallelism) == 0); - if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } + if (!_is_valid_parallel_aio_op(true, buffer_bytes)) { return -1; } const auto fd = open_file(filename, true); if (fd == -1) { return -1; } - auto scheduled_op = _create_io_op_desc(true, buffer, fd, filename, num_file_bytes, validate); + auto scheduled_op = + _create_io_op_desc(true, buffer, fd, filename, num_file_bytes, validate, file_offset); _schedule_aio_work(scheduled_op); @@ -254,7 +261,8 @@ int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer, const char* filename, const bool validate, - const bool async) + const bool async, + const int64_t file_offset) { const auto num_write_bytes = static_cast(buffer.nbytes()); assert((num_write_bytes % _intra_op_parallelism) == 0); @@ -264,7 +272,8 @@ int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer, const auto fd = open_file(filename, false); if (fd == -1) { return -1; } - auto scheduled_op = _create_io_op_desc(false, buffer, fd, filename, num_write_bytes, validate); + auto scheduled_op = + _create_io_op_desc(false, buffer, fd, filename, num_write_bytes, validate, file_offset); _schedule_aio_work(scheduled_op); @@ -273,24 +282,32 @@ int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer, return wait(); } -int deepspeed_io_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) +int deepspeed_io_handle_t::sync_pread(torch::Tensor& buffer, + const char* filename, + const int64_t file_offset) { - return pread(buffer, filename, false, false); + return pread(buffer, filename, false, false, file_offset); } -int deepspeed_io_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) +int deepspeed_io_handle_t::sync_pwrite(const torch::Tensor& buffer, + const char* filename, + const int64_t file_offset) { - return pwrite(buffer, filename, false, false); + return pwrite(buffer, filename, false, false, file_offset); } -int deepspeed_io_handle_t::async_pread(torch::Tensor& buffer, const char* filename) +int deepspeed_io_handle_t::async_pread(torch::Tensor& buffer, + const char* filename, + const int64_t file_offset) { - return pread(buffer, filename, false, true); + return pread(buffer, filename, false, true, file_offset); } -int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) +int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, + const char* filename, + const int64_t file_offset) { - return pwrite(buffer, filename, false, true); + return pwrite(buffer, filename, false, true, file_offset); } at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const int64_t num_elem, diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h index 4fedf8080818..dfcb4125ab9a 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -38,27 +38,35 @@ struct deepspeed_io_handle_t { const bool get_overlap_events() const; const int get_intra_op_parallelism() const; - int read(torch::Tensor& buffer, const char* filename, const bool validate); + int read(torch::Tensor& buffer, + const char* filename, + const bool validate, + const int64_t file_offset); - int write(const torch::Tensor& buffer, const char* filename, const bool validate); + int write(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const int64_t file_offset); int pread(const torch::Tensor& buffer, const char* filename, const bool validate, - const bool async); + const bool async, + const int64_t file_offset); int pwrite(const torch::Tensor& buffer, const char* filename, const bool validate, - const bool async); + const bool async, + const int64_t file_offset); - int sync_pread(torch::Tensor& buffer, const char* filename); + int sync_pread(torch::Tensor& buffer, const char* filename, const int64_t file_offset); - int sync_pwrite(const torch::Tensor& buffer, const char* filename); + int sync_pwrite(const torch::Tensor& buffer, const char* filename, const int64_t file_offset); - int async_pread(torch::Tensor& buffer, const char* filename); + int async_pread(torch::Tensor& buffer, const char* filename, const int64_t file_offset); - int async_pwrite(const torch::Tensor& buffer, const char* filename); + int async_pwrite(const torch::Tensor& buffer, const char* filename, const int64_t file_offset); // TODO: Make API's args to be shape and dtype. torch::Tensor new_cpu_locked_tensor(const int64_t num_elem, @@ -81,5 +89,6 @@ struct deepspeed_io_handle_t { const int fd, const char* filename, const int64_t file_num_bytes, - const bool validate); + const bool validate, + const int64_t file_offset); }; diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp index b80fa2d6c8e6..bf298b691b81 100644 --- a/csrc/aio/py_lib/py_ds_aio.cpp +++ b/csrc/aio/py_lib/py_ds_aio.cpp @@ -40,14 +40,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "Synchronous and non-parallel file read. Returns count of completed read ops", "buffer"_a, "filename"_a, - "validate"_a) + "validate"_a, + "file_offset"_a = 0) .def("write", &deepspeed_aio_handle_t::write, "Synchronous and non-parallel file write. Returns count of completed write ops", "buffer"_a, "filename"_a, - "validate"_a) + "validate"_a, + "file_offset"_a = 0) .def("pread", &deepspeed_aio_handle_t::pread, @@ -55,7 +57,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "buffer"_a, "filename"_a, "validate"_a, - "async"_a) + "async"_a, + "file_offset"_a = 0) .def("pwrite", &deepspeed_aio_handle_t::pwrite, @@ -63,33 +66,38 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "buffer"_a, "filename"_a, "validate"_a, - "async"_a) + "async"_a, + "file_offset"_a = 0) .def("sync_pread", &deepspeed_aio_handle_t::sync_pread, "Synchrononous parallel file read. Returns count of completed read ops", "buffer"_a, - "filename"_a) + "filename"_a, + "file_offset"_a = 0) .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite, "Synchronous parallel file write. Returns count of completed write ops", "buffer"_a, - "filename"_a) + "filename"_a, + "file_offset"_a = 0) .def("async_pread", &deepspeed_aio_handle_t::async_pread, "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and " "following wait() returns count of completed ops.", "buffer"_a, - "filename"_a) + "filename"_a, + "file_offset"_a = 0) .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite, "Asynchronous parallel file write. Returns 0 on success, and following wait() returns " "count of completed ops.", "buffer"_a, - "filename"_a) + "filename"_a, + "file_offset"_a = 0) .def("new_cpu_locked_tensor", &deepspeed_aio_handle_t::new_cpu_locked_tensor, diff --git a/csrc/aio/py_test/ds_aio_handle.py b/csrc/aio/py_test/ds_aio_handle.py index f4a179deb9ec..6913e9090bf5 100755 --- a/csrc/aio/py_test/ds_aio_handle.py +++ b/csrc/aio/py_test/ds_aio_handle.py @@ -92,7 +92,7 @@ def main_parallel_read(pool_params): start_time = time.time() dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER - ret = handle.pread(ctxt[dest_buffer], ctxt['file'], args.validate, True) + ret = handle.pread(ctxt[dest_buffer], ctxt['file'], args.validate, 0, True) assert ret != -1 handle.wait() if dest_buffer == BOUNCE_BUFFER: diff --git a/csrc/gds/py_lib/deepspeed_gds_op.cpp b/csrc/gds/py_lib/deepspeed_gds_op.cpp index f49f74394374..b7055c8cc72b 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.cpp +++ b/csrc/gds/py_lib/deepspeed_gds_op.cpp @@ -95,8 +95,16 @@ gds_op_desc_t::gds_op_desc_t(const bool read_op, const char* filename, const int64_t file_num_bytes, const int intra_op_parallelism, - const bool validate) - : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, intra_op_parallelism, validate) + const bool validate, + const int64_t file_offset) + : io_op_desc_t(read_op, + buffer, + fd, + filename, + file_num_bytes, + intra_op_parallelism, + validate, + file_offset) { _contiguous_buffer = _buffer.contiguous(); const int64_t device = _buffer.get_device(); @@ -124,17 +132,17 @@ void gds_op_desc_t::run(const int tid, { assert(tid < _intra_op_parallelism); check_cudaruntimecall(cudaSetDevice(_buffer.get_device())); - int64_t buf_offset = data_ptr() + (_num_bytes_per_thread * tid) - (char*)_base_ptr; - const auto file_offset = _num_bytes_per_thread * tid; + const auto buf_offset = data_ptr() + (_num_bytes_per_thread * tid) - (char*)_base_ptr; + const auto tid_file_offset = _file_offset + (_num_bytes_per_thread * tid); if (_read_op) { auto ret = - cuFileRead(_cf_handle, _base_ptr, _num_bytes_per_thread, file_offset, buf_offset); - if (ret < 0) { _report_error(ret, errno, buf_offset); } + cuFileRead(_cf_handle, _base_ptr, _num_bytes_per_thread, tid_file_offset, buf_offset); + if (ret < 0) { _report_error(ret, errno, tid_file_offset); } } else { auto ret = - cuFileWrite(_cf_handle, _base_ptr, _num_bytes_per_thread, file_offset, buf_offset); - if (ret < 0) { _report_error(ret, errno, buf_offset); } + cuFileWrite(_cf_handle, _base_ptr, _num_bytes_per_thread, tid_file_offset, buf_offset); + if (ret < 0) { _report_error(ret, errno, tid_file_offset); } } } diff --git a/csrc/gds/py_lib/deepspeed_gds_op.h b/csrc/gds/py_lib/deepspeed_gds_op.h index 380bb0b9b6ae..d955527b1ba3 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.h +++ b/csrc/gds/py_lib/deepspeed_gds_op.h @@ -24,7 +24,8 @@ struct gds_op_desc_t : io_op_desc_t { const char* filename, const int64_t file_num_bytes, const int intra_op_parallelism, - const bool validate); + const bool validate, + const int64_t file_offset); void run(const int tid, std::unique_ptr& aio_ctxt, diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp index c052144a0190..f11245c75a5e 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -107,12 +107,19 @@ std::shared_ptr deepspeed_gds_handle_t::_create_io_op_desc( const int fd, const char* filename, const int64_t file_num_bytes, - const bool validate) + const bool validate, + const int64_t file_offset) { if (buffer.is_cuda()) { - return std::make_shared( - read_op, buffer, fd, filename, file_num_bytes, _intra_op_parallelism, validate); + return std::make_shared(read_op, + buffer, + fd, + filename, + file_num_bytes, + _intra_op_parallelism, + validate, + file_offset); } return deepspeed_io_handle_t::_create_io_op_desc( - read_op, buffer, fd, filename, file_num_bytes, validate); + read_op, buffer, fd, filename, file_num_bytes, validate, file_offset); } diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.h b/csrc/gds/py_lib/deepspeed_py_gds_handle.h index 131e83e7b838..25f68e177b2c 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.h +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.h @@ -42,7 +42,8 @@ struct deepspeed_gds_handle_t : deepspeed_io_handle_t { const int fd, const char* filename, const int64_t file_num_bytes, - const bool validate); + const bool validate, + const int64_t file_offset); static int s_cuFile_init; }; diff --git a/csrc/gds/py_lib/py_ds_gds.cpp b/csrc/gds/py_lib/py_ds_gds.cpp index 57bf8d2207c4..2f165ee2c32a 100644 --- a/csrc/gds/py_lib/py_ds_gds.cpp +++ b/csrc/gds/py_lib/py_ds_gds.cpp @@ -33,14 +33,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "Synchronous and non-parallel file read. Returns count of completed read ops", "buffer"_a, "filename"_a, - "validate"_a) + "validate"_a, + "file_offset"_a = 0) .def("write", &deepspeed_gds_handle_t::write, "Synchronous and non-parallel file write. Returns count of completed write ops", "buffer"_a, "filename"_a, - "validate"_a) + "validate"_a, + "file_offset"_a = 0) .def("pread", &deepspeed_gds_handle_t::pread, @@ -48,7 +50,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "buffer"_a, "filename"_a, "validate"_a, - "async"_a) + "async"_a, + "file_offset"_a = 0) .def("pwrite", &deepspeed_gds_handle_t::pwrite, @@ -56,33 +59,38 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "buffer"_a, "filename"_a, "validate"_a, - "async"_a) + "async"_a, + "file_offset"_a = 0) .def("sync_pread", &deepspeed_gds_handle_t::sync_pread, "Synchrononous parallel file read. Returns count of completed read ops", "buffer"_a, - "filename"_a) + "filename"_a, + "file_offset"_a = 0) .def("sync_pwrite", &deepspeed_gds_handle_t::sync_pwrite, "Synchronous parallel file write. Returns count of completed write ops", "buffer"_a, - "filename"_a) + "filename"_a, + "file_offset"_a = 0) .def("async_pread", &deepspeed_gds_handle_t::async_pread, "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and " "following wait() returns count of completed ops.", "buffer"_a, - "filename"_a) + "filename"_a, + "file_offset"_a = 0) .def("async_pwrite", &deepspeed_gds_handle_t::async_pwrite, "Asynchronous parallel file write. Returns 0 on success, and following wait() returns " "count of completed ops.", "buffer"_a, - "filename"_a) + "filename"_a, + "file_offset"_a = 0) .def("new_cpu_locked_tensor", &deepspeed_gds_handle_t::new_cpu_locked_tensor, diff --git a/deepspeed/runtime/swap_tensor/utils.py b/deepspeed/runtime/swap_tensor/utils.py index 90b2d9b8bd31..1f9825c34638 100644 --- a/deepspeed/runtime/swap_tensor/utils.py +++ b/deepspeed/runtime/swap_tensor/utils.py @@ -18,12 +18,12 @@ def swap_in_tensors(swap_handle, tensor_buffers, swap_paths): for buffer, path in zip(tensor_buffers, swap_paths): - assert (swap_handle.async_pread(buffer, path) == 0) + assert (swap_handle.async_pread(buffer, path, 0) == 0) def swap_out_tensors(swap_handle, tensor_buffers, swap_paths): for buffer, path in zip(tensor_buffers, swap_paths): - assert (swap_handle.async_pwrite(buffer, path) == 0) + assert (swap_handle.async_pwrite(buffer, path, 0) == 0) def print_object(obj, name, exclude_list=[]): diff --git a/deepspeed/utils/numa.py b/deepspeed/utils/numa.py index 4fe7cbba90ae..aba3b5179d41 100644 --- a/deepspeed/utils/numa.py +++ b/deepspeed/utils/numa.py @@ -23,7 +23,10 @@ # ] def get_numa_cores(): ret = [] - output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8") + try: + output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8") + except: + return [] lines = output.split('\n') for line in lines: if line.startswith('available:'): diff --git a/tests/unit/ops/aio/test_aio.py b/tests/unit/ops/aio/test_aio.py index a074cfca317f..1aa5f647a8aa 100644 --- a/tests/unit/ops/aio/test_aio.py +++ b/tests/unit/ops/aio/test_aio.py @@ -35,16 +35,21 @@ def _get_local_rank(): return 0 -def _do_ref_write(tmpdir, index=0): +def _do_ref_write(tmpdir, index=0, file_size=IO_SIZE): file_suffix = f'{_get_local_rank()}_{index}' ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') - ref_buffer = os.urandom(IO_SIZE) + ref_buffer = os.urandom(file_size) with open(ref_file, 'wb') as f: f.write(ref_buffer) return ref_file, ref_buffer +def _get_file_path(tmpdir, file_prefix, index=0): + file_suffix = f'{_get_local_rank()}_{index}' + return os.path.join(tmpdir, f'{file_prefix}_{file_suffix}.pt') + + def _get_test_write_file(tmpdir, index): file_suffix = f'{_get_local_rank()}_{index}' return os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt') @@ -103,7 +108,7 @@ def test_parallel_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, over _validate_handle_state(h, single_submit, overlap_events) ref_file, _ = _do_ref_write(tmpdir) - read_status = h.sync_pread(aio_buffer, ref_file) + read_status = h.sync_pread(aio_buffer, ref_file, 0) assert read_status == 1 with open(ref_file, 'rb') as f: @@ -131,7 +136,7 @@ def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap _validate_handle_state(h, single_submit, overlap_events) ref_file, _ = _do_ref_write(tmpdir) - read_status = h.async_pread(aio_buffer, ref_file) + read_status = h.async_pread(aio_buffer, ref_file, 0) assert read_status == 0 wait_status = h.wait() @@ -172,7 +177,7 @@ def test_parallel_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, ove _validate_handle_state(h, single_submit, overlap_events) - write_status = h.sync_pwrite(aio_buffer, aio_file) + write_status = h.sync_pwrite(aio_buffer, aio_file, 0) assert write_status == 1 if not use_cuda_pinned_tensor: @@ -201,7 +206,7 @@ def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overla _validate_handle_state(h, single_submit, overlap_events) - write_status = h.async_pwrite(aio_buffer, aio_file) + write_status = h.async_pwrite(aio_buffer, aio_file, 0) assert write_status == 0 wait_status = h.wait() @@ -258,7 +263,7 @@ def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, use_unpinned_te _validate_handle_state(h, single_submit, overlap_events) for i in range(async_queue): - read_status = h.async_pread(aio_buffers[i], ref_files[i]) + read_status = h.async_pread(aio_buffers[i], ref_files[i], 0) assert read_status == 0 wait_status = h.wait() @@ -305,7 +310,7 @@ def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, use_unpinned_t _validate_handle_state(h, single_submit, overlap_events) for i in range(async_queue): - read_status = h.async_pwrite(aio_buffers[i], aio_files[i]) + read_status = h.async_pwrite(aio_buffers[i], aio_files[i], 0) assert read_status == 0 wait_status = h.wait() @@ -320,3 +325,79 @@ def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, use_unpinned_t filecmp.clear_cache() assert filecmp.cmp(ref_files[i], aio_files[i], shallow=False) + + +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) +@pytest.mark.parametrize('file_partitions', [[1, 1, 1], [1, 1, 2], [1, 2, 1], [2, 1, 1]]) +class TestAsyncFileOffset(DistributedTest): + world_size = 1 + + def test_offset_write(self, tmpdir, file_partitions, use_cuda_pinned_tensor): + + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) + ref_file = _get_file_path(tmpdir, '_py_random') + aio_file = _get_file_path(tmpdir, '_aio_random') + partition_unit_size = BLOCK_SIZE + file_size = sum(file_partitions) * partition_unit_size + + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + if use_cuda_pinned_tensor: + data_buffer = torch.ByteTensor(list(os.urandom(file_size))).pin_memory() + else: + data_buffer = h.new_cpu_locked_tensor(file_size, torch.empty(0, dtype=torch.uint8)) + + file_offsets = [] + next_offset = 0 + for i in range(len(file_partitions)): + file_offsets.append(next_offset) + next_offset += file_partitions[i] * partition_unit_size + + ref_fd = open(ref_file, 'wb') + for i in range(len(file_partitions)): + src_buffer = torch.narrow(data_buffer, 0, file_offsets[i], file_partitions[i] * partition_unit_size) + + ref_fd.write(src_buffer.numpy().tobytes()) + ref_fd.flush() + + assert 1 == h.sync_pwrite(buffer=src_buffer, filename=aio_file, file_offset=file_offsets[i]) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, aio_file, shallow=False) + + ref_fd.close() + + if not use_cuda_pinned_tensor: + h.free_cpu_locked_tensor(data_buffer) + + def test_offset_read(self, tmpdir, file_partitions, use_cuda_pinned_tensor): + + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) + partition_unit_size = BLOCK_SIZE + file_size = sum(file_partitions) * partition_unit_size + ref_file, _ = _do_ref_write(tmpdir, 0, file_size) + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + if use_cuda_pinned_tensor: + data_buffer = torch.zeros(file_size, dtype=torch.uint8, device='cpu').pin_memory() + else: + data_buffer = h.new_cpu_locked_tensor(file_size, torch.empty(0, dtype=torch.uint8)) + + file_offsets = [] + next_offset = 0 + for i in range(len(file_partitions)): + file_offsets.append(next_offset) + next_offset += file_partitions[i] * partition_unit_size + + with open(ref_file, 'rb') as ref_fd: + for i in range(len(file_partitions)): + ref_fd.seek(file_offsets[i]) + bytes_to_read = file_partitions[i] * partition_unit_size + ref_buf = list(ref_fd.read(bytes_to_read)) + + dst_tensor = torch.narrow(data_buffer, 0, 0, bytes_to_read) + assert 1 == h.sync_pread(dst_tensor, ref_file, file_offsets[i]) + assert dst_tensor.tolist() == ref_buf + + if not use_cuda_pinned_tensor: + h.free_cpu_locked_tensor(data_buffer) diff --git a/tests/unit/ops/aio/test_gds.py b/tests/unit/ops/aio/test_gds.py index e94d42cd22af..d97eff452eb5 100644 --- a/tests/unit/ops/aio/test_gds.py +++ b/tests/unit/ops/aio/test_gds.py @@ -29,16 +29,21 @@ def _get_local_rank(): return 0 -def _do_ref_write(tmpdir, index=0): +def _do_ref_write(tmpdir, index=0, file_size=IO_SIZE): file_suffix = f'{_get_local_rank()}_{index}' ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') - ref_buffer = os.urandom(IO_SIZE) + ref_buffer = os.urandom(file_size) with open(ref_file, 'wb') as f: f.write(ref_buffer) return ref_file, ref_buffer +def _get_file_path(tmpdir, file_prefix, index=0): + file_suffix = f'{_get_local_rank()}_{index}' + return os.path.join(tmpdir, f'{file_prefix}_{file_suffix}.pt') + + def _get_test_write_file(tmpdir, index): file_suffix = f'{_get_local_rank()}_{index}' return os.path.join(tmpdir, f'_gds_write_random_{file_suffix}.pt') @@ -78,7 +83,7 @@ def test_parallel_read(self, tmpdir, single_submit, overlap_events): _validate_handle_state(h, single_submit, overlap_events) ref_file, _ = _do_ref_write(tmpdir) - read_status = h.sync_pread(gds_buffer, ref_file) + read_status = h.sync_pread(gds_buffer, ref_file, 0) assert read_status == 1 with open(ref_file, 'rb') as f: @@ -97,7 +102,7 @@ def test_async_read(self, tmpdir, single_submit, overlap_events): _validate_handle_state(h, single_submit, overlap_events) ref_file, _ = _do_ref_write(tmpdir) - read_status = h.async_pread(gds_buffer, ref_file) + read_status = h.async_pread(gds_buffer, ref_file, 0) assert read_status == 0 wait_status = h.wait() @@ -128,7 +133,7 @@ def test_parallel_write(self, tmpdir, single_submit, overlap_events): _validate_handle_state(h, single_submit, overlap_events) - write_status = h.sync_pwrite(gds_buffer, gds_file) + write_status = h.sync_pwrite(gds_buffer, gds_file, 0) assert write_status == 1 h.unpin_device_tensor(gds_buffer) @@ -146,7 +151,7 @@ def test_async_write(self, tmpdir, single_submit, overlap_events): _validate_handle_state(h, single_submit, overlap_events) - write_status = h.async_pwrite(gds_buffer, gds_file) + write_status = h.async_pwrite(gds_buffer, gds_file, 0) assert write_status == 0 wait_status = h.wait() @@ -188,7 +193,7 @@ def test_read(self, tmpdir, async_queue): _validate_handle_state(h, single_submit, overlap_events) for i in range(async_queue): - read_status = h.async_pread(gds_buffers[i], ref_files[i]) + read_status = h.async_pread(gds_buffers[i], ref_files[i], 0) assert read_status == 0 wait_status = h.wait() @@ -225,7 +230,7 @@ def test_write(self, tmpdir, async_queue): _validate_handle_state(h, single_submit, overlap_events) for i in range(async_queue): - read_status = h.async_pwrite(gds_buffers[i], gds_files[i]) + read_status = h.async_pwrite(gds_buffers[i], gds_files[i], 0) assert read_status == 0 wait_status = h.wait() @@ -268,3 +273,69 @@ def test_pin_device_tensor(self, use_new_api): h.free_pinned_device_tensor(pinned_buffer) else: h.unpin_device_tensor(pinned_buffer) + + +@pytest.mark.parametrize('file_partitions', [[1, 1, 1], [1, 1, 2], [1, 2, 1], [2, 1, 1]]) +class TestAsyncFileOffset(DistributedTest): + world_size = 1 + + def test_offset_write(self, tmpdir, file_partitions): + ref_file = _get_file_path(tmpdir, '_py_random') + aio_file = _get_file_path(tmpdir, '_aio_random') + partition_unit_size = IO_SIZE + file_size = sum(file_partitions) * partition_unit_size + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + gds_buffer = torch.empty(file_size, dtype=torch.uint8, device=get_accelerator().device_name()) + h.pin_device_tensor(gds_buffer) + + file_offsets = [] + next_offset = 0 + for i in range(len(file_partitions)): + file_offsets.append(next_offset) + next_offset += file_partitions[i] * partition_unit_size + + ref_fd = open(ref_file, 'wb') + for i in range(len(file_partitions)): + src_buffer = torch.narrow(gds_buffer, 0, file_offsets[i], + file_partitions[i] * partition_unit_size).to(device='cpu') + + ref_fd.write(src_buffer.numpy().tobytes()) + ref_fd.flush() + + assert 1 == h.sync_pwrite(buffer=src_buffer, filename=aio_file, file_offset=file_offsets[i]) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, aio_file, shallow=False) + + ref_fd.close() + + h.unpin_device_tensor(gds_buffer) + + def test_offset_read(self, tmpdir, file_partitions): + partition_unit_size = BLOCK_SIZE + file_size = sum(file_partitions) * partition_unit_size + ref_file, _ = _do_ref_write(tmpdir, 0, file_size) + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + gds_buffer = torch.empty(file_size, dtype=torch.uint8, device=get_accelerator().device_name()) + h.pin_device_tensor(gds_buffer) + + file_offsets = [] + next_offset = 0 + for i in range(len(file_partitions)): + file_offsets.append(next_offset) + next_offset += file_partitions[i] * partition_unit_size + + with open(ref_file, 'rb') as ref_fd: + for i in range(len(file_partitions)): + ref_fd.seek(file_offsets[i]) + bytes_to_read = file_partitions[i] * partition_unit_size + ref_buf = list(ref_fd.read(bytes_to_read)) + + dst_tensor = torch.narrow(gds_buffer, 0, 0, bytes_to_read) + assert 1 == h.sync_pread(dst_tensor, ref_file, file_offsets[i]) + assert dst_tensor.tolist() == ref_buf + + h.unpin_device_tensor(gds_buffer) From 877aa0dba673c2aa2157029c28363b804d6ee03d Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:50:02 -0800 Subject: [PATCH 02/16] Update path for BingBertSquad from DeepSpeedExamples (#6746) In https://github.com/microsoft/DeepSpeedExamples/pull/245, the DeepSpeedExamples directory structure was refactored, this updates the DeepSpeed examples from those changes. --- docs/_tutorials/bert-finetuning.md | 4 ++-- docs/_tutorials/onebit-adam.md | 4 ++-- tests/model/BingBertSquad/run_BingBertSquad.sh | 2 +- tests/model/BingBertSquad/run_BingBertSquad_sanity.sh | 2 +- tests/model/BingBertSquad/run_tests.sh | 2 +- tests/model/BingBertSquad/test_e2e_squad.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/_tutorials/bert-finetuning.md b/docs/_tutorials/bert-finetuning.md index 3014be18d682..f833acebde9a 100755 --- a/docs/_tutorials/bert-finetuning.md +++ b/docs/_tutorials/bert-finetuning.md @@ -10,14 +10,14 @@ In this tutorial we will be adding DeepSpeed to the BingBert model for the SQuAD If you don't already have a copy of the DeepSpeed repository, please clone in now and checkout the DeepSpeedExamples submodule the contains the BingBertSquad -example (DeepSpeedExamples/BingBertSquad) we will be going over in the rest of +example (DeepSpeedExamples/training/BingBertSquad) we will be going over in the rest of this tutorial. ```shell git clone https://github.com/microsoft/DeepSpeed cd DeepSpeed git submodule update --init --recursive -cd DeepSpeedExamples/BingBertSquad +cd DeepSpeedExamples/training/BingBertSquad ``` ### Pre-requisites diff --git a/docs/_tutorials/onebit-adam.md b/docs/_tutorials/onebit-adam.md index b1a8b5369761..e66bba3f818b 100644 --- a/docs/_tutorials/onebit-adam.md +++ b/docs/_tutorials/onebit-adam.md @@ -136,7 +136,7 @@ You can also use a pre-trained BERT model checkpoint from either DeepSpeed, [Hug ### 2.1 Running BingBertSQuAD with DeepSpeed and 1-bit Adam -We provide example scripts under [DeepSpeedExamples/BingBertSquad/1-bit_adam/](https://github.com/microsoft/DeepSpeedExamples/tree/master/BingBertSquad/1-bit_adam). There are 3 sets of scripts corresponding to NCCL-based implementation, MPI-based implementation on Ethernet systems, and MPI-based implementation on InfiniBand systems. For MPI-based implementation, we provide both example scripts when launching with deepspeed or mpirun. +We provide example scripts under [DeepSpeedExamples/training/BingBertSquad/1-bit_adam/](https://github.com/microsoft/DeepSpeedExamples/tree/master/training/BingBertSquad/1-bit_adam). There are 3 sets of scripts corresponding to NCCL-based implementation, MPI-based implementation on Ethernet systems, and MPI-based implementation on InfiniBand systems. For MPI-based implementation, we provide both example scripts when launching with deepspeed or mpirun.