forked from naderAsadi/AML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
136 lines (105 loc) · 4.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import copy
import numpy as np
from collections import OrderedDict as OD
from collections import defaultdict as DD
import torch
import torch.nn as nn
import torch.nn.functional as F
''' LOG '''
def logging_per_task(wandb, log, mode, metric, task=0, task_t=0, value=0):
if 'final' in metric:
log[mode][metric] = value
else:
log[mode][metric][task_t, task] = value
if wandb is not None:
if 'final' in metric:
wandb.log({mode+metric:value}) #, step=run)
def print_(log, mode, task):
to_print = mode + ' ' + str(task) + ' '
for name, value in log.items():
# only print acc for now
if len(value) > 0:
name_ = name + ' ' * (12 - len(name))
value = sum(value) / len(value)
if 'acc' in name or 'gen' in name:
to_print += '{}\t {:.4f}\t'.format(name_, value)
print(to_print)
def get_logger(names, n_tasks=None):
log = OD()
log.print_ = lambda a, b: print_(log, a, b)
log = {}
for mode in ['train','valid','test']:
log[mode] = {}
for name in names:
log[mode][name] = np.zeros([n_tasks,n_tasks])
log[mode]['final_acc'] = 0.
log[mode]['final_forget'] = 0.
return log
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
def get_temp_logger(exp, names):
log = OD()
log.print_ = lambda a, b: print_(log, a, b)
for name in names: log[name] = []
return log
import collections
import numpy as np
import torch
def sho_(x, nrow=8):
x = x * .5 + .5
from torchvision.utils import save_image
from PIL import Image
if x.ndim == 5:
nrow=x.size(1)
x = x.reshape(-1, *x.shape[2:])
save_image(x, 'tmp.png', nrow=nrow)
Image.open('tmp.png').show()
# https://github.com/tristandeleu/pytorch-meta/
# from torchvision.datasets.utils import _get_confirm_token, _save_response_content
def _quota_exceeded(response: "requests.models.Response"):
return False
# See https://github.com/pytorch/vision/issues/2992 for details
# return "Google Drive - Quota exceeded" in response.text
def download_file_from_google_drive(file_id, root, filename=None, md5=None):
"""Download a Google Drive file from and place it in root.
Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath): #and check_integrity(fpath, md5): #TODO: Resolve the issue with 'check_integrity'
print('Using downloaded and verified file: ' + fpath)
else:
session = requests.Session()
response = session.get(url, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)
if _quota_exceeded(response):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)
save_response_content(response, fpath)