-
Notifications
You must be signed in to change notification settings - Fork 0
/
functions.py
72 lines (57 loc) · 2.41 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
def print_tensor_details(tensor, name=None):
if name is None:
name = ""
print(f"{name} Device: {tensor.device}")
print(f"{name} Type: {tensor.dtype}")
print(f"{name} Shape: {tensor.shape}")
print(f"{name} dtype: {tensor.dtype}")
print(f"{name} Is Nan: {torch.isnan(tensor).any()}")
print(f"{name} Is Inf: {torch.isinf(tensor).any()}")
print(f"{name} Min: {torch.min(tensor)}")
print(f"{name} Max: {torch.max(tensor)}")
if tensor.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]:
print(f"{name} Mean: {torch.mean(tensor)}")
print(f"{name} Std: {torch.std(tensor)}")
if hasattr(tensor, "grad") and tensor.grad is not None:
print_tensor_details(tensor.grad, name=name + " Grad" if name else "Grad")
def print_vars(ref_dict, names=None):
keys = list(ref_dict.keys())
for name in keys:
# only perform debugging for specific variables if we provide that
if names is not None:
if name not in names:
continue
# if a tensor, print details
if isinstance(ref_dict[name], torch.Tensor):
print_tensor_details(ref_dict[name], name=name)
print("---"*5)
# if dictionary, tuple, or list, print a map of its hiearchy
elif isinstance(ref_dict[name], (dict, tuple, list)):
analyze_hierarchy(ref_dict[name], name)
print("---"*5)
def recursive(obj, string_to_print, depth, keyname=None):
string = f"{depth * '- '}"
if keyname is not None:
string += f"{keyname}"
string += f" {type(obj)}"
if isinstance(obj, (list, tuple)):
string_to_print += [f"{string}, length:{len(obj)}"]
for item in obj:
string_to_print = recursive(item, string_to_print, depth + 1)
elif isinstance(obj, dict):
string_to_print += [f"{string}, length:{len(obj)}"]
for key, value in obj.items():
string_to_print = recursive(value, string_to_print, depth + 1, keyname=key)
else:
string_to_print += [f"{string}"]
return string_to_print
def analyze_hierarchy(obj, name):
"""
decomposes a dictionary, list, or tuple into its components and prints attributes
"""
string_to_print = []
depth = 0
string_to_print = recursive(obj, string_to_print, depth, keyname=name)
string_to_print = "\n".join(string_to_print)
print(string_to_print)