Skip to content

Commit

Permalink
prepare kbit on 4bit
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Aug 3, 2024
1 parent 778ca09 commit 40d974b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def __init__(self, **kwargs):
attn_implementation=getattr(self.hparams, "attn_implementation", None),
)

if self.load_in_8bit:
model_object = prepare_model_for_kbit_training(model_object)

# rebuild the training config, a bit cumbersome, but that's life
self.training_config = ExpertConfig.fromdict(kwargs)
self.training_config.vocab_size = (
model_object.get_input_embeddings().num_embeddings
)

if self.load_in_8bit or self.load_in_4bit:
model_object = prepare_model_for_kbit_training(model_object)

# init the transformer just with the modifier config, this avoids
# passing the whole training config to the modify_transformer func
self.modifier_config = ModifierConfig.from_training_config(self.training_config)
Expand All @@ -83,6 +83,7 @@ def forward(self, batch, reduction="mean"):
input_ids = batch["input_ids"]
labels = batch["labels"]

print(input_ids.shape[-1])
outputs = self.model.forward(input_ids, attention_mask=batch["attention_mask"])

# calculate loss, could also be done inside of the model
Expand Down

0 comments on commit 40d974b

Please sign in to comment.