-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_utils.py
144 lines (116 loc) · 4.43 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import deepspeed
import shutil
import torch
import socket
import yaml
import os
from distributed_utils import is_main_process
import math
from normalizer import *
def parse_trainer_args(parser: argparse.ArgumentParser):
parser = deepspeed.add_config_arguments(parser)
args, _ = parser.parse_known_args()
args = parser.parse_args()
recipe_args = argparse.Namespace(**yaml.load(open(args.recipe_pth), Loader=yaml.FullLoader))
args, _ = parser.parse_known_args(namespace=recipe_args)
args = parser.parse_args(namespace=recipe_args)
return args
@torch.no_grad()
def compute_param_diff(model1, model2):
ssq = 0
param_norm_sq = 0
num_elements = 0
for param1, param2 in zip(model1.parameters(),model2.parameters()):
if param1.requires_grad:
ssq += ((param1 - param2) ** 2).sum()
param_norm_sq += (param1 ** 2).sum()
num_elements += param1.numel()
return math.sqrt(ssq), math.sqrt(param_norm_sq), num_elements
def is_bn(m):
return isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d) or isinstance(m, GhostBatchNorm1d) or isinstance(m, GhostBatchNorm2d)
def get_data_pth(config_pth='./.config.yml'):
config = yaml.load(open(config_pth), Loader=yaml.FullLoader)
return config['imagenet_ffcv_train_pth'], config['imagenet_ffcv_val_pth']
def mkdir(path):
# remove space
path=path.strip()
# remove \ at the end
path=path.rstrip("\\")
# judge whether the paths exists
isExists=os.path.exists(path)
# judge the result
if not isExists:
'''
differences between os.mkdir(path) and os.makedirs(path): os.mkdirs will create the parent directory but os.mkdir will not
'''
# use utf-8 encoding
os.makedirs(path)
print(path + ' is successfully made')
return True
else:
# if the path already exists
print(path + 'already exists')
return False
def count_correct(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k)
return res
def if_enough_space(pth, thres=10):
# Get the disk usage statistics for the specified path
usage = shutil.disk_usage(pth)
# Calculate the free space in bytes
free_space = usage.free
# Convert the free space to a more readable format
free_space_gb = free_space / (1024**3) # Convert bytes to gigabytes
if free_space_gb > thres:
return True
else:
print('No space left, skip saving')
return False
def print_lr(step_ctr, optimizer):
if is_main_process():
print("Step {}, learning rate {}".format(step_ctr, optimizer.param_groups[0]['lr']))
def adjust_client_lr(client_list, gamma):
for client in client_list:
client.decay_lr(gamma)
def print_client_lr(optimizer):
print(f"learning rate {optimizer.param_groups[0]['lr']}")
@torch.no_grad()
def eval_param_norm(model):
norm_sq = 0.
for name, param in model.named_parameters():
if param.requires_grad:
norm_sq += torch.sum(param ** 2).item()
return math.sqrt(norm_sq)
def group_weight(module):
group_decay = []
group_no_decay = []
for m in module.modules():
if isinstance(m, nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.Conv2d):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
return groups
def yield_optimizer_state(model, optimizer, key):
for p in model.parameters():
yield optimizer.state[p][key]