Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update doc and notes for BMTrain. #192

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 207 additions & 108 deletions bmtrain/block_layer.py

Large diffs are not rendered by default.

32 changes: 23 additions & 9 deletions bmtrain/hook_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,70 @@
from .global_var import config
from .zero_context import ZeroContext


def zero_pre_forward(module, inputs):
"""Helper function for using ZeroContext to gather parmas before forward."""
enter = True
pipe = False
if module._mode == "PIPE":
enter = module._micro_idx == 0
pipe = True
if enter:
zero_level = module._zero_level
zero_level = module._zero_level
forward_flag = 1 if zero_level == 2 else 0
if zero_level == 2 and not module._need_release:
forward_flag = 2 # repeating forward in same layer
if module.all_param_no_grad: #only forward
forward_flag = 2 # repeating forward in same layer
if module.all_param_no_grad: # only forward
forward_flag = 0
module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe)
module._forward_block_ctx.enter(forward_flag)


def zero_post_forward(module, inputs, outputs):
"""Helper function for module _forwar_block_ctx weather exits after forward."""
forward_flag = 1 if module._zero_level == 2 else 0
if module.all_param_no_grad:
forward_flag = 0
exit = True
if module._mode == "PIPE":
exit = module._micro_idx == config['micros'] - 1
exit = module._micro_idx == config["micros"] - 1

if exit:
module._forward_block_ctx.exit(forward_flag)


def zero_pre_backward(module, grad_outputs):
"""Helper function for using ZeroContext to init grad buffer before backward."""
backward_flag = 2 if module._zero_level == 2 else 0
if module._mode != "PIPE":
module._backward_block_ctx = ZeroContext(module, module._layer_dict)
module._backward_block_ctx.enter(backward_flag, True)
module.release_next_module(backward_flag)
else:
if module._micro_idx == config['micros'] - 1:
module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True)
if module._micro_idx == config["micros"] - 1:
module._backward_block_ctx = ZeroContext(
module, module._layer_dict, pipe=True
)
module._backward_block_ctx.enter(backward_flag, True)


def zero_post_backward(module, grad_inputs, grad_outputs):
"""Helper function for module weather release after backward."""
backward_flag = 2 if module._zero_level == 2 else 0
if module._mode != "PIPE":
if module._is_first_layer:
if module._is_first_layer:
module.release(backward_flag)
else:
if module._micro_idx == 0:
module.release(backward_flag)
module._micro_idx -= 1


class OneStepNoGradFunc(torch.autograd.Function):
"""
requires_grad = False for all inputs
Requires_grad = False for all inputs.
"""

@staticmethod
def forward(ctx, module, placeholder, *x):
ctx.x = x
Expand All @@ -80,7 +92,8 @@ def backward(ctx, grads):
grads = []
for _ in x:
grads.append(None)
return None, None, *grads
return None, None, *grads


class PreHookFunc(torch.autograd.Function):
@staticmethod
Expand All @@ -94,6 +107,7 @@ def backward(ctx, *grads):
zero_post_backward(ctx.module, grads, None)
return None, *grads


class PostHookFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, module, *out):
Expand Down
Loading
Loading