Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ ENV PIP_CONSTRAINT=""
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
# Using varlen_mamba for variable length sequence support
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring varlen mamba for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but it's mission critical

# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
18 changes: 18 additions & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
images: list[torch.Tensor] | None = None
image_positions: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None

Expand All @@ -49,12 +51,28 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
has_images = False
batch_images = []
for sample in batch:
if sample.images is not None:
batch_images.append([torch.from_numpy(image) for image in sample.images])
has_images = True
else:
batch_images.append([])
batch_image_positions = []
for sample in batch:
if sample.image_positions is not None:
batch_image_positions.append(torch.from_numpy(sample.image_positions))
else:
batch_image_positions.append([])
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
images=batch_images if has_images else None,
image_positions=batch_image_positions if has_images else None,
)


Expand Down
13 changes: 12 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class GPTSamplingParameters(SamplingParameters):
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
truncate_documents: bool = True
patch_size: int | None = None
max_image_size: int | None = None
image_break_token: int | None = None
image_end_token: int | None = None
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
Expand Down Expand Up @@ -138,11 +142,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
desc="Expected number of tokens in the dataset.",
hint=FieldHint.optional,
)
num_pixels: int | None = Field(
default=None,
desc="Expected number of pixels in the dataset.",
hint=FieldHint.optional,
)

def build(self) -> "GPTMemmapDataset":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)
return GPTMemmapDataset(
str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels
)


@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"})
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def _fim_permute_sequence(
middle = contents[boundaries[0] : boundaries[1]]
suffix = contents[boundaries[1] :]

prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64)
middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64)
suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64)
prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64)
middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64)
suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64)

# here we truncate each given segment to fit the same length as it was before
# A consequence is that we never reach the end of a file?
Expand Down
26 changes: 24 additions & 2 deletions fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset":

return GPTSampledIndexedDataset(self, sampling)

@property
@abc.abstractmethod
def has_images(self) -> bool:
"""
Whether the dataset contains images.
This is used to determine whether to use image-related fields in the sampled data.
"""


class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset):
"""
Expand All @@ -40,11 +48,16 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe

def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return self._dataset.get_document_sizes()[self._begin : self._end]
doc_sizes, im_sizes = self._dataset.get_document_sizes()
return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([])

def get_document_size(self, index: int) -> int:
return self._dataset.get_document_size(self._begin + index)

@property
def has_images(self) -> bool:
return self._dataset.has_images


class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset
Expand All @@ -53,8 +66,17 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](

def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])
# return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])
sizes = [dataset.get_document_sizes() for dataset in self._datasets]
return (
np.concatenate([size[0] for size in sizes]),
np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]),
)

def get_document_size(self, index: int) -> int:
dataset = np.searchsorted(self._dataset_splits[1:], index, side="right")
return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item())

@property
def has_images(self) -> bool:
return any(dataset.has_images for dataset in self._datasets)
Loading
Loading