How to train a layer before the first layer of existing models (e.g. gemma) using torchtune? #2766
-
In our cases, we wanna train a layer before the first layer of existing LLM models (e.g. gemma) using torchtune. The new added layer is used to process the image. What can we do to implement this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
So you want to add an image encoder layer? The standard way to do this is to create an Early Fusion model. You can define your custom model with def my_mm_model(...):
decoder = existingLLM(...)
encoder = nn.Conv2D(3, embed_dim, ...)
return EarlyFusionModel(
decoder,
{"image": encoder},
encoder_tokens={"image": 128256},
decoder_trainable=True,
encoder_trainable=True,
)
MyTransform(MyModelTokenizer, Transform):
def __init__(self, ...):
super.__init__(...)
self.image_transform = MyImageTransform(...)
def __call__(self, sample, inference=False):
images = []
for message in sample["messages"]:
if image in message.get_media():
image = self.transform_image(image, inference)
images.append(image)
sample["encoder_input"] = {"images": images}
sample = self.tokenizer(sample, inference=inference)
return sample You can see real examples of our Transforms here and here. Currently our recipes do expect the checkpoint to match the full model definition, so before you can train with this model, you'll want to create a new checkpoint. >>> model = my_mm_model(...)
>>> state_dict = torch.load("path/to/checkpoint", ...)
>>> model.decoder.load_state_dict(state_dict)
>>> torch.save(model.state_dict(), "path/to/mm/checkpoint") Then just use a MM dataset and call your custom model and tokenizer with the standard padded_collate_sft in your config and you should be good to go. If you want to do lora, you'll have to define another model with the lora version of the decoder and decide if you want a lora version of the encoder. You can see an example here If you want to do something more custom and unique, you'd have to take the model builder, insert your custom layer in the TransformerDecoder layers list. It can use the encoder_input option. This may break some optimization features, so use at your own risk. |
Beta Was this translation helpful? Give feedback.
So you want to add an image encoder layer? The standard way to do this is to create an Early Fusion model. You can define your custom model with