Skip to content

Commit

Permalink
Merge pull request #350 from KantaTamura/cache-profile
Browse files Browse the repository at this point in the history
Support PPE profiling for FileCache / MultiprocessFileCache
  • Loading branch information
k5342 authored Oct 17, 2024
2 parents 96252c0 + 7461588 commit 09e15ed
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 74 deletions.
File renamed without changes.
66 changes: 38 additions & 28 deletions pfio/cache/file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from struct import calcsize, pack, unpack

from pfio import cache
from pfio._profiler import record

# Deprecated, but leaving for backward compatibility just in case any
# system directly using this value
Expand Down Expand Up @@ -152,10 +153,11 @@ class FileCache(cache.Cache):
'''

def __init__(self, length, multithread_safe=False, do_pickle=False,
dir=None, cache_size_limit=None, verbose=False):
dir=None, cache_size_limit=None, verbose=False, trace=False):
self._multithread_safe = multithread_safe
self.length = length
self.do_pickle = do_pickle
self.trace = trace
if self.length <= 0 or (2 ** 64) <= self.length:
raise ValueError("length has to be between 0 and 2^64")

Expand Down Expand Up @@ -217,45 +219,50 @@ def multithread_safe(self):
return self._multithread_safe

def get(self, i):
if self.closed:
return
data = self._get(i)
if self.do_pickle and data:
data = pickle.loads(data)
return data
with record("pfio.cache.file:get", trace=self.trace):
if self.closed:
return
data = self._get(i)
if self.do_pickle and data:
data = pickle.loads(data)
return data

def _get(self, i):
if i < 0 or self.length <= i:
raise IndexError("index {} out of range ([0, {}])"
.format(i, self.length - 1))

offset = self.buflen * i
with self.lock.rdlock():
buf = os.pread(self.cachefp.fileno(), self.buflen, offset)
with self.lock.rdlock(), record("pfio.cache.file:get:lock", trace=self.trace):
with record("pfio.cache.file:get:read_index", trace=self.trace):
buf = os.pread(self.cachefp.fileno(), self.buflen, offset)
(o, l) = unpack('Qq', buf)
if l < 0 or o < 0:
return None

data = os.pread(self.cachefp.fileno(), l, o)
with record("pfio.cache.file:get:read_data", trace=self.trace):
data = os.pread(self.cachefp.fileno(), l, o)
assert len(data) == l
return data

def put(self, i, data):
if self._frozen or self.closed:
return False
with record("pfio.cache.file:put", trace=self.trace):
if self._frozen or self.closed:
return False

try:
if self.do_pickle:
data = pickle.dumps(data)
return self._put(i, data)
try:
if self.do_pickle:
data = pickle.dumps(data)
return self._put(i, data)

except OSError as ose:
# Disk full (ENOSPC) possibly by cache; just warn and keep running
if ose.errno == errno.ENOSPC:
warnings.warn(ose.strerror, RuntimeWarning)
return False
else:
raise ose
except OSError as ose:
# Disk full (ENOSPC) possibly by cache;
# just warn and keep running
if ose.errno == errno.ENOSPC:
warnings.warn(ose.strerror, RuntimeWarning)
return False
else:
raise ose

def _put(self, i, data):
if self.closed:
Expand All @@ -270,8 +277,9 @@ def _put(self, i, data):
return False

offset = self.buflen * i
with self.lock.wrlock():
buf = os.pread(self.cachefp.fileno(), self.buflen, offset)
with self.lock.wrlock(), record("pfio.cache.file:put:lock", trace=self.trace):
with record("pfio.cache.file:put:read_index", trace=self.trace):
buf = os.pread(self.cachefp.fileno(), self.buflen, offset)
(o, l) = unpack('Qq', buf)
if l >= 0 and o >= 0:
# Already data exists
Expand All @@ -296,13 +304,15 @@ def _put(self, i, data):
'''
buf = pack('Qq', pos, len(data))
r = os.pwrite(self.cachefp.fileno(), buf, offset)
with record("pfio.cache.file:put:write_index", trace=self.trace):
r = os.pwrite(self.cachefp.fileno(), buf, offset)
assert r == self.buflen

