-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Offload activation async support #156
Open
MayDomine
wants to merge
144
commits into
OpenBMB:dev
Choose a base branch
from
MayDomine:tp_offload
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 132 commits
Commits
Show all changes
144 commits
Select commit
Hold shift + click to select a range
bcea035
using hooks to implement ZeRO and Checkpoint
7b080e7
async backward
be5f9d7
async forward
dea1781
merge upstream
05bc553
fix
bdf7087
save cuda_rng_state
6a366e3
fix
25ef84f
fix
768f209
fix
324e0dd
remove __call__
0f4ddb5
refactor code structure
76c5c26
pipeline
16c0922
for low version
2d35ba0
for low torch version
bc48d83
for checkpoint
bd61071
remove unused code
de25455
remove duplicate code
fde122f
fix pipeline; checkpoint support low version
a897ad4
fix pipeline; checkpoint support low version
ca50795
merge remote
ec8385b
fix indent
9877a81
pipe support low version
28993b5
custom linear for zero3
4d43952
merge origin
e4eaebf
resolve conflict
cba7c55
resolve conflict
839a976
use torch.utils.checkpoint.checkpoint
d5bbf1a
custom hook
e92d0ef
optimize code structure
6ba753e
for hidden_state
b0a0da9
for input.requires_grad is False
f4a0e0b
fix
8faff0f
pipeline support return hidden_state
26c8c94
fix args
b7d1c8c
fix test
4303575
CheckpointBlock -> BMTBlock
8061b66
reset block name
845f210
pipeline support batch_related
0b14fe5
remove use_checkpoint from init_distributed
12e51e1
test
726aa2f
test for transformer and attn
MayDomine ae56de8
for requires_grad
27ae2b7
for requires_grad
fdc8231
fix for arg is not tensor
b0f7154
fix for arg is not a tensor
420b626
add test
b843489
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
ebc269f
merge enhance_ckp
2f1e766
enhance ckp
4336437
Merge branch 'hook' into test
683707d
test
4013502
test
1c532d4
test
1e993c6
refactor code
24d0f59
mv linear to bmt.nn.linear
ff72e66
for enhance_ckp
1fbf3b2
fix for all input not grad
ace5216
fix pre_module
52cd4e2
fix pre_module
0b0bd0b
fix for all input no grad
05b49f8
fix for all input no grad
98d5b32
activation offloading
MayDomine bd42ee4
Merge branch 'main' of https://github.com/OpenBMB/BMTrain into test
MayDomine c16127a
offload new version
MayDomine 64eb672
Merge branch 'main' of https://github.com/OpenBMB/BMTrain into hook
4861ec8
save_for_backward hook
MayDomine fc81971
offloading bug fix
MayDomine 88b5bd3
fix reentrant
9c2e47d
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
e93e6dc
Merge branch 'dev' into hook
fd49311
refactor CheckpointBlock
221bdc3
refactor pipe
76f74e5
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
9c63407
fix all input no grad
f72fcfc
fix hiddenstate
ebdf519
fix test
780ca20
fix
6df85e7
remove unused import
bb482d6
fix pre_module
1010d26
recovery some code
b580530
add test_no_grad.py
767a875
test unroll block list
d19a627
fix test_fp32
bf986a7
cross_entropy support fp32
b28cb3f
offload context
MayDomine 5e24661
Merge branch 'hook' into test
MayDomine f94afa2
cpm live for offloading test
MayDomine bc65a2e
Better hack for offload
MayDomine 76f8162
fix OFFLOAD _mode bug
MayDomine 0d4ea37
fix is_first_layer
6ffcf5c
Fix async bug
MayDomine 3063afb
tensor parallel
bdc1ed9
Merge branch 'fix_first_layer' into tensor_parallel
8648f5b
rm unused code
763b408
refactor nccl group; remove partition_modules in pipe_layer.py
4c50567
fix by review comment
825139c
fix topology
f08bc83
offload event wait
MayDomine 82e975c
Merge branch 'dev' of https://github.com/OpenBMB/BMTrain into offload
MayDomine 4ff0f41
fix topology
a5d7ba6
fix
2951d70
use ParallelEmbedding
2f4ca8a
Offload Correct Version
MayDomine 39319e1
overlap parallel linear backward
df3fd8f
add tp_comm_stream
99efba3
fix tp
Achazwl 85dd5ab
Merge branch 'tensor_parallel' into tp
Achazwl 76abcb4
Merge pull request #1 from Achazwl/tp
9f8a5b4
new hook storage
MayDomine 725fe57
Offload storage function fix
MayDomine ec63e1b
storage dont release fix
MayDomine f1b4fd7
fix load_state_dict
677a316
test parallel linear
743253e
mv zero_level to CheckpointBlock
8493828
use dataptr as storage id
MayDomine 4e8c462
merge dev
23d7bef
Merge with dev
MayDomine 8919f18
fix prev confilct
MayDomine 604ddfe
fix overlap
0aee817
gather once in atten
bd0bad0
fix sub grad_input in parallel linear
50cdcaf
Merge branch 'dev' into tensor_parallel
zkh2016 15460b6
fix gather_output
0e0e05c
Merge branch 'tensor_parallel' of https://github.com/zkh2016/BMTrain …
66a04f3
better overlap
MayDomine b44a62e
fix train.py
b208e9f
rm unused code
MayDomine 30090ef
Merge branch 'offload' into tp
MayDomine de32538
fix tp feature
MayDomine c64da6f
update pre module interface
MayDomine 1f3b5a3
Merge branch 'dev' of https://github.com/OpenBMB/BMTrain into tp
MayDomine ae99c77
Merge branch 'dev' of https://github.com/OpenBMB/BMTrain into tp
MayDomine 5819ce4
.gitignore back
MayDomine 832141a
example back to origin
MayDomine 8bd6475
delete test file
MayDomine a7270e3
format
MayDomine 47905b8
version modify
MayDomine 568b02a
reformat code
MayDomine b249adc
fix pre module
MayDomine a1b8eee
modify comment
MayDomine 92b8630
dont expose use offload interface outside
MayDomine 1fac581
print tools
MayDomine f66c162
high priority for offload stream
MayDomine aef4899
fix import
MayDomine File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,157 @@ | ||
import torch | ||
from .global_var import config | ||
from .checkpointing import CheckpointBlockContext | ||
from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些没用到,可以去掉 |
||
from collections import deque,OrderedDict | ||
from contextlib import contextmanager | ||
from .utils import round_up | ||
|
||
class Offload_Dict: | ||
|
||
def __init__(self): | ||
self._offload_dict = OrderedDict() | ||
|
||
def add(self, tensor): | ||
tensor = tensor.contiguous() | ||
tensor_id = id(tensor) | ||
data_ptr = tensor.storage().data_ptr() | ||
if data_ptr not in self._offload_dict: | ||
self._offload_dict[data_ptr] = {} | ||
self._offload_dict[data_ptr]["stor"] = tensor.storage() | ||
self._offload_dict[data_ptr]["size"] = tensor.storage().size() | ||
self._offload_dict[data_ptr]["dtype"] = tensor.storage().dtype | ||
self._offload_dict[data_ptr]["tensors"] = {} | ||
|
||
self._offload_dict[data_ptr]["tensors"][id(tensor)] = {} | ||
self._offload_dict[data_ptr]["tensors"][id(tensor)]["numel"] = tensor.numel() | ||
self._offload_dict[data_ptr]["tensors"][id(tensor)]['dtype'] = tensor.dtype | ||
self._offload_dict[data_ptr]["tensors"][id(tensor)]['offset'] = tensor.storage_offset() | ||
self._offload_dict[data_ptr]["tensors"][id(tensor)]['tensor'] = tensor | ||
self._offload_dict[data_ptr]["tensors"][id(tensor)]["shape"] = tensor.shape | ||
self._device = "cuda" | ||
return (data_ptr,tensor_id) | ||
|
||
def get_total(self): | ||
fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) | ||
fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) | ||
return fp16_total,fp32_total | ||
|
||
def make_cpu_storage(self): | ||
fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) | ||
fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) | ||
fp16_storage = torch.HalfStorage(fp16_total).pin_memory() | ||
fp32_storage = torch.FloatStorage(fp32_total).pin_memory() | ||
self.fp16_storage = fp16_storage | ||
self.fp32_storage = fp32_storage | ||
self.fp16_total = fp16_total | ||
self.fp32_total = fp32_total | ||
|
||
def get(self, key): | ||
data_ptr, tensor_id = key | ||
return self._offload_dict[data_ptr]['tensors'][tensor_id]["tensor"] | ||
|
||
def pop_all(self): | ||
self._offload_dict.clear() | ||
|
||
def h2d_memcpy(self): | ||
fp16_storage_cuda = self.fp16_storage.cuda(non_blocking=True) | ||
fp32_storage_cuda = self.fp32_storage.cuda(non_blocking=True) | ||
for key,val in self._offload_dict.items(): | ||
for id_val in val['tensors'].values(): | ||
id_val['tensor'] = torch.tensor([], dtype=id_val['dtype'],device=fp16_storage_cuda.device) | ||
if id_val['dtype'] == torch.float16: | ||
id_val['tensor'].set_(fp16_storage_cuda, id_val['abs_offset'], id_val['shape']) | ||
elif id_val['dtype'] == torch.float32: | ||
id_val['tensor'].set_(fp32_storage_cuda, id_val['abs_offset'], id_val['shape']) | ||
|
||
def record_stream(self, stream): | ||
for key, val in self._offload_dict.items(): | ||
for id_val in val['tensors'].values(): | ||
id_val['tensor'].record_stream(stream) | ||
|
||
def d2h_memcpy(self): | ||
fp16_offset = 0 | ||
fp32_offset = 0 | ||
fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) | ||
fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) | ||
assert fp16_total <= self.fp16_total | ||
assert fp32_total <= self.fp32_total | ||
fp16_storage = self.fp16_storage | ||
fp32_storage = self.fp32_storage | ||
for key,val in self._offload_dict.items(): | ||
assert val['dtype'] in [torch.float16, torch.float32] | ||
storage = fp16_storage if val['dtype'] == torch.float16 else fp32_storage | ||
offset = fp16_offset if val['dtype'] == torch.float16 else fp32_offset | ||
for id_val in val['tensors'].values(): | ||
cpu_tensor = torch.tensor([], dtype=id_val['dtype'], device="cpu") \ | ||
.set_(storage, offset+id_val['offset'], id_val['shape']) | ||
id_val["abs_offset"] = offset+id_val['offset'] | ||
id_val['tensor'] = cpu_tensor.copy_(id_val['tensor'], non_blocking=True) | ||
if val['dtype'] == torch.float16: | ||
fp16_offset += val['size'] | ||
else: | ||
fp32_offset += val['size'] | ||
val['stor'] = None | ||
|
||
def find_pre_module_helper(m): | ||
if len(m) == 0: | ||
return None | ||
if m._mode == "OFFLOAD": | ||
return m | ||
else: | ||
return find_pre_module_helper(m.pre_module()) | ||
|
||
def offload_wrapper(offload_dict): | ||
def pack_hook(tensor): | ||
if isinstance(tensor, torch.nn.Parameter): | ||
return (tensor,) | ||
elif tensor.dtype not in [torch.float16]: | ||
return (tensor,) | ||
else: | ||
key = offload_dict.add(tensor) | ||
return (tensor.device, key) | ||
def unpack_hook(packed): | ||
if len(packed) == 2: | ||
device, key = packed | ||
tensor = offload_dict.get(key) | ||
assert tensor.device == device | ||
return tensor | ||
else: | ||
tensor, = packed | ||
return tensor | ||
return pack_hook, unpack_hook | ||
|
||
def offload_pre_hook(module, input): | ||
if hasattr(module, "_offload_hook"): | ||
pack_hook, unpack_hook = module._offload_hook | ||
torch._C._autograd._push_saved_tensors_default_hooks( | ||
pack_hook, unpack_hook | ||
) | ||
|
||
def offload_post_hook(module, input, output): | ||
if hasattr(module, "_offload_hook"): | ||
torch._C._autograd._pop_saved_tensors_default_hooks() | ||
|
||
def zero_pre_forward(module, inputs): | ||
enter = True | ||
pipe = False | ||
if module._mode == "OFFLOAD": | ||
if not hasattr(module, "_offload_dict"): | ||
module._offload_dict = Offload_Dict() | ||
pack_hook, unpack_hook = offload_wrapper(module._offload_dict) | ||
if module.offload_level == 1: | ||
for n, m in module.named_modules(): | ||
if m.__class__.__name__ == "Linear" and not hasattr(m, "_offload_hook"): | ||
m._offload_hook = (pack_hook, unpack_hook) | ||
m.register_forward_pre_hook(offload_pre_hook) | ||
m.register_forward_hook(offload_post_hook) | ||
elif module.offload_level == 2: | ||
if not hasattr(module, "_offload_hook"): | ||
module._offload_hook = (pack_hook, unpack_hook) | ||
torch._C._autograd._push_saved_tensors_default_hooks( | ||
pack_hook, unpack_hook | ||
) | ||
|
||
if module._mode == "PIPE": | ||
enter = module._micro_idx == 0 | ||
pipe = True | ||
|
@@ -25,14 +172,42 @@ def zero_post_forward(module, inputs, outputs): | |
exit = True | ||
if module._mode == "PIPE": | ||
exit = module._micro_idx == config['micros'] - 1 | ||
|
||
elif module._mode == "OFFLOAD": | ||
torch.cuda.current_stream().record_event(module.calc_event) | ||
pre_offload_module = find_pre_module_helper(module.pre_module()) | ||
if pre_offload_module is not None: | ||
torch.cuda.current_stream().wait_event(pre_offload_module.offload_event) | ||
with torch.cuda.stream(config["offload_stream"]): | ||
config["offload_stream"].wait_event(module.calc_event) | ||
if not hasattr(module._offload_dict, "fp16_storage"): | ||
module._offload_dict.make_cpu_storage() | ||
module._offload_dict.record_stream(config["offload_stream"]) | ||
module._offload_dict.d2h_memcpy() | ||
if len(module._next_module) > 0: | ||
config["offload_stream"].record_event(module.offload_event) | ||
if module.offload_level == 2: | ||
torch._C._autograd._pop_saved_tensors_default_hooks() | ||
if exit: | ||
module._forward_block_ctx.exit(forward_flag) | ||
module._ref_count += 1 | ||
|
||
def zero_pre_backward(module, grad_outputs): | ||
backward_flag = 2 if module._zero_level == 2 else 0 | ||
if module._mode != "PIPE": | ||
if module._mode == "OFFLOAD" or (len(module._next_module) == 0): | ||
if len(module._next_module) != 0: | ||
current_stream = torch.cuda.current_stream() | ||
current_stream.wait_event(module.offload_event) | ||
pre_module = find_pre_module_helper(module.pre_module()) | ||
if pre_module is not None: | ||
pre_module._on_device = True | ||
with torch.cuda.stream(config["offload_stream"]): | ||
if (len(module._next_module) != 0): | ||
torch.cuda.current_stream().wait_event(module.calc_event) | ||
pre_module._offload_dict.h2d_memcpy() | ||
torch.cuda.current_stream().record_event(pre_module.offload_event) | ||
if (len(module._next_module) != 0): | ||
module._offload_dict.record_stream(current_stream) | ||
module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) | ||
module._backward_block_ctx.enter(backward_flag, True) | ||
if not module._is_last_layer: | ||
|
@@ -45,6 +220,10 @@ def zero_pre_backward(module, grad_outputs): | |
def zero_post_backward(module, grad_inputs, grad_outputs): | ||
backward_flag = 2 if module._zero_level == 2 else 0 | ||
if module._mode != "PIPE": | ||
if module._mode == "OFFLOAD": | ||
module._on_device = False | ||
module._offload_dict.pop_all() | ||
torch.cuda.current_stream().record_event(module.calc_event) | ||
if module._is_first_layer: | ||
module.backward_release(backward_flag) | ||
else: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以直接去掉,现在不需要区分ZERO和BLOCK了