You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm just opening a discussion as I can't seem to get this to work using pytorch without using insane amount of VRAM for no apparent reason.
I have got it working using the triton kernels though:
tl.store(logits_ptr+col_offsets, dloss*y, mask=mask)
# Zero out the gradient for all the "special" Cohere tokeniser tokens# NOTE: This should freeze the probability of generating these regardless of our fine-tuning dataset# SEE: https://huggingface.co/CohereForAI/c4ai-command-r-v01/blob/main/tokenizer.json# SEE: https://huggingface.co/spaces/Xenova/the-tokenizer-playgroundzero_mask= ((col_offsets<=7) | (col_offsets>=255000)) & (col_offsets<VOCAB_SIZE)
tl.store(logits_ptr+col_offsets, 0.0, mask=zero_mask)
or:
# Zero out the gradient for all the "special" Cohere tokeniser tokens# NOTE: This should freeze the probability of generating these regardless of our fine-tuning dataset# SEE: https://huggingface.co/CohereForAI/c4ai-command-r-v01/blob/main/tokenizer.json# SEE: https://huggingface.co/spaces/Xenova/the-tokenizer-playgroundy=tl.where(
(col_offsets<=7) | (col_offsets>=255000),
0.0, # Set gradient to zero for special tokensy# Keep existing gradient calculation otherwise
)
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.ifDO_LOGIT_SCALING:
# d/dx [s * x] = sy=LOGIT_SCALE*ypasstl.store(logits_ptr+col_offsets, dloss*y, mask=mask)
both methods work to avoid hurting the special tokens for the Cohere models specifically and use approximately the same small amount of extra VRAM for the masks (which are 4096 elements in length due to the chunking).
It isn't reflected in the loss, but this is equivalent to assuming that the model's output for the special tokens is always correct, and in essence it stops the slow drift downwards of these tokens if they aren't properly represented in the fine-tuning data we are using (eg: the "Instruct-Storywriter" method, no tool use, etc).
If this wasn't such a painful hack then you could also use this to offset certain tokens that you do/don't want to appear more/less often in a similar way to logit-bias is used at inference time (except you would offset in the opposite direction like the "focal loss star" method is doing).
I tried to write a "chunked" version similar to your entropy_fn code and called inside of with torch.no_grad(), but it still uses ~10GB more VRAM for whatever reason???
I can't seem to figure it out and it does make me wonder if the logit_scale parameter in LmHeadPipe might be causing a similarly huge chunk of VRAM to get used?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm just opening a discussion as I can't seem to get this to work using pytorch without using insane amount of VRAM for no apparent reason.
I have got it working using the triton kernels though:
or:
both methods work to avoid hurting the special tokens for the Cohere models specifically and use approximately the same small amount of extra VRAM for the masks (which are 4096 elements in length due to the chunking).
It isn't reflected in the loss, but this is equivalent to assuming that the model's output for the special tokens is always correct, and in essence it stops the slow drift downwards of these tokens if they aren't properly represented in the fine-tuning data we are using (eg: the "Instruct-Storywriter" method, no tool use, etc).
If this wasn't such a painful hack then you could also use this to offset certain tokens that you do/don't want to appear more/less often in a similar way to logit-bias is used at inference time (except you would offset in the opposite direction like the "focal loss star" method is doing).
I tried to write a "chunked" version similar to your
entropy_fn
code and called inside ofwith torch.no_grad()
, but it still uses ~10GB more VRAM for whatever reason???I can't seem to figure it out and it does make me wonder if the
logit_scale
parameter inLmHeadPipe
might be causing a similarly huge chunk of VRAM to get used?Beta Was this translation helpful? Give feedback.
All reactions