current_pos = pos
while current_pos - pos < len(data):
r = os.pwrite(self.cachefp.fileno(),
data[current_pos-pos:], current_pos)
with record("pfio.cache.file:put:write_data", trace=self.trace):
r = os.pwrite(self.cachefp.fileno(),
data[current_pos-pos:], current_pos)
assert r > 0
current_pos += r
assert current_pos - pos == len(data)
Expand Down
105 changes: 62 additions & 43 deletions pfio/cache/multiprocess_file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from struct import calcsize, pack, unpack

from pfio import cache
from pfio._profiler import record
from pfio.cache.file_cache import _check_local, _default_cache_path


Expand Down Expand Up @@ -139,10 +140,12 @@ def __init__(self, image_paths):
''' # NOQA

def __init__(self, length, do_pickle=False,
dir=None, cache_size_limit=None, verbose=False):
dir=None, cache_size_limit=None,
verbose=False, trace=False):
self.length = length
self.do_pickle = do_pickle
self.verbose = verbose
self.trace = trace
if self.length <= 0 or (2 ** 64) <= self.length:
raise ValueError("length has to be between 0 and 2^64")

Expand Down Expand Up @@ -207,12 +210,13 @@ def multithread_safe(self) -> bool:
return True

def get(self, i):
if self.closed:
return
data = self._get(i)
if self.do_pickle and data:
data = pickle.loads(data)
return data
with record("pfio.cache.multiprocessfile:get", trace=self.trace):
if self.closed:
return
data = self._get(i)
if self.do_pickle and data:
data = pickle.loads(data)
return data

def _open_fds(self):
pid = os.getpid()
Expand All @@ -227,34 +231,41 @@ def _get(self, i):

self._open_fds()
offset = self.buflen * i

fcntl.flock(self.cache_fd, fcntl.LOCK_SH)
index_entry = os.pread(self.cache_fd, self.buflen, offset)
(o, l) = unpack('Qq', index_entry)
if l < 0 or o < 0:
with record(f"pfio.cache.multiprocessfile:get:lock-{self.cache_fd}", trace=self.trace):
with record("pfio.cache.multiprocessfile:get:read_index", trace=self.trace):
index_entry = os.pread(self.cache_fd, self.buflen, offset)
(o, l) = unpack('Qq', index_entry)
if l < 0 or o < 0:
fcntl.flock(self.cache_fd, fcntl.LOCK_UN)
return None

with record("pfio.cache.multiprocessfile:get:read_data", trace=self.trace):
data = os.pread(self.cache_fd, l, o)
assert len(data) == l
fcntl.flock(self.cache_fd, fcntl.LOCK_UN)
return None

data = os.pread(self.cache_fd, l, o)
assert len(data) == l
fcntl.flock(self.cache_fd, fcntl.LOCK_UN)
return data

def put(self, i, data):
if self._frozen or self.closed:
return False
with record("pfio.cache.multiprocessfile:put", trace=self.trace):
if self._frozen or self.closed:
return False

try:
if self.do_pickle:
data = pickle.dumps(data)
return self._put(i, data)
try:
if self.do_pickle:
data = pickle.dumps(data)
return self._put(i, data)

except OSError as ose:
# Disk full (ENOSPC) possibly by cache; just warn and keep running
if ose.errno == errno.ENOSPC:
warnings.warn(ose.strerror, RuntimeWarning)
return False
else:
raise ose
except OSError as ose:
# Disk full (ENOSPC) possibly by cache;
# just warn and keep running
if ose.errno == errno.ENOSPC:
warnings.warn(ose.strerror, RuntimeWarning)
return False
else:
raise ose

def _put(self, i, data):
if self.closed:
Expand All @@ -266,27 +277,35 @@ def _put(self, i, data):
self._open_fds()

index_ofst = self.buflen * i
fcntl.flock(self.cache_fd, fcntl.LOCK_EX)
buf = os.pread(self.cache_fd, self.buflen, index_ofst)
(o, l) = unpack('Qq', buf)

if l >= 0 and o >= 0:
# Already data exists
fcntl.flock(self.cache_fd, fcntl.LOCK_UN)
return False
fcntl.flock(self.cache_fd, fcntl.LOCK_EX)
with record(f"pfio.cache.multiprocessfile:put:lock-{self.cache_fd}", trace=self.trace):
with record("pfio.cache.multiprocessfile:put:read_index", trace=self.trace):
buf = os.pread(self.cache_fd, self.buflen, index_ofst)
(o, l) = unpack('Qq', buf)

data_pos = os.lseek(self.cache_fd, 0, os.SEEK_END)
if self.cache_size_limit:
if self.cache_size_limit < (data_pos + len(data)):
self._frozen = True
if l >= 0 and o >= 0:
# Already data exists
fcntl.flock(self.cache_fd, fcntl.LOCK_UN)
return False

index_entry = pack('Qq', data_pos, len(data))
assert os.pwrite(self.cache_fd, index_entry, index_ofst) == self.buflen
assert os.pwrite(self.cache_fd, data, data_pos) == len(data)
os.fsync(self.cache_fd)
fcntl.flock(self.cache_fd, fcntl.LOCK_UN)
with record("pfio.cache.multiprocessfile:put:seek", trace=self.trace):
data_pos = os.lseek(self.cache_fd, 0, os.SEEK_END)
if self.cache_size_limit:
if self.cache_size_limit < (data_pos + len(data)):
self._frozen = True
fcntl.flock(self.cache_fd, fcntl.LOCK_UN)
return False

index_entry = pack('Qq', data_pos, len(data))
with record("pfio.cache.multiprocessfile:put:write_index", trace=self.trace):
assert os.pwrite(self.cache_fd, index_entry, index_ofst) == self.buflen
with record("pfio.cache.multiprocessfile:put:write_data", trace=self.trace):
assert os.pwrite(self.cache_fd, data, data_pos) == len(data)
with record("pfio.cache.multiprocessfile:put:sync", trace=self.trace):
os.fsync(self.cache_fd)
fcntl.flock(self.cache_fd, fcntl.LOCK_UN)

return True

def __enter__(self):
Expand Down
3 changes: 2 additions & 1 deletion pfio/v2/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import shutil
from typing import Optional

from ._profiler import record, record_iterable
from pfio._profiler import record, record_iterable

from .fs import FS, FileStat, format_repr


Expand Down
3 changes: 2 additions & 1 deletion pfio/v2/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from botocore.config import Config
from botocore.exceptions import ClientError

from ._profiler import record, record_iterable
from pfio._profiler import record, record_iterable

from .fs import FS, FileStat, format_repr

DEFAULT_MAX_BUFFER_SIZE = 16 * 1024 * 1024
Expand Down
3 changes: 2 additions & 1 deletion pfio/v2/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from datetime import datetime
from typing import Optional, Set

from ._profiler import record, record_iterable
from pfio._profiler import record, record_iterable

from .fs import FS, FileStat, format_repr

logger = logging.getLogger(__name__)
Expand Down
31 changes: 31 additions & 0 deletions tests/cache_tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import json
import os
import pickle
import random
Expand Down Expand Up @@ -438,3 +439,33 @@ def test_default_cache_path(test_class):
finally:
if orig is not None:
os.environ['XDG_CACHE_HOME'] = orig


def test_filecache_profiling():
ppe = pytest.importorskip("pytorch_pfn_extras")

with tempfile.TemporaryDirectory() as d:
ppe.profiler.clear_tracer()

cache = FileCache(1, dir=d, do_pickle=True, trace=True)
cache.put(0, b"foo")
assert b"foo" == cache.get(0)

dict = ppe.profiler.get_tracer().state_dict()
keys = [event["name"] for event in json.loads(dict['_event_list'])]

assert "pfio.cache.file:put" in keys
assert "pfio.cache.file:get" in keys

with tempfile.TemporaryDirectory() as d:
ppe.profiler.clear_tracer()

cache = MultiprocessFileCache(1, dir=d, do_pickle=True, trace=True)
cache.put(0, b"foo")
assert b"foo" == cache.get(0)

dict = ppe.profiler.get_tracer().state_dict()
keys = [event["name"] for event in json.loads(dict['_event_list'])]

assert "pfio.cache.multiprocessfile:put" in keys
assert "pfio.cache.multiprocessfile:get" in keys

0 comments on commit 09e15ed

Please sign in to comment.