Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offload activation async support #156

Open
wants to merge 144 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
bcea035
using hooks to implement ZeRO and Checkpoint
Jul 24, 2023
7b080e7
async backward
Jul 25, 2023
be5f9d7
async forward
Jul 25, 2023
dea1781
merge upstream
Jul 25, 2023
05bc553
fix
Jul 25, 2023
bdf7087
save cuda_rng_state
Jul 26, 2023
6a366e3
fix
Jul 27, 2023
25ef84f
fix
Jul 27, 2023
768f209
fix
Jul 27, 2023
324e0dd
remove __call__
Jul 31, 2023
0f4ddb5
refactor code structure
Jul 31, 2023
76c5c26
pipeline
Jul 31, 2023
16c0922
for low version
Jul 31, 2023
2d35ba0
for low torch version
Jul 31, 2023
bc48d83
for checkpoint
Jul 31, 2023
bd61071
remove unused code
Jul 31, 2023
de25455
remove duplicate code
Jul 31, 2023
fde122f
fix pipeline; checkpoint support low version
Aug 1, 2023
a897ad4
fix pipeline; checkpoint support low version
Aug 1, 2023
ca50795
merge remote
Aug 1, 2023
ec8385b
fix indent
Aug 1, 2023
9877a81
pipe support low version
Aug 2, 2023
28993b5
custom linear for zero3
Aug 2, 2023
4d43952
merge origin
Aug 3, 2023
e4eaebf
resolve conflict
Aug 3, 2023
cba7c55
resolve conflict
Aug 3, 2023
839a976
use torch.utils.checkpoint.checkpoint
Aug 3, 2023
d5bbf1a
custom hook
Aug 4, 2023
e92d0ef
optimize code structure
Aug 4, 2023
6ba753e
for hidden_state
Aug 4, 2023
b0a0da9
for input.requires_grad is False
Aug 4, 2023
f4a0e0b
fix
Aug 5, 2023
8faff0f
pipeline support return hidden_state
Aug 6, 2023
26c8c94
fix args
Aug 7, 2023
b7d1c8c
fix test
Aug 7, 2023
4303575
CheckpointBlock -> BMTBlock
Aug 8, 2023
8061b66
reset block name
Aug 8, 2023
845f210
pipeline support batch_related
Aug 8, 2023
0b14fe5
remove use_checkpoint from init_distributed
Aug 9, 2023
12e51e1
test
Aug 10, 2023
726aa2f
test for transformer and attn
MayDomine Aug 10, 2023
ae56de8
for requires_grad
Aug 10, 2023
27ae2b7
for requires_grad
Aug 10, 2023
fdc8231
fix for arg is not tensor
Aug 10, 2023
b0f7154
fix for arg is not a tensor
Aug 10, 2023
420b626
add test
Aug 10, 2023
b843489
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 10, 2023
ebc269f
merge enhance_ckp
Aug 11, 2023
2f1e766
enhance ckp
Aug 11, 2023
4336437
Merge branch 'hook' into test
Aug 11, 2023
683707d
test
Aug 11, 2023
4013502
test
Aug 11, 2023
1c532d4
test
Aug 12, 2023
1e993c6
refactor code
Aug 12, 2023
24d0f59
mv linear to bmt.nn.linear
Aug 12, 2023
ff72e66
for enhance_ckp
Aug 12, 2023
1fbf3b2
fix for all input not grad
Aug 14, 2023
ace5216
fix pre_module
Aug 14, 2023
52cd4e2
fix pre_module
Aug 14, 2023
0b0bd0b
fix for all input no grad
Aug 14, 2023
05b49f8
fix for all input no grad
Aug 14, 2023
98d5b32
activation offloading
MayDomine Aug 15, 2023
bd42ee4
Merge branch 'main' of https://github.com/OpenBMB/BMTrain into test
MayDomine Aug 15, 2023
c16127a
offload new version
MayDomine Aug 16, 2023
64eb672
Merge branch 'main' of https://github.com/OpenBMB/BMTrain into hook
Aug 16, 2023
4861ec8
save_for_backward hook
MayDomine Aug 16, 2023
fc81971
offloading bug fix
MayDomine Aug 17, 2023
88b5bd3
fix reentrant
Aug 17, 2023
9c2e47d
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 17, 2023
e93e6dc
Merge branch 'dev' into hook
Aug 18, 2023
fd49311
refactor CheckpointBlock
Aug 20, 2023
221bdc3
refactor pipe
Aug 20, 2023
76f74e5
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 20, 2023
9c63407
fix all input no grad
Aug 20, 2023
f72fcfc
fix hiddenstate
Aug 20, 2023
ebdf519
fix test
Aug 21, 2023
780ca20
fix
Aug 21, 2023
6df85e7
remove unused import
Aug 21, 2023
bb482d6
fix pre_module
Aug 21, 2023
1010d26
recovery some code
Aug 21, 2023
b580530
add test_no_grad.py
Aug 21, 2023
767a875
test unroll block list
Aug 21, 2023
d19a627
fix test_fp32
Aug 21, 2023
bf986a7
cross_entropy support fp32
Aug 21, 2023
b28cb3f
offload context
MayDomine Aug 21, 2023
5e24661
Merge branch 'hook' into test
MayDomine Aug 21, 2023
f94afa2
cpm live for offloading test
MayDomine Aug 21, 2023
bc65a2e
Better hack for offload
MayDomine Aug 22, 2023
76f8162
fix OFFLOAD _mode bug
MayDomine Aug 22, 2023
0d4ea37
fix is_first_layer
Aug 22, 2023
6ffcf5c
Fix async bug
MayDomine Aug 23, 2023
3063afb
tensor parallel
Aug 23, 2023
bdc1ed9
Merge branch 'fix_first_layer' into tensor_parallel
Aug 23, 2023
8648f5b
rm unused code
Aug 23, 2023
763b408
refactor nccl group; remove partition_modules in pipe_layer.py
Aug 24, 2023
4c50567
fix by review comment
Aug 24, 2023
825139c
fix topology
Aug 24, 2023
f08bc83
offload event wait
MayDomine Aug 24, 2023
82e975c
Merge branch 'dev' of https://github.com/OpenBMB/BMTrain into offload
MayDomine Aug 24, 2023
4ff0f41
fix topology
Aug 24, 2023
a5d7ba6
fix
Aug 24, 2023
2951d70
use ParallelEmbedding
Aug 24, 2023
2f4ca8a
Offload Correct Version
MayDomine Aug 24, 2023
39319e1
overlap parallel linear backward
Aug 24, 2023
df3fd8f
add tp_comm_stream
Aug 24, 2023
99efba3
fix tp
Achazwl Aug 24, 2023
85dd5ab
Merge branch 'tensor_parallel' into tp
Achazwl Aug 24, 2023
76abcb4
Merge pull request #1 from Achazwl/tp
Aug 24, 2023
9f8a5b4
new hook storage
MayDomine Aug 25, 2023
725fe57
Offload storage function fix
MayDomine Aug 25, 2023
ec63e1b
storage dont release fix
MayDomine Aug 25, 2023
f1b4fd7
fix load_state_dict
Aug 25, 2023
677a316
test parallel linear
Aug 25, 2023
743253e
mv zero_level to CheckpointBlock
Aug 25, 2023
8493828
use dataptr as storage id
MayDomine Aug 25, 2023
4e8c462
merge dev
Aug 25, 2023
23d7bef
Merge with dev
MayDomine Aug 25, 2023
8919f18
fix prev confilct
MayDomine Aug 25, 2023
604ddfe
fix overlap
Aug 25, 2023
0aee817
gather once in atten
Aug 25, 2023
bd0bad0
fix sub grad_input in parallel linear
Aug 25, 2023
50cdcaf
Merge branch 'dev' into tensor_parallel
zkh2016 Aug 26, 2023
15460b6
fix gather_output
Aug 26, 2023
0e0e05c
Merge branch 'tensor_parallel' of https://github.com/zkh2016/BMTrain …
Aug 26, 2023
66a04f3
better overlap
MayDomine Aug 26, 2023
b44a62e
fix train.py
Aug 26, 2023
b208e9f
rm unused code
MayDomine Aug 26, 2023
30090ef
Merge branch 'offload' into tp
MayDomine Aug 26, 2023
de32538
fix tp feature
MayDomine Aug 29, 2023
c64da6f
update pre module interface
MayDomine Aug 29, 2023
1f3b5a3
Merge branch 'dev' of https://github.com/OpenBMB/BMTrain into tp
MayDomine Aug 29, 2023
ae99c77
Merge branch 'dev' of https://github.com/OpenBMB/BMTrain into tp
MayDomine Aug 29, 2023
5819ce4
.gitignore back
MayDomine Aug 29, 2023
832141a
example back to origin
MayDomine Aug 29, 2023
8bd6475
delete test file
MayDomine Aug 29, 2023
a7270e3
format
MayDomine Aug 29, 2023
47905b8
version modify
MayDomine Aug 29, 2023
568b02a
reformat code
MayDomine Aug 29, 2023
b249adc
fix pre module
MayDomine Aug 29, 2023
a1b8eee
modify comment
MayDomine Aug 29, 2023
92b8630
dont expose use offload interface outside
MayDomine Aug 29, 2023
1fac581
print tools
MayDomine Aug 29, 2023
f66c162
high priority for offload stream
MayDomine Aug 29, 2023
aef4899
fix import
MayDomine Aug 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class CheckpointBlock(torch.nn.Module):
>>> y2, ... = transformer_block(x)
>>> assert torch.allclose(y1, y2)
"""
def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3):
def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, offload_level=0, zero_level=3):
super().__init__()
self._module = inner_module
self._inputs = None
Expand All @@ -80,7 +80,8 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev
self._ready = False
# sort parameters by name
ordered_parameters = list(self._module.named_parameters())

use_offload = offload_level in [1,2]
assert not (use_checkpoint and use_offload), "It does not make sense to use offload and checkpointing at the same time"
# calc total number of parameters
for name, param in ordered_parameters:
if not isinstance(param, DistributedParameter):
Expand Down Expand Up @@ -202,6 +203,11 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev
self._pre_module = [] #save the pre module of self
self._ref_count = 0 #incremental in forward and decreasing in backward
self._mode = "BLOCK" #BLOCK or ZERO or PIPE
self.offload_level = offload_level
if use_offload:
self._mode = "OFFLOAD"
self._on_device = False

self.all_input_no_grad = False
self.all_param_no_grad = False
self._zero_level = zero_level
Expand All @@ -212,12 +218,16 @@ def set_pre_module(self, pre_module):
pre_module._next_module.append(self)

def pre_module(self):
assert len(self._pre_module) == self._ref_count, "{} != {}".format(len(self._pre_module), self._ref_count)
return self._pre_module[self._ref_count-1]
if len(self._pre_module) > 0:
return self._pre_module[self._ref_count-1]
else:
return None

def next_module(self):
assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count)
return self._next_module[self._ref_count-1]
if len(self._next_module) > 0:
return self._next_module[self._ref_count-1]
else:
return None

def backward_release(self, flag):
if self._ref_count == 1 and self._backward_block_ctx is not None:
Expand Down Expand Up @@ -536,19 +546,21 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False)

self._modules = {}
pre_module = None
offload = 0
for i, module in enumerate(modules):
if not isinstance(module, CheckpointBlock):
module = CheckpointBlock(module)

module._mode = "ZERO"
module._mode = "ZERO" if module._mode == "BLOCK" else module._mode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以直接去掉,现在不需要区分ZERO和BLOCK了

module.set_pre_module(pre_module)
pre_module = module
module._is_first_layer = False
module._is_last_layer = False

if module._mode == "OFFLOAD":
offload+=1
module.calc_event = torch.cuda.Event()
module.offload_event = torch.cuda.Event()
self._modules[str(i)] = module
self.add_module(str(i), module)

self._modules[str(0)]._is_first_layer = True
self._modules[str(len(modules)-1)]._is_last_layer = True

Expand All @@ -575,7 +587,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False)
self.save_list = save_list
else:
self.save_list = [(i, i) for i in range(len(self))]

def __len__(self) -> int:
return len(self._modules)

Expand Down
56 changes: 54 additions & 2 deletions bmtrain/hook_func.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
import torch
from .global_var import config
from .checkpointing import CheckpointBlockContext

from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些没用到,可以去掉

from contextlib import contextmanager
from .utils import round_up, find_pre_module_helper
from .offload import Offload_Dict, offload_wrapper, offload_pre_hook, offload_post_hook
def zero_pre_forward(module, inputs):
enter = True
pipe = False
if module._mode == "OFFLOAD":
if not hasattr(module, "_offload_dict"):
module._offload_dict = Offload_Dict()
pack_hook, unpack_hook = offload_wrapper(module._offload_dict)
if module.offload_level == 1:
for n, m in module.named_modules():
if m.__class__.__name__ == "Linear" and not hasattr(m, "_offload_hook"):
m._offload_hook = (pack_hook, unpack_hook)
m.register_forward_pre_hook(offload_pre_hook)
m.register_forward_hook(offload_post_hook)
elif module.offload_level == 2:
if not hasattr(module, "_offload_hook"):
module._offload_hook = (pack_hook, unpack_hook)
torch._C._autograd._push_saved_tensors_default_hooks(
pack_hook, unpack_hook
)

if module._mode == "PIPE":
enter = module._micro_idx == 0
pipe = True
Expand All @@ -25,14 +45,42 @@ def zero_post_forward(module, inputs, outputs):
exit = True
if module._mode == "PIPE":
exit = module._micro_idx == config['micros'] - 1

elif module._mode == "OFFLOAD":
torch.cuda.current_stream().record_event(module.calc_event)
pre_offload_module = find_pre_module_helper(module.pre_module())
if pre_offload_module is not None:
torch.cuda.current_stream().wait_event(pre_offload_module.offload_event)
with torch.cuda.stream(config["offload_stream"]):
config["offload_stream"].wait_event(module.calc_event)
if not hasattr(module._offload_dict, "fp16_storage"):
module._offload_dict.make_cpu_storage()
module._offload_dict.record_stream(config["offload_stream"])
module._offload_dict.d2h_memcpy()
if len(module._next_module) > 0:
config["offload_stream"].record_event(module.offload_event)
if module.offload_level == 2:
torch._C._autograd._pop_saved_tensors_default_hooks()
if exit:
module._forward_block_ctx.exit(forward_flag)
module._ref_count += 1

def zero_pre_backward(module, grad_outputs):
backward_flag = 2 if module._zero_level == 2 else 0
if module._mode != "PIPE":
if module._mode == "OFFLOAD" or (len(module._next_module) == 0):
if len(module._next_module) != 0:
current_stream = torch.cuda.current_stream()
current_stream.wait_event(module.offload_event)
pre_module = find_pre_module_helper(module.pre_module())
if pre_module is not None:
pre_module._on_device = True
with torch.cuda.stream(config["offload_stream"]):
if (len(module._next_module) != 0):
torch.cuda.current_stream().wait_event(module.calc_event)
pre_module._offload_dict.h2d_memcpy()
torch.cuda.current_stream().record_event(pre_module.offload_event)
if (len(module._next_module) != 0):
module._offload_dict.record_stream(current_stream)
module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict)
module._backward_block_ctx.enter(backward_flag, True)
if not module._is_last_layer:
Expand All @@ -45,6 +93,10 @@ def zero_pre_backward(module, grad_outputs):
def zero_post_backward(module, grad_inputs, grad_outputs):
backward_flag = 2 if module._zero_level == 2 else 0
if module._mode != "PIPE":
if module._mode == "OFFLOAD":
module._on_device = False
module._offload_dict.pop_all()
torch.cuda.current_stream().record_event(module.calc_event)
if module._is_first_layer:
module.backward_release(backward_flag)
else:
Expand Down
1 change: 1 addition & 0 deletions bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def init_distributed(
config["rank"] = rank
config["world_size"] = world_size
config["calc_stream"] = torch.cuda.current_stream()
config["offload_stream"] = torch.cuda.Stream(priority=-1)
config["load_stream"] = torch.cuda.Stream(priority=-1)
config["tp_comm_stream"] = torch.cuda.Stream(priority=-1)
config["pp_comm_stream"] = torch.cuda.Stream(priority=-1)
Expand Down
121 changes: 121 additions & 0 deletions bmtrain/offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
from collections import OrderedDict

class Offload_Dict:

def __init__(self):
self._offload_dict = OrderedDict()

def add(self, tensor):
tensor = tensor.contiguous()
tensor_id = id(tensor)
data_ptr = tensor.storage().data_ptr()
if data_ptr not in self._offload_dict:
self._offload_dict[data_ptr] = {}
self._offload_dict[data_ptr]["stor"] = tensor.storage()
self._offload_dict[data_ptr]["size"] = tensor.storage().size()
self._offload_dict[data_ptr]["dtype"] = tensor.storage().dtype
self._offload_dict[data_ptr]["tensors"] = {}

self._offload_dict[data_ptr]["tensors"][id(tensor)] = {}
self._offload_dict[data_ptr]["tensors"][id(tensor)]["numel"] = tensor.numel()
self._offload_dict[data_ptr]["tensors"][id(tensor)]['dtype'] = tensor.dtype
self._offload_dict[data_ptr]["tensors"][id(tensor)]['offset'] = tensor.storage_offset()
self._offload_dict[data_ptr]["tensors"][id(tensor)]['tensor'] = tensor
self._offload_dict[data_ptr]["tensors"][id(tensor)]["shape"] = tensor.shape
self._device = "cuda"
return (data_ptr,tensor_id)

def get_total(self):
fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16])
fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32])
return fp16_total,fp32_total

def make_cpu_storage(self):
fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16])
fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32])
fp16_storage = torch.HalfStorage(fp16_total).pin_memory()
fp32_storage = torch.FloatStorage(fp32_total).pin_memory()
self.fp16_storage = fp16_storage
self.fp32_storage = fp32_storage
self.fp16_total = fp16_total
self.fp32_total = fp32_total

def get(self, key):
data_ptr, tensor_id = key
return self._offload_dict[data_ptr]['tensors'][tensor_id]["tensor"]

def pop_all(self):
self._offload_dict.clear()

def h2d_memcpy(self):
fp16_storage_cuda = self.fp16_storage.cuda(non_blocking=True)
fp32_storage_cuda = self.fp32_storage.cuda(non_blocking=True)
for key,val in self._offload_dict.items():
for id_val in val['tensors'].values():
id_val['tensor'] = torch.tensor([], dtype=id_val['dtype'],device=fp16_storage_cuda.device)
if id_val['dtype'] == torch.float16:
id_val['tensor'].set_(fp16_storage_cuda, id_val['abs_offset'], id_val['shape'])
elif id_val['dtype'] == torch.float32:
id_val['tensor'].set_(fp32_storage_cuda, id_val['abs_offset'], id_val['shape'])

def record_stream(self, stream):
for key, val in self._offload_dict.items():
for id_val in val['tensors'].values():
id_val['tensor'].record_stream(stream)

def d2h_memcpy(self):
fp16_offset = 0
fp32_offset = 0
fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16])
fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32])
assert fp16_total <= self.fp16_total
assert fp32_total <= self.fp32_total
fp16_storage = self.fp16_storage
fp32_storage = self.fp32_storage
for key,val in self._offload_dict.items():
assert val['dtype'] in [torch.float16, torch.float32]
storage = fp16_storage if val['dtype'] == torch.float16 else fp32_storage
offset = fp16_offset if val['dtype'] == torch.float16 else fp32_offset
for id_val in val['tensors'].values():
cpu_tensor = torch.tensor([], dtype=id_val['dtype'], device="cpu") \
.set_(storage, offset+id_val['offset'], id_val['shape'])
id_val["abs_offset"] = offset+id_val['offset']
id_val['tensor'] = cpu_tensor.copy_(id_val['tensor'], non_blocking=True)
if val['dtype'] == torch.float16:
fp16_offset += val['size']
else:
fp32_offset += val['size']
val['stor'] = None


def offload_wrapper(offload_dict):
def pack_hook(tensor):
if isinstance(tensor, torch.nn.Parameter):
return (tensor,)
elif tensor.dtype not in [torch.float16]:
return (tensor,)
else:
key = offload_dict.add(tensor)
return (tensor.device, key)
def unpack_hook(packed):
if len(packed) == 2:
device, key = packed
tensor = offload_dict.get(key)
assert tensor.device == device
return tensor
else:
tensor, = packed
return tensor
return pack_hook, unpack_hook

def offload_pre_hook(module, input):
if hasattr(module, "_offload_hook"):
pack_hook, unpack_hook = module._offload_hook
torch._C._autograd._push_saved_tensors_default_hooks(
pack_hook, unpack_hook
)

def offload_post_hook(module, input, output):
if hasattr(module, "_offload_hook"):
torch._C._autograd._pop_saved_tensors_default_hooks()
20 changes: 19 additions & 1 deletion bmtrain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ def load_nccl_pypi():
if file_split[-1] == "so" or (len(file_split)>1 and file_split[-2] == "so"):
ctypes.CDLL(os.path.join(path, file_so))


def find_pre_module_helper(m):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find_pre_offload_module_helper

if m is None:
return m
if m._mode == "OFFLOAD":
return m
else:
return find_pre_module_helper(m.pre_module())

def round_up(x, d):
return (x + d - 1) // d * d

Expand Down Expand Up @@ -80,6 +87,17 @@ def print_rank(*args, rank=0, **kwargs):
if config["rank"] == rank:
print(*args, **kwargs)

def print_strategy(model):
print_rank(" "*24+"|"+" Offload Level |" + " ZeRO Level |"+" Activation Recompute |")
for idx,ckpt in enumerate(model):
print_rank(f"CheckpointBlock Layer {idx} |{ckpt.offload_level:^14} | {ckpt._zero_level:^10} | {ckpt.use_checkpoint.__repr__():^20} |")

def print_inspect(model):
model_inspect = bmt.inspect.inspect_model(model, "*")
print_rank(bmt.inspect.format_summary(model_inspect))



def see_memory(message, detail=False):
"""
Outputs a message followed by GPU memory status summary on rank 0.
Expand Down