-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpointing support for transformer type models #247
Conversation
@hariharan-devarajan ready for you to review again. I added two other features since last time we talked:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there.
if self.args.hidden_size <= 0: | ||
return 0 | ||
head_size = self.args.hidden_size//self.args.num_attention_heads | ||
dim_kv = head_size * self.args.num_kv_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missed this. dim_kv
mlp_4h_to_h = self.args.ffn_hidden_size*self.args.hidden_size | ||
weight = self.args.hidden_size | ||
lm_head = embedding | ||
return embedding + (input_norm + qkv + dense + layer_norm + mlp_h_to_4h + mlp_4h_to_h)*self.args.num_layers + weight + lm_head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is qkv
and mlp_h_to_4h
? if mlp_h_to_4h
is too big then at least add a line comment on what it is, when it is defined.
|
||
def get_layer_parameters(self, layer_index): | ||
head_size = self.args.hidden_size//self.args.num_attention_heads | ||
dim_kv = head_size * self.args.num_kv_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full form dim_kv
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thank you for all the changes.
In this PR, we addressed the issue that people have to manually input layer parameters and optimization groups in the checkpointing. #248