Skip to content

Commit

Permalink
add query method
Browse files Browse the repository at this point in the history
  • Loading branch information
jreadey committed Apr 21, 2016
1 parent 81dcdd2 commit a85c6b1
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 1 deletion.
175 changes: 175 additions & 0 deletions h5json/hdf5db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2183,6 +2183,181 @@ def getDatasetValuesByUuid(self, obj_uuid, slices=Ellipsis, format="json"):
values = values.tobytes()

return values

"""
doDatasetQueryByUuid: return rows based on query string
Return rows from a dataset that matches query string.
Note: Only supported for compound_type/one-dimensional datasets
"""
def doDatasetQueryByUuid(self, obj_uuid, query, start=0, stop=-1, step=1, limit=None):
self.log.info("doQueryByUuid - uuid: " + obj_uuid + " query:" + query)
self.log.info("start: " + str(start) + " stop: " + str(stop) + " step: " + str(step) + " limit: " + str(limit))
dset = self.getDatasetObjByUuid(obj_uuid)
if dset is None:
msg = "Dataset: " + obj_uuid + " not found"
self.log.info(msg)
raise IOError(errno.ENXIO, msg)

values = []
dt = dset.dtype
typeItem = getTypeItem(dt)
itemSize = getItemSize(typeItem)
if typeItem['class'] != "H5T_COMPOUND":
msg = "Only compound type datasets can be used as query target"
self.log.info(msg)
raise IOError(errno.EINVAL, msg)

if dset.shape is None:
# null space dataset (with h5py 2.6.0)
return None

rank = len(dset.shape)
if rank != 1:
msg = "One one-dimensional datasets can be used as query target"
self.log.info(msg)
raise IOError(errno.EINVAL, msg)


values = []
indexes = []
count = 0

num_elements = dset.shape[0]
if stop == -1:
stop = num_elements
elif stop > num_elements:
stop = num_elements
block_size = self._getBlockSize(dset)
self.log.info("block_size: " + str(block_size))

field_names = list(dset.dtype.fields.keys())
eval_str = self._getEvalStr(query, field_names)

while start < stop:
if limit and (count == limit):
break # no more rows for this batch
end = start + block_size
if end > stop:
end = stop
rows = dset[start:end] # read from dataset
where_result = np.where(eval(eval_str))
index = where_result[0].tolist()
if len(index) > 0:
for i in index:
row = rows[i]
item = self.bytesArrayToList(row)
values.append(item)
indexes.append(start + i)
count += 1
if limit and (count == limit):
break # no more rows for this batch

start = end # go to next block


# values = self.getDataValue(item_type, values, dimension=1, dims=(len(values),))

self.log.info("got " + str(count) + " query matches")
return (indexes, values)

"""
_getBlockSize: Get number of rows to read from disk
heurestic to get reasonable sized chunk of data to fetch.
make multiple of chunk_size if possible
"""
def _getBlockSize(self, dset):
target_block_size = 256 * 1000
if dset.chunks:
chunk_size = dset.chunks[0]
if chunk_size < target_block_size:
block_size = (target_block_size // chunk_size) * chunk_size
else:
block_size = target_block_size
else:
block_size = target_block_size
return block_size

"""
_getEvalStr: Get eval string for given query
Gets Eval string to use with numpy where method.
"""
def _getEvalStr(self, query, field_names):
i = 0
eval_str = ""
var_name = None
end_quote_char = None
var_count = 0
paren_count = 0
black_list = ( "import", ) # field names that are not allowed
self.log.info("getEvalStr(" + query + ")")
for item in black_list:
if item in field_names:
msg = "invalid field name"
self.log.info("EINVAL: " + msg)
raise IOError(errno.EINVAL, msg)
while i < len(query):
ch = query[i]
if (i+1) < len(query):
ch_next = query[i+1]
else:
ch_next = None
if var_name and not ch.isalnum():
# end of variable
if var_name not in field_names:
# invalid
msg = "unknown field name"
self.log.info("EINVAL: " + msg)
raise IOError(errno.EINVAL, msg)
eval_str += "rows['" + var_name + "']"
var_name = None
var_count += 1

if end_quote_char:
if ch == end_quote_char:
# end of literal
end_quote_char = None
eval_str += ch
elif ch in ("'", '"'):
end_quote_char = ch
eval_str += ch
elif ch.isalpha():
if ch == 'b' and ch_next in ("'", '"'):
eval_str += 'b' # start of a byte string literal
elif var_name is None:
var_name = ch # start of a variable
else:
var_name += ch
elif ch == '(' and end_quote_char is None:
paren_count += 1
eval_str += ch
elif ch == ')' and end_quote_char is None:
paren_count -= 1
if paren_count < 0:
msg = "Mismatched paren"
self.log.info("EINVAL: " + msg)
raise IOError(errno.EINVAL, msg)
eval_str += ch
else:
# just add to eval_str
eval_str += ch
i = i+1
if end_quote_char:
msg = "no matching quote character"
self.log.info("EINVAL: " + msg)
raise IOError(errno.EINVAL, msg)
if var_count == 0:
msg = "No field value"
self.log.info("EINVAL: " + msg)
raise IOError(errno.EINVAL, msg)
if paren_count != 0:
msg = "Mismatched paren"
self.log.info("EINVAL: " + msg)
raise IOError(errno.EINVAL, msg)

return eval_str

"""
Get values from dataset identified by obj_uuid using the given
Expand Down
53 changes: 52 additions & 1 deletion test/unit/hdf5dbTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,58 @@ def testRootAcl(self):
self.assertEqual(acl['delete'], 0)
self.assertEqual(acl['readACL'], 0)
self.assertEqual(acl['updateACL'], 0)


def testGetEvalStr(self):
queries = { "date == 23": "rows['date'] == 23",
"wind == b'W 5'": "rows['wind'] == b'W 5'",
"temp > 61": "rows['temp'] > 61",
"(date >=22) & (date <= 24)": "(rows['date'] >=22) & (rows['date'] <= 24)",
"(date == 21) & (temp > 70)": "(rows['date'] == 21) & (rows['temp'] > 70)",
"(wind == b'E 7') | (wind == b'S 7')": "(rows['wind'] == b'E 7') | (rows['wind'] == b'S 7')" }

fields = ["date", "wind", "temp" ]
filepath = getFile('empty.h5', 'getevalstring.h5')
with Hdf5db(filepath, app_logger=self.log) as db:

for query in queries.keys():
eval_str = db._getEvalStr(query, fields)
self.assertEqual(eval_str, queries[query])
#print(query, "->", eval_str)

def testBadQuery(self):
queries = ( "foobar", # no variable used
"wind = b'abc", # non-closed literal
"(wind = b'N') & (temp = 32", # missing paren
"foobar > 42", # invalid field name
"import subprocess; subprocess.call(['ls', '/'])") # injection attack

fields = ("date", "wind", "temp" )
filepath = getFile('empty.h5', 'badquery.h5')
with Hdf5db(filepath, app_logger=self.log) as db:

for query in queries:
try:
eval_str = db._getEvalStr(query, fields)
self.assertTrue(False) # shouldn't get here
except IOError as e:
pass # ok

def testInjectionBlock(self):
queries = (
"import subprocess; subprocess.call(['ls', '/'])", ) # injection attack

fields = ("import", "subprocess", "call" )
filepath = getFile('empty.h5', 'injectionblock.h5')
with Hdf5db(filepath, app_logger=self.log) as db:

for query in queries:
try:
eval_str = db._getEvalStr(query, fields)
self.assertTrue(False) # shouldn't get here
except IOError as e:
pass # ok




if __name__ == '__main__':
Expand Down

0 comments on commit a85c6b1

Please sign in to comment.