-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
76 lines (64 loc) · 1.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# from: https://github.com/karpathy/arxiv-sanity-preserver/blob/master/utils.py
from contextlib import contextmanager
import os
import re
import pickle
import tempfile
# Context managers for atomic writes courtesy of
# http://stackoverflow.com/questions/2333872/atomic-writing-to-file-with-python
@contextmanager
def _tempfile(*args, **kws):
""" Context for temporary file.
Will find a free temporary filename upon entering
and will try to delete the file on leaving
Parameters
----------
suffix : string
optional file suffix
"""
fd, name = tempfile.mkstemp(*args, **kws)
os.close(fd)
try:
yield name
finally:
try:
os.remove(name)
except OSError as e:
if e.errno == 2:
pass
else:
raise e
@contextmanager
def open_atomic(filepath, *args, **kwargs):
""" Open temporary file object that atomically moves to destination upon
exiting.
Allows reading and writing to and from the same filename.
Parameters
----------
filepath : string
the file path to be opened
fsync : bool
whether to force write the file to disk
kwargs : mixed
Any valid keyword arguments for :code:`open`
"""
fsync = kwargs.pop('fsync', False)
with _tempfile(dir=os.path.dirname(filepath)) as tmppath:
with open(tmppath, *args, **kwargs) as f:
yield f
if fsync:
f.flush()
os.fsync(f.fileno())
os.rename(tmppath, filepath)
def safe_pickle_dump(obj, fname):
with open_atomic(fname, 'wb') as f:
pickle.dump(obj, f, -1)
# arxiv utils
# -----------------------------------------------------------------------------
def strip_version(idstr):
""" identity function if arxiv id has no version, otherwise strips it. """
parts = idstr.split('v')
return parts[0]
# "1511.08198v1" is an example of a valid arxiv id that we accept
def isvalidid(pid):
return re.match('^\d+\.\d+(v\d+)?$', pid)