-
Notifications
You must be signed in to change notification settings - Fork 55
/
baselines.py
174 lines (148 loc) · 5.49 KB
/
baselines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import argparse
import os
from pathlib import Path
import torch
from baselines.apply_filter import apply_filter
BASELINES = {
"no_filter",
"basic_filter",
"text_based",
"image_based",
"image_based_intersect_clip_score",
"clip_score",
"laion2b",
}
ARCH = {
"b32",
"l14",
}
CLUSTER_CENTROID_SCALES = [
"small",
"medium",
"large",
"xlarge",
]
def check_args(args):
if args.name not in BASELINES:
raise ValueError(f"--name must be in: {BASELINES}")
if args.name == "laion2b":
if (
args.fraction is not None
or args.threshold is not None
or args.arch is not None
or args.image_based_scale is not None
):
raise ValueError("laion2b does not support clip_score or image_based flags")
# clip_score checks
if "clip_score" in args.name:
if args.fraction is None and args.threshold is None:
raise ValueError(
"--fraction or --threshold must be specified for clip_score baselines"
)
if args.fraction is not None and args.threshold is not None:
raise ValueError(
"specify either --fraction or --threshold for clip_score baselines but not both"
)
if args.arch is None:
raise ValueError(f"specify architecture {ARCH}, for clip_score baselines")
if args.fraction is not None and not ("clip_score" in args.name):
raise ValueError("--fraction value only used for clip_score baselines")
if args.threshold is not None and not ("clip_score" in args.name):
raise ValueError("--threshold value only used for clip_score baselines")
if args.arch is not None and not ("clip_score" in args.name):
raise ValueError("--arch value only used for clip_score baselines")
# image_based checks
if args.image_based_scale is None and "image_based" in args.name:
raise ValueError(
"--image_based_scale value must be passed for image_based and image_based_intersect_clip_score_* baselines (for clustering)"
)
if args.image_based_scale is not None and not ("image_based" in args.name):
raise ValueError(
"--image_based_scale should only be passed for image_based and image_based_intersect_clip_score_* baselines (for clustering)"
)
if "image_based" in args.name and not torch.cuda.is_available():
raise ValueError(
"gpus needed for image_based baselines, torch.cuda.is_available() must return true"
)
npy_parent = Path(args.save_path).parent
if not os.path.exists(npy_parent):
print(f"creating: {npy_parent}")
os.mkdir(npy_parent)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="This is a command line script for reproducing the main DataComp filtering baselines. The output of the script is a numpy file (.npy) containing the uids in the filtered subsets in sorted binary format. Please see README.md for additional information"
)
parser.add_argument(
"--name",
type=str,
required=True,
choices=list(BASELINES),
help="name of the baseline",
)
parser.add_argument(
"--metadata_dir",
type=str,
required=True,
help="directory (local or cloud) containing parquet, npz metadata",
)
parser.add_argument(
"--save_path",
type=str,
required=True,
help="path to output .npy, note: cloudpaths are not supported for this arg",
)
parser.add_argument(
"--threshold",
type=float,
required=False,
default=None,
help="A threshold to apply on a CLIP score (e.g., a value of 0.25 will only keep examples with CLIP score over 0.25)",
)
parser.add_argument(
"--fraction",
type=float,
required=False,
default=None,
help="a fraction of metadata to keep according to CLIP score (e.g., a value of 0.25 will keep the top 25 percent of examples in the pool by CLIP score)",
)
parser.add_argument(
"--arch",
type=str,
required=False,
choices=list(ARCH),
help="kinds of features (b32 or l14) on which to run the CLIP score filter",
)
parser.add_argument(
"--num_workers",
type=int,
required=False,
default=os.cpu_count(),
help="number of workers, generally set to number of cpu cores. workers read their metadata files and filter them in parallel).",
)
parser.add_argument(
"--num_gpus",
type=int,
required=False,
default=torch.cuda.device_count(),
help="number of gpus for the image_based gpu implementation. num_gpus metadata files are processed in parallel on each gpu worker. NOTE: this parameter is ignored for non-image_basesd baselines",
)
parser.add_argument(
"--batch_size",
type=int,
required=False,
default=1024,
help="batch size for the image_based gpu implementation. NOTE: this parameter is ignored for non-image_basesd baselines",
)
parser.add_argument(
"--image_based_scale",
type=str,
required=False,
choices=CLUSTER_CENTROID_SCALES,
help="datacomp scale, used for the clutering baselines",
default=None,
)
args = parser.parse_args()
# all error checking happens here and apply_filter assumes correct input
check_args(args)
# route the args to the correct baseline
apply_filter(args)