Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shuffle Tokenized Data #290

Merged
merged 13 commits into from
Jan 21, 2025
24 changes: 24 additions & 0 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel, TrainingReportGenerator
from modalities.dataloader.shuffle_tokenized_data import shuffle_tokenized_data
from modalities.evaluator import Evaluator
from modalities.gym import Gym
from modalities.logging_broker.message_broker import MessageBroker
Expand Down Expand Up @@ -181,6 +182,29 @@ def CMD_entry_point_merge_packed_data(src_paths: list[Path], target_path: Path):
merge_packed_data_files(src_paths=src_paths, target_path=target_path)


@data.command(name="shuffle_tokenized_data")
@click.option(
"--input_data_path",
type=click_pathlib.Path(exists=False),
required=True,
help="Path to a tokenized file (.pbin).",
)
@click.option(
"--batch-size", type=int, default=100, show_default=True, help="Number of documents to process per batch."
)
def shuffle_tokenized_data_entrypoint(input_data_path: Path, batch_size: int) -> None:
mali-git marked this conversation as resolved.
Show resolved Hide resolved
"""Entrypoint for shuffling tokenized data.

Args:
input_data_path (Path): The path to the input tokenized data (.pbin).
batch_size (int): The size of the batches to shuffle.

Returns:
None
"""
shuffle_tokenized_data(input_data_path=input_data_path, batch_size=batch_size)


class Main:
"""Main class that orchestrates the training process."""

Expand Down
92 changes: 92 additions & 0 deletions src/modalities/dataloader/shuffle_tokenized_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pickle
import random
from pathlib import Path

from modalities.dataloader.create_packed_data import EmbeddedStreamData


def _process_batch(
batch: list[tuple[int, int]], data: bytes, start_position: int
) -> tuple[bytes, list[tuple[int, int]]]:
"""Process a batch of index entries to extract documents and create a new index.

Args:
batch (list[tuple[int, int]]): List of index entries [(start, length), ...].
data (bytes): Byte stream of the entire data loaded in memory.
start_position (int): The starting position for this batch in the byte stream.

Returns:
tuple[bytes, list[tuple[int, int]]]: A tuple containing the processed data (bytes)
and the new index [(position, length), ...].
"""
processed_data = []
new_index = []

current_position = start_position

for start, length in batch:
# Access the data slice directly from the in-memory bytes
document = data[start : start + length]
processed_data.append(document) # Already bytes

# Record the current position and length in the new index
new_index.append((current_position, length))
current_position += length

return b"".join(processed_data), new_index


def shuffle_tokenized_data(input_data_path: Path, batch_size: int) -> None:
"""Shuffle data and index segments loaded fully into memory.
Shuffled data is written to a new file with the postfix "_shuffled".

mali-git marked this conversation as resolved.
Show resolved Hide resolved
Args:
input_data_path (Path): Path to the tokenized data (.pbin).
batch_size (int): Number of documents to process per batch.

le1nux marked this conversation as resolved.
Show resolved Hide resolved
Returns:
None
"""
# Step 1: Load the entire data into memory
with input_data_path.open("rb") as f:
# Read the header
data_section_length_in_bytes = f.read(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES)
data_len = int.from_bytes(data_section_length_in_bytes, byteorder="little")

token_size_as_bytes = f.read(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES)

# Load the data
data = f.read(data_len)

# Load the index
pkl_encoded_index = f.read()
index_base = pickle.loads(pkl_encoded_index)
le1nux marked this conversation as resolved.
Show resolved Hide resolved

# Step 2: Shuffle the index
random.shuffle(index_base)

# Step 3: Divide the shuffled index into batches
batches = [index_base[i : i + batch_size] for i in range(0, len(index_base), batch_size)]
mali-git marked this conversation as resolved.
Show resolved Hide resolved

# Step 4: Prepare the output file
stem = input_data_path.stem
suffix = input_data_path.suffix
output_data_path = input_data_path.with_name(f"{stem}_shuffled{suffix}")
mali-git marked this conversation as resolved.
Show resolved Hide resolved

header_data = data_section_length_in_bytes + token_size_as_bytes

with output_data_path.open("wb") as f:
# Write the header data
f.write(header_data)
current_position = 0
final_index = []

# Process and write each batch sequentially
for batch in batches:
data_segment, new_index = _process_batch(batch, data, current_position)
f.write(data_segment)
final_index.extend(new_index)
current_position += len(data_segment)

# Write the final index to the file
f.write(pickle.dumps(final_index))
77 changes: 77 additions & 0 deletions tests/dataloader/test_shuffle_tokenized_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pickle

from modalities.dataloader.create_packed_data import EmbeddedStreamData
from modalities.dataloader.shuffle_tokenized_data import _process_batch, shuffle_tokenized_data


def test_process_batch_with_embedded_stream_with_memmap(tmp_path):
# Create a temporary file
file_path = tmp_path / "test_data.pbin"
data = b"IloveModalities" # Example data

with open(file_path, "wb") as f:
f.write(data)

# Load the data into memory
with open(file_path, "rb") as f:
in_memory_data = f.read()

# Define a batch
batch = [(0, 1), (1, 4), (5, 10)]

# Call the function
new_data, new_index = _process_batch(batch=batch, data=in_memory_data, start_position=0)

# Validate the result
expected_data = b"IloveModalities"
expected_index = [(0, 1), (1, 4), (5, 10)]
assert (new_data, new_index) == (expected_data, expected_index)


def test_shuffle_tokenized_data(tmp_path):
# Create test input data
data = b"IloveModalities"
data_section_length_as_bytes = len(data).to_bytes(
EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little"
)
token_size_in_bytes = 4
mali-git marked this conversation as resolved.
Show resolved Hide resolved
token_size_as_bytes = token_size_in_bytes.to_bytes(
EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little"
)
index = [(0, 1), (1, 4), (5, 10)]

# Prepare the input file
input_path = tmp_path / "input.pbin"
with input_path.open("wb") as f:
f.write(data_section_length_as_bytes)
f.write(token_size_as_bytes)
f.write(data)
f.write(pickle.dumps(index))

for batch_size in [1, 2, 3]:
# Call shuffle_tokenized_data
output_path = tmp_path / "input_shuffled.pbin"
shuffle_tokenized_data(input_path, batch_size=batch_size)

# Validate the output
assert output_path.is_file()

with output_path.open("rb") as f:
# Validate header and data
data_section_length_as_bytes = f.read(len(data_section_length_as_bytes))
assert data_section_length_as_bytes == data_section_length_as_bytes
mali-git marked this conversation as resolved.
Show resolved Hide resolved
assert f.read(len(token_size_as_bytes)) == token_size_as_bytes
data_len = int.from_bytes(data_section_length_as_bytes, byteorder="little")
data_written = f.read(data_len)

# Validate the shuffled index
written_index = pickle.loads(f.read())

# Extract substrings from the data using written_index
extracted_substrings = [data_written[start : start + length] for start, length in written_index]

# Verify that these substrings match the original defined ones
original_substrings = [data[start : start + length] for start, length in index]

# Ensure that extracted substrings are a valid permutation of original substrings
assert sorted(extracted_substrings) == sorted(original_substrings)
Loading