-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathbatch_download.py
66 lines (54 loc) · 1.74 KB
/
batch_download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import time
import argparse
from model_download import download_model_retry
from model_mirror import make_mirror
def readTaskFromList(fn):
x = -1
z = 0
model_id = ""
with open(fn, 'r') as f:
lines = f.readlines()
z = len(lines)
for i in range(len(lines)):
line = lines[i]
if line.startswith("*"):
continue
else:
x = i
model_id = line.strip()
break
return x, z, model_id
def writeFlagToList(fn, x):
with open(fn, 'r') as f:
lines = f.readlines()
lines[x] = "*" + lines[x]
with open(fn, 'w') as f:
f.writelines(lines)
def downloadModelFromHg(model_id, repo_type):
print("***** " + model_id + " *****")
download_model_retry(model_id, repo_type)
def mirrorModel(_root, model_id, repo_type):
print("***** " + model_id + " *****")
make_mirror(_root, model_id, repo_type)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--listfile', default="model_list.txt", type=str, required=True)
parser.add_argument(
'--repo_type', default="model", type=str, required=False)
parser.add_argument(
'--mirror_root', default="", type=str, required=False)
args = parser.parse_args()
fn = args.listfile
while True:
x, z, model_id = readTaskFromList(fn)
print(str(x+1) + ' of ' + str(z))
if x == -1:
time.sleep(10)
else:
if args.mirror_root != "":
mirrorModel(args.mirror_root, model_id, args.repo_type)
else:
downloadModelFromHg(model_id, args.repo_type)
writeFlagToList(fn, x)
time.sleep(10)