-
Notifications
You must be signed in to change notification settings - Fork 36
[Workspace] Dev branch merge attempt #367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
3dbcab5
to
ecd1918
Compare
# Using varlen_mamba for variable length sequence support | ||
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" | ||
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" | ||
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignoring varlen mamba for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, but it's mission critical
|
||
_ACTIVATION_FN_MAP = { | ||
ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), | ||
ActivationType.gelu: torch.nn.functional.gelu, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't change this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we then add another one instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's possible, but is it absolutely needed? The two are usually safe to swap, so if it's just to convert HF models I'd rather just convert everything to tanh gelu.
@@ -0,0 +1,55 @@ | |||
import typing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a MLP, let's make it one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only useful if this gives better speed because of fused kernels. Otherwise not worth it because we don't need to make this thing overly flexible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main benefit is we get the implementation for free, no new code needed. Flexibility is just a side effect.
|
||
|
||
@config_class() | ||
class ImageNormalizationConfig(Config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Belongs to dataset preprocessing
@@ -0,0 +1,281 @@ | |||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving most of this to dataset preprocessing
use_loss_masking_spans=self._parameters.use_loss_masking_spans, | ||
) | ||
token_ids.append(sample.token_ids) | ||
start_pos = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All this mess is there so we know how many tokens the images will take, because preprocessing is done in the model. Moving to dataset preprocessing (before sampling) to simplify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed.
Watch your attitude please, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, commented for self-reference here
|
||
batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) | ||
batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) | ||
if self._config.vision_encoder.enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary complexity, not needed once we move preprocessing
from fast_llm.engine.inference.runner import InferenceRunner | ||
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel | ||
from fast_llm.layers.attention.config import AttentionKwargs | ||
from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too many conflicts with recent changes, redoing entirely. Moving to a separate model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
kv_channels = "vision_kv_channels" | ||
|
||
|
||
class VisionEncoderKwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary complexity. We can get those from the configs
Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") | ||
self._version = struct.unpack("<Q", stream.read(8))[0] | ||
assert self._version in [1, 2, 3], f"Unsupported version for gpt_memmap dataset: {self._version}." | ||
assert self._version in [1, 2, 3, 4], f"Unsupported version for gpt_memmap dataset: {self._version}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not sustainable. Switching to a json header.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, | ||
) | ||
) | ||
offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is inefficient. We can just store the image count cumsums instead to know which images belong to each sample. (Same for spans)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hdf could be an option...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's something to consider, but wouldn't be that beneficial right now because nearly all the complexity is in the data processing and preparation (which we need either way) rather than the actual file content.
loss = per_sample_loss.mean() | ||
if target_format != TargetFormat.labels and group is not None: | ||
all_reduce(loss, op=ReduceOp.MEAN, group=group) | ||
all_reduce(loss, op=ReduceOp.SUM, group=group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames | ||
from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward | ||
from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig | ||
from fast_llm.functional.config import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignoring changes to model head and reverse kl for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. cc @oleksost
input_, output = grad_context | ||
output.backward(output_grad) | ||
return input_.grad | ||
return input_.grad if input_.grad is not None else torch.zeros_like(input_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad idea. This adds overhead to the first layer for all models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved a problem though. Was the most pragmatic thing to do for Soham
): | ||
scaled_target = target / teacher_softmax_temperature | ||
|
||
scaled_target = torch.clamp(target, min=-50, max=50) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -0,0 +1,183 @@ | |||
import typing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much simpler solution that can go directly in LanguageModelEmbedding
and generalizes to other multimodal models:
- Move token ids to kwargs, replace input by the image or other embeddings. Same as here, but do it in the base class too. Use a placeholder or
None
for the LLM non-existant input. - Use a pre-built map (tensors) from input_ (image embeddings) to LM embeddings, and copy image/multimodal embeddings with a one-liner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are all inputs, it's a bit limiting to have to declare one the "official" input and make the others kwargs. Btw, the generic case is multiple embeddings (images, audio, etc) and token id maps not just one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a single input is a restriction from the engine, so not much we can do ATM. The new version (#369) is generic in the sense that it copies any kinds of embeddings from previous layers into token ids regardless of their source. It's still not combining inputs if they come from multiple previous layers, but that will be relatively straightforward to do when we need it.
# Move to the next image in the input tensor | ||
image_embedding_offset += num_patches | ||
|
||
if self._use_absolute_position_embeddings: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Position embeddings should be masked?
if self._use_absolute_position_embeddings: | ||
position_ids = split(position_ids, group=group, dim=0) | ||
# mask padded tokens | ||
token_mask = tokens >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a one-to-one correspondence between masked tokens and those replaced by image embeddings? If so this seems a bit redundant, and there are better solutions...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general no, masked tokens could be padding, too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need explicit masking for those though? Padding tokens are already masked for attention/ssm (through varlen) and loss (loss mask), so they don't contribute either way...
image_embedding_offset += num_patches | ||
if image_embedding_offset > patch_end_offset: | ||
break | ||
embeddings = reduce_forward(embeddings, group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is very wrong. The image embeddings are not vocab-parallel, so this incorrectly multiplies the image embeddings by TP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this might be correct in a really complicated way that mixes the gathering of the sequence-parallel vision encoder outputs with their mapping to the lm embeddings. Not sure about non-sequence-parallel. Either way, this needs simplification.
patch_position_ids = torch.cat(patch_position_ids) | ||
kwargs[VisionEncoderKwargs.image_patches] = patches | ||
kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids | ||
kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused?
✨ Description
Workspace for dealing with the merge. Not intended to work yet.