Skip to content

Commit

Permalink
Merge pull request #3 from msmk0/zipfile
Browse files Browse the repository at this point in the history
Add support for loading dataset directly from a zipfile
  • Loading branch information
dhrou authored Apr 23, 2018
2 parents c4d82d3 + efaf163 commit 78b71a3
Showing 1 changed file with 69 additions and 33 deletions.
102 changes: 69 additions & 33 deletions trackml/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,19 @@
__authors__ = ['Moritz Kiehn', 'Sabrina Amrouche']

import glob
import os
import os.path as op
import re
import zipfile

import pandas

CELLS_DTYPES = dict([
('hit_id', 'i4'),
('ch0', 'i4'),
('ch1', 'i4'),
('value', 'f4'),
])
HITS_DTYPES = dict([
('hit_id', 'i4'),
('x', 'f4'),
Expand All @@ -16,12 +25,6 @@
('layer_id', 'i4'),
('module_id', 'i4'),
])
CELLS_DTYPES = dict([
('hit_id', 'i4'),
('ch0', 'i4'),
('ch1', 'i4'),
('value', 'f4'),
])
PARTICLES_DTYPES = dict([
('particle_id', 'i8'),
('vx', 'f4'),
Expand All @@ -44,12 +47,20 @@
('tpz', 'f4'),
('weight', 'f4'),
])

def _load_event_data(prefix, name, dtype):
DTYPES = {
'cells': CELLS_DTYPES,
'hits': HITS_DTYPES,
'particles': PARTICLES_DTYPES,
'truth': TRUTH_DTYPES,
}
DEFAULT_PARTS = ['hits', 'cells', 'particles', 'truth']

def _load_event_data(prefix, name):
"""Load per-event data for one single type, e.g. hits, or particles.
"""
expr = '{!s}-{}.csv*'.format(prefix, name)
files = glob.glob(expr)
dtype = DTYPES[name]
if len(files) == 1:
return pandas.read_csv(files[0], header=0, index_col=False, dtype=dtype)
elif len(files) == 0:
Expand All @@ -60,30 +71,24 @@ def _load_event_data(prefix, name, dtype):
def load_event_hits(prefix):
"""Load the hits information for a single event with the given prefix.
"""
return _load_event_data(prefix, 'hits', HITS_DTYPES)
return _load_event_data(prefix, 'hits')

def load_event_cells(prefix):
"""Load the hit cells information for a single event with the given prefix.
"""
return _load_event_data(prefix, 'cells', CELLS_DTYPES)
return _load_event_data(prefix, 'cells')

def load_event_particles(prefix):
"""Load the particles information for a single event with the given prefix.
"""
return _load_event_data(prefix, 'particles', PARTICLES_DTYPES)
return _load_event_data(prefix, 'particles')

def load_event_truth(prefix):
"""Load only the truth information for a single event with the given prefix.
"""
return _load_event_data(prefix, 'truth', TRUTH_DTYPES)

_LOAD_FUNCTIONS = {
'hits': load_event_hits,
'cells': load_event_cells,
'particles': load_event_particles,
'truth': load_event_truth, }
return _load_event_data(prefix, 'truth')

def load_event(prefix, parts=['hits', 'cells', 'particles', 'truth']):
def load_event(prefix, parts=DEFAULT_PARTS):
"""Load data for a single event with the given prefix.
Parameters
Expand All @@ -100,15 +105,15 @@ def load_event(prefix, parts=['hits', 'cells', 'particles', 'truth']):
element has field names identical to the CSV column names with
appropriate types.
"""
return tuple(_LOAD_FUNCTIONS[_](prefix) for _ in parts)
return tuple(_load_event_data(prefix, name) for name in parts)

def load_dataset(path, skip=None, nevents=None, **kw):
"""Provide an iterator over (all) events in a dataset directory.
def load_dataset(path, skip=None, nevents=None, parts=DEFAULT_PARTS):
"""Provide an iterator over (all) events in a dataset.
Parameters
----------
path : str or pathlib.Path
Path to the dataset directory.
Path to a directory or a zip file containing event files.
skip : int, optional
Skip the first `skip` events.
nevents : int, optional
Expand All @@ -123,13 +128,44 @@ def load_dataset(path, skip=None, nevents=None, **kw):
*data
Event data element as specified in `parts`.
"""
files = glob.glob(op.join(path, 'event*-*'))
names = set(op.basename(_).split('-', 1)[0] for _ in files)
names = sorted(names)
if skip is not None:
names = names[skip:]
if nevents is not None:
names = names[:nevents]
for name in names:
event_id = int(name[5:])
yield (event_id,) + load_event(op.join(path, name), **kw)
# extract a sorted list of event file prefixes.
def list_prefixes(files):
regex = re.compile('^event\d{9}-[a-zA-Z]+.csv')
files = filter(regex.match, files)
prefixes = set(op.basename(_).split('-', 1)[0] for _ in files)
prefixes = sorted(prefixes)
if skip is not None:
prefixes = prefixes[skip:]
if nevents is not None:
prefixes = prefixes[:nevents]
return prefixes

# TODO use yield from when we increase the python requirement
if op.isdir(path):
for x in _iter_dataset_dir(path, list_prefixes(os.listdir(path)), parts):
yield x
else:
with zipfile.ZipFile(path, mode='r') as z:
for x in _iter_dataset_zip(z, list_prefixes(z.namelist()), parts):
yield x

def _extract_event_id(prefix):
"""Extract event_id from prefix, e.g. event_id=1 from `event000000001`.
"""
return int(prefix[5:])

def _iter_dataset_dir(directory, prefixes, parts):
"""Iterate over selected events files inside a directory.
"""
for p in prefixes:
yield (_extract_event_id(p),) + load_event(op.join(directory, p), parts)

def _iter_dataset_zip(zipfile, prefixes, parts):
""""Iterate over selected event files inside a zip archive.
"""
for p in prefixes:
files = [zipfile.open('{}-{}.csv'.format(p, _), mode='r') for _ in parts]
dtypes = [DTYPES[_] for _ in parts]
data = tuple(pandas.read_csv(f, header=0, index_col=False, dtype=d)
for f, d in zip(files, dtypes))
yield (_extract_event_id(p),) + data

0 comments on commit 78b71a3

Please sign in to comment.