Skip to content

Conversation

jlamypoirier
Copy link
Collaborator

✨ Description

Workspace for dealing with the merge. Not intended to work yet.

# Using varlen_mamba for variable length sequence support
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring varlen mamba for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but it's mission critical


_ACTIVATION_FN_MAP = {
ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
ActivationType.gelu: torch.nn.functional.gelu,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't change this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we then add another one instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's possible, but is it absolutely needed? The two are usually safe to swap, so if it's just to convert HF models I'd rather just convert everything to tanh gelu.

@@ -0,0 +1,55 @@
import typing
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a MLP, let's make it one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only useful if this gives better speed because of fused kernels. Otherwise not worth it because we don't need to make this thing overly flexible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main benefit is we get the implementation for free, no new code needed. Flexibility is just a side effect.



@config_class()
class ImageNormalizationConfig(Config):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Belongs to dataset preprocessing

@@ -0,0 +1,281 @@
import math
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving most of this to dataset preprocessing

use_loss_masking_spans=self._parameters.use_loss_masking_spans,
)
token_ids.append(sample.token_ids)
start_pos = 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this mess is there so we know how many tokens the images will take, because preprocessing is done in the model. Moving to dataset preprocessing (before sampling) to simplify.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed.
Watch your attitude please, though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, commented for self-reference here


batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data)
batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data)
if self._config.vision_encoder.enabled:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary complexity, not needed once we move preprocessing

from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.layers.attention.config import AttentionKwargs
from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many conflicts with recent changes, redoing entirely. Moving to a separate model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

kv_channels = "vision_kv_channels"


class VisionEncoderKwargs:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary complexity. We can get those from the configs

Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}")
self._version = struct.unpack("<Q", stream.read(8))[0]
assert self._version in [1, 2, 3], f"Unsupported version for gpt_memmap dataset: {self._version}."
assert self._version in [1, 2, 3, 4], f"Unsupported version for gpt_memmap dataset: {self._version}."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not sustainable. Switching to a json header.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
)
)
offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes
Copy link
Collaborator Author

@jlamypoirier jlamypoirier Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inefficient. We can just store the image count cumsums instead to know which images belong to each sample. (Same for spans)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hdf could be an option...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's something to consider, but wouldn't be that beneficial right now because nearly all the complexity is in the data processing and preparation (which we need either way) rather than the actual file content.

loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.SUM, group=group)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward
from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig
from fast_llm.functional.config import (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring changes to model head and reverse kl for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. cc @oleksost

input_, output = grad_context
output.backward(output_grad)
return input_.grad
return input_.grad if input_.grad is not None else torch.zeros_like(input_)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad idea. This adds overhead to the first layer for all models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved a problem though. Was the most pragmatic thing to do for Soham

):
scaled_target = target / teacher_softmax_temperature

scaled_target = torch.clamp(target, min=-50, max=50)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,183 @@
import typing
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much simpler solution that can go directly in LanguageModelEmbedding and generalizes to other multimodal models:

  • Move token ids to kwargs, replace input by the image or other embeddings. Same as here, but do it in the base class too. Use a placeholder or None for the LLM non-existant input.
  • Use a pre-built map (tensors) from input_ (image embeddings) to LM embeddings, and copy image/multimodal embeddings with a one-liner.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all inputs, it's a bit limiting to have to declare one the "official" input and make the others kwargs. Btw, the generic case is multiple embeddings (images, audio, etc) and token id maps not just one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a single input is a restriction from the engine, so not much we can do ATM. The new version (#369) is generic in the sense that it copies any kinds of embeddings from previous layers into token ids regardless of their source. It's still not combining inputs if they come from multiple previous layers, but that will be relatively straightforward to do when we need it.

# Move to the next image in the input tensor
image_embedding_offset += num_patches

if self._use_absolute_position_embeddings:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Position embeddings should be masked?

if self._use_absolute_position_embeddings:
position_ids = split(position_ids, group=group, dim=0)
# mask padded tokens
token_mask = tokens >= 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a one-to-one correspondence between masked tokens and those replaced by image embeddings? If so this seems a bit redundant, and there are better solutions...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general no, masked tokens could be padding, too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need explicit masking for those though? Padding tokens are already masked for attention/ssm (through varlen) and loss (loss mask), so they don't contribute either way...

image_embedding_offset += num_patches
if image_embedding_offset > patch_end_offset:
break
embeddings = reduce_forward(embeddings, group)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is very wrong. The image embeddings are not vocab-parallel, so this incorrectly multiplies the image embeddings by TP.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this might be correct in a really complicated way that mixes the gathering of the sequence-parallel vision encoder outputs with their mapping to the lm embeddings. Not sure about non-sequence-parallel. Either way, this needs simplification.

patch_position_ids = torch.cat(patch_position_ids)
kwargs[VisionEncoderKwargs.image_patches] = patches
kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids
kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants