forked from shredder67/diffusion-dict-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
collect_features.py
138 lines (112 loc) · 4.82 KB
/
collect_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
from tqdm import tqdm
import json
import os
import gc
from torch.utils.data import DataLoader
from safetensors import safe_open
from safetensors.torch import save_file
import argparse
from src.utils import setup_seed, multi_acc
from src.pixel_classifier import load_ensemble, compute_iou, predict_labels, save_predictions, save_predictions, pixel_classifier
from src.datasets import ImageLabelDataset, FeatureDataset, make_transform
from src.feature_extractors import create_feature_extractor, collect_features
from guided_diffusion.guided_diffusion.script_util import model_and_diffusion_defaults, add_dict_to_argparser
from guided_diffusion.guided_diffusion.dist_util import dev
def prepare_data(args):
feature_extractor = create_feature_extractor(**args)
dataset = ImageLabelDataset(
data_dir=args['training_path'],
resolution=args['image_size'],
num_images=args['training_number'],
transform=make_transform(
args['model_type'],
args['image_size']
)
)
X = torch.zeros((len(dataset), *args['dim'][::-1]), dtype=torch.float)
y = torch.zeros((len(dataset), *args['dim'][:-1]), dtype=torch.uint8)
if 'share_noise' in args and args['share_noise']:
rnd_gen = torch.Generator(device=dev()).manual_seed(args['seed'])
noise = torch.randn(1, 3, args['image_size'], args['image_size'],
generator=rnd_gen, device=dev())
else:
noise = None
feature_tensors = []
for row, (img, label) in enumerate(tqdm(dataset)):
img = img[None].to(dev())
features = feature_extractor(img, noise=noise)
X[row] = collect_features(args, features).cpu()
for target in range(args['number_class']):
if target == args['ignore_label']: continue
if 0 < (label == target).sum() < 20:
print(f'Delete small annotation from image {dataset.image_paths[row]} | label {target}')
label[label == target] = args['ignore_label']
y[row] = label
d = X.shape[1]
print(f'Total dimension {d}')
X = X.permute(1,0,2,3).reshape(d, -1).permute(1, 0)
y = y.flatten()
return X[y != args['ignore_label']], y[y != args['ignore_label']]
def prepare_sae_data(args):
assert "best_t" in args
args["steps"] = [args["best_t"]]
feature_extractor = create_feature_extractor(**args)
dataset = ImageLabelDataset(
data_dir=args['training_path'],
resolution=args['image_size'],
num_images=args['training_number'],
transform=make_transform(
args['model_type'],
args['image_size']
)
)
X = torch.zeros((len(dataset), *args["dim"]), dtype=torch.float) # N_IMGS, C, H, W
if 'share_noise' in args and args['share_noise']:
rnd_gen = torch.Generator(device=dev()).manual_seed(args['seed'])
noise = torch.randn(1, 3, args['image_size'], args['image_size'],
generator=rnd_gen, device=dev())
else:
noise = None
for row, (img, _) in enumerate(tqdm(dataset)):
img = img[None].to(dev())
X[row] = feature_extractor(img, noise=noise)[0] # single tensor list
return X
if __name__ == '__main__':
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, model_and_diffusion_defaults())
parser.add_argument('--exp', type=str)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
setup_seed(args.seed)
# Load the experiment config
opts = json.load(open(args.exp, 'r'))
opts.update(vars(args))
opts['image_size'] = opts['dim'][0]
if opts.get("best_t") is None: # execute feature_dataset collection
opts['dim'][-1] = 512 # should be correct for single t
for t in tqdm(range(50, 951, 50)):
print(f"getting features for t={t}")
opts["steps"] = [t] # select single step to collect features for
X, y = prepare_data(opts)
features = dict()
features[f"x_{t}"] = X
features["y"] = y
save_file(features, f"clf_features_20_{t}.safetensors")
del features
gc.collect()
else: # collect out_block activations dataset
print("[COLLECTING FEATURES FOR SAE]")
target_blocks = opts["blocks"].copy()
block_sizes = [
[512, 32, 32]
] # IMPORTANT: fill this manually to match desired blocks out_sizes
features = dict()
for b, out_size in zip(target_blocks, block_sizes):
print(f"[WOKRING WITH BLOCK={b}]")
opts["blocks"] = [b]
opts["dim"] = out_size
X = prepare_sae_data(opts)
features[f"{opts['best_t']}_{b}"] = X
save_file(features, "sae_features_train.safetensors")