Skip to content

Commit ecd1918

Browse files
committed
clean history
1 parent 4db6271 commit ecd1918

File tree

49 files changed

+3307
-224
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3307
-224
lines changed

Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,17 @@ ENV PIP_CONSTRAINT=""
2929
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
3030
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
3131
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
32+
# Using varlen_mamba for variable length sequence support
3233
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
33-
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
34+
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba"
3435
# Copy dependency files with universal write permissions for all users.
3536
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
3637
COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/
3738
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
3839
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/
3940

4041
# Install dependencies within the virtual environment.
41-
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0
42+
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0
4243

4344
# Copy the remaining source code with universal write permissions.
4445
COPY --chmod=777 ./Megatron-LM Megatron-LM

fast_llm/data/data/gpt/data.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class GPTBatch:
3232
token_ids: torch.Tensor
3333
loss_masking_spans: list[torch.Tensor] | None = None
3434
sequence_lengths: list[torch.Tensor] | None = None
35+
images: list[torch.Tensor] | None = None
36+
image_positions: list[torch.Tensor] | None = None
3537
chosen_spans: list[torch.Tensor] | None = None
3638
rejected_spans: list[torch.Tensor] | None = None
3739

@@ -49,12 +51,28 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
4951
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
5052
if not sampling_parameters.cross_document_attention:
5153
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
54+
has_images = False
55+
batch_images = []
56+
for sample in batch:
57+
if sample.images is not None:
58+
batch_images.append([torch.from_numpy(image) for image in sample.images])
59+
has_images = True
60+
else:
61+
batch_images.append([])
62+
batch_image_positions = []
63+
for sample in batch:
64+
if sample.image_positions is not None:
65+
batch_image_positions.append(torch.from_numpy(sample.image_positions))
66+
else:
67+
batch_image_positions.append([])
5268
return GPTBatch(
5369
token_ids=torch.from_numpy(stacked_ids),
5470
loss_masking_spans=stacked_spans,
5571
sequence_lengths=sequence_lengths,
5672
chosen_spans=stacked_chosen_spans,
5773
rejected_spans=stacked_rejected_spans,
74+
images=batch_images if has_images else None,
75+
image_positions=batch_image_positions if has_images else None,
5876
)
5977

6078

fast_llm/data/dataset/gpt/config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ class GPTSamplingParameters(SamplingParameters):
7272
use_preference_loss_spans: bool = False
7373
cross_document_attention: bool = True
7474
truncate_documents: bool = True
75+
patch_size: int | None = None
76+
max_image_size: int | None = None
77+
image_break_token: int | None = None
78+
image_end_token: int | None = None
7579
# How many extra tokens to add to the sequence length.
7680
# This is used to provide labels even for the last tokens in the sequence.
7781
extra_tokens: int = 1
@@ -138,11 +142,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
138142
desc="Expected number of tokens in the dataset.",
139143
hint=FieldHint.optional,
140144
)
145+
num_pixels: int | None = Field(
146+
default=None,
147+
desc="Expected number of pixels in the dataset.",
148+
hint=FieldHint.optional,
149+
)
141150

142151
def build(self) -> "GPTMemmapDataset":
143152
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
144153

145-
return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)
154+
return GPTMemmapDataset(
155+
str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels
156+
)
146157

147158

148159
@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"})

fast_llm/data/dataset/gpt/fim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ def _fim_permute_sequence(
158158
middle = contents[boundaries[0] : boundaries[1]]
159159
suffix = contents[boundaries[1] :]
160160

161-
prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64)
162-
middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64)
163-
suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64)
161+
prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64)
162+
middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64)
163+
suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64)
164164

165165
# here we truncate each given segment to fit the same length as it was before
166166
# A consequence is that we never reach the end of a file?

fast_llm/data/dataset/gpt/indexed.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset":
3030

3131
return GPTSampledIndexedDataset(self, sampling)
3232

33+
@property
34+
@abc.abstractmethod
35+
def has_images(self) -> bool:
36+
"""
37+
Whether the dataset contains images.
38+
This is used to determine whether to use image-related fields in the sampled data.
39+
"""
40+
3341

3442
class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset):
3543
"""
@@ -40,11 +48,16 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe
4048

4149
def get_document_sizes(self) -> np.ndarray:
4250
# TODO: This can be really big.
43-
return self._dataset.get_document_sizes()[self._begin : self._end]
51+
doc_sizes, im_sizes = self._dataset.get_document_sizes()
52+
return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([])
4453

4554
def get_document_size(self, index: int) -> int:
4655
return self._dataset.get_document_size(self._begin + index)
4756

57+
@property
58+
def has_images(self) -> bool:
59+
return self._dataset.has_images
60+
4861

4962
class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
5063
ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset
@@ -53,8 +66,17 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
5366

5467
def get_document_sizes(self) -> np.ndarray:
5568
# TODO: This can be really big.
56-
return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])
69+
# return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])
70+
sizes = [dataset.get_document_sizes() for dataset in self._datasets]
71+
return (
72+
np.concatenate([size[0] for size in sizes]),
73+
np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]),
74+
)
5775

5876
def get_document_size(self, index: int) -> int:
5977
dataset = np.searchsorted(self._dataset_splits[1:], index, side="right")
6078
return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item())
79+
80+
@property
81+
def has_images(self) -> bool:
82+
return any(dataset.has_images for dataset in self._datasets)

0 commit comments

Comments
 (0)