Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write multiple items to output file at once, in distributed data analyzer. #5169

Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
14f2bbe
added assert of torch vs numpy types
bm-synth Feb 9, 2024
796341d
first draft
bm-synth Feb 14, 2024
07aa4b4
reverted to original master
bm-synth Feb 14, 2024
815a789
added metric type accumulate_value_over_samples
bm-synth Feb 14, 2024
28a72e7
pre-commit
bm-synth Feb 14, 2024
e8dbf0b
Merge branch 'master' into distributed_data_analyzer
bm-synth Feb 14, 2024
ec3479f
Merge branch 'distributed_data_analyzer' of github.com:bm-synth/DeepS…
bm-synth Feb 14, 2024
38d7ce6
Update data_analyzer.py
bm-synth Feb 14, 2024
295fba6
added check for single node reduce. added barriers
bm-synth Feb 14, 2024
4144e42
more bug fixes
bm-synth Feb 14, 2024
a1e121c
new iteration, many bug fixes
bm-synth Feb 15, 2024
e045753
bug fixes
bm-synth Feb 15, 2024
3a89116
Merge branch 'master' into distributed_data_analyzer
bm-synth Feb 15, 2024
cdc838c
fixing previous commit
bm-synth Feb 15, 2024
ba34a55
Merge branch 'master' into distributed_data_analyzer
bm-synth Feb 16, 2024
5c07710
pre-commit
bm-synth Feb 16, 2024
87d7686
Merge branch 'distributed_data_analyzer' of github.com:bm-synth/DeepS…
bm-synth Feb 16, 2024
a634787
write sequentially to file
bm-synth Feb 16, 2024
848ffd5
Merge branch 'master' into distributed_data_analyzer
bm-synth Feb 16, 2024
ec59f08
fixes in sequential write
bm-synth Feb 16, 2024
832874c
Merge branch 'distributed_data_analyzer' of github.com:bm-synth/DeepS…
bm-synth Feb 16, 2024
ea0d65f
pre-commit hooks
bm-synth Feb 16, 2024
c6c9bc5
Merge branch 'master' into distributed_data_analyzer
bm-synth Feb 16, 2024
56a9533
added main as example
bm-synth Feb 18, 2024
b4d8654
Merge branch 'distributed_data_analyzer' of github.com:bm-synth/DeepS…
bm-synth Feb 18, 2024
676dc1a
Merge branch 'master' into distributed_data_analyzer
bm-synth Feb 18, 2024
6788af5
Update data_analyzer.py
bm-synth Feb 18, 2024
bd61d9c
first working version. idx files differ
bm-synth Feb 19, 2024
7ac5e45
Merge branch 'distributed_data_analyzer' of github.com:bm-synth/DeepS…
bm-synth Feb 19, 2024
8bf0e63
added missing static function
bm-synth Feb 19, 2024
e5a7eb0
removed/added breaklines to match base code
bm-synth Feb 19, 2024
3b8014f
corrected comment
bm-synth Feb 19, 2024
5a42687
imports
bm-synth Feb 19, 2024
cdaad36
removed main
bm-synth Feb 19, 2024
b3d4062
reverted main
bm-synth Feb 19, 2024
7cabfa2
bug fix in sample calculation
bm-synth Feb 19, 2024
62f68dd
added worker_an and num_worker to kwargs
bm-synth Feb 19, 2024
6d35e45
removed dist.initialize ()from DataAnalyzer.run_map_reduce
bm-synth Feb 19, 2024
be91d37
first iteration
bm-synth Feb 20, 2024
5fd0546
updated with add_items
bm-synth Feb 21, 2024
f5be5e1
added add_items
bm-synth Feb 21, 2024
1ccd3ba
Merge branch 'master' into write_multiple_items_at_once_in_distribute…
bm-synth Feb 21, 2024
9d5c171
Update indexed_dataset.py
bm-synth Feb 21, 2024
dc1dbb3
Merge branch 'master' into write_multiple_items_at_once_in_distribute…
conglongli Feb 21, 2024
db29942
formatting
bm-synth Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,17 +482,17 @@ def __init__(
dist.init_distributed()

# comm_group and worker_id+num_workers are mutually exclusive
if comm_group is not None:
self.comm_group = comm_group
self.num_workers = self.comm_group.size()
self.worker_id = self.comm_group.rank()
self.comm_group = comm_group
if self.comm_group is None:
# self.comm_group = deepspeed.utils.groups._clone_world_group()
self.num_workers = num_workers
self.worker_id = worker_id
else:
self.comm_group = groups._clone_world_group()
self.num_workers = self.comm_group.size()
self.worker_id = self.comm_group.rank()

if self.worker_id == 0:
logger.info(f"Data analyzer initialized with {self.num_workers} workers.")
logger.info(f"Distributed data analyzer initialized with {self.num_workers} workers.")

def run_map_reduce(self):

Expand Down Expand Up @@ -635,9 +635,18 @@ def file_write_ordered(self, tensor_list, fname, numpy_dtype):
# method to deserializes a buffer into rows of different lengths and write them to file
def write_buffer_to_file(buff, src, builder):
assert self.worker_id == 0, "only rank 0 can write to file"

# # write one buffer at a time
# for row_len in row_lens[src]:
# builder.add_item(buff[:row_len].cpu())
# buff = buff[row_len:]

# collect all buffers and write them all at once
buffer_list = []
for row_len in row_lens[src]:
builder.add_item(buff[:row_len].cpu())
buffer_list.append(buff[:row_len].cpu())
buff = buff[row_len:]
builder.add_items(buffer_list)

# 5. rank 0 prepares output folder and file
if self.worker_id == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,18 @@ def __init__(self, out_file, dtype=np.int64):
self._doc_idx = [0]

def add_item(self, tensor):
""" write the tensor to the file and update its size in the index"""
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)

def add_items(self, tensor_list):
""" write a list of tensors to the file and update their sizes in the index"""
np_arrays = [np.array(t.numpy(), dtype=self._dtype) for t in tensor_list]
self._data_file.writelines([arr.tobytes(order='C') for arr in np_arrays])
for arr in np_arrays:
self._sizes.append(arr.size)

def add_item_numpy(self, np_array):
if np_array.dtype != self._dtype:
np_array = np_array.astype(self._dtype)
Expand Down
Loading