diff --git a/h5json/hdf5db.py b/h5json/hdf5db.py index 900117f..a03c0f8 100644 --- a/h5json/hdf5db.py +++ b/h5json/hdf5db.py @@ -2183,6 +2183,181 @@ def getDatasetValuesByUuid(self, obj_uuid, slices=Ellipsis, format="json"): values = values.tobytes() return values + + """ + doDatasetQueryByUuid: return rows based on query string + Return rows from a dataset that matches query string. + + Note: Only supported for compound_type/one-dimensional datasets + """ + def doDatasetQueryByUuid(self, obj_uuid, query, start=0, stop=-1, step=1, limit=None): + self.log.info("doQueryByUuid - uuid: " + obj_uuid + " query:" + query) + self.log.info("start: " + str(start) + " stop: " + str(stop) + " step: " + str(step) + " limit: " + str(limit)) + dset = self.getDatasetObjByUuid(obj_uuid) + if dset is None: + msg = "Dataset: " + obj_uuid + " not found" + self.log.info(msg) + raise IOError(errno.ENXIO, msg) + + values = [] + dt = dset.dtype + typeItem = getTypeItem(dt) + itemSize = getItemSize(typeItem) + if typeItem['class'] != "H5T_COMPOUND": + msg = "Only compound type datasets can be used as query target" + self.log.info(msg) + raise IOError(errno.EINVAL, msg) + + if dset.shape is None: + # null space dataset (with h5py 2.6.0) + return None + + rank = len(dset.shape) + if rank != 1: + msg = "One one-dimensional datasets can be used as query target" + self.log.info(msg) + raise IOError(errno.EINVAL, msg) + + + values = [] + indexes = [] + count = 0 + + num_elements = dset.shape[0] + if stop == -1: + stop = num_elements + elif stop > num_elements: + stop = num_elements + block_size = self._getBlockSize(dset) + self.log.info("block_size: " + str(block_size)) + + field_names = list(dset.dtype.fields.keys()) + eval_str = self._getEvalStr(query, field_names) + + while start < stop: + if limit and (count == limit): + break # no more rows for this batch + end = start + block_size + if end > stop: + end = stop + rows = dset[start:end] # read from dataset + where_result = np.where(eval(eval_str)) + index = where_result[0].tolist() + if len(index) > 0: + for i in index: + row = rows[i] + item = self.bytesArrayToList(row) + values.append(item) + indexes.append(start + i) + count += 1 + if limit and (count == limit): + break # no more rows for this batch + + start = end # go to next block + + + # values = self.getDataValue(item_type, values, dimension=1, dims=(len(values),)) + + self.log.info("got " + str(count) + " query matches") + return (indexes, values) + + """ + _getBlockSize: Get number of rows to read from disk + + heurestic to get reasonable sized chunk of data to fetch. + make multiple of chunk_size if possible + """ + def _getBlockSize(self, dset): + target_block_size = 256 * 1000 + if dset.chunks: + chunk_size = dset.chunks[0] + if chunk_size < target_block_size: + block_size = (target_block_size // chunk_size) * chunk_size + else: + block_size = target_block_size + else: + block_size = target_block_size + return block_size + + """ + _getEvalStr: Get eval string for given query + + Gets Eval string to use with numpy where method. + """ + def _getEvalStr(self, query, field_names): + i = 0 + eval_str = "" + var_name = None + end_quote_char = None + var_count = 0 + paren_count = 0 + black_list = ( "import", ) # field names that are not allowed + self.log.info("getEvalStr(" + query + ")") + for item in black_list: + if item in field_names: + msg = "invalid field name" + self.log.info("EINVAL: " + msg) + raise IOError(errno.EINVAL, msg) + while i < len(query): + ch = query[i] + if (i+1) < len(query): + ch_next = query[i+1] + else: + ch_next = None + if var_name and not ch.isalnum(): + # end of variable + if var_name not in field_names: + # invalid + msg = "unknown field name" + self.log.info("EINVAL: " + msg) + raise IOError(errno.EINVAL, msg) + eval_str += "rows['" + var_name + "']" + var_name = None + var_count += 1 + + if end_quote_char: + if ch == end_quote_char: + # end of literal + end_quote_char = None + eval_str += ch + elif ch in ("'", '"'): + end_quote_char = ch + eval_str += ch + elif ch.isalpha(): + if ch == 'b' and ch_next in ("'", '"'): + eval_str += 'b' # start of a byte string literal + elif var_name is None: + var_name = ch # start of a variable + else: + var_name += ch + elif ch == '(' and end_quote_char is None: + paren_count += 1 + eval_str += ch + elif ch == ')' and end_quote_char is None: + paren_count -= 1 + if paren_count < 0: + msg = "Mismatched paren" + self.log.info("EINVAL: " + msg) + raise IOError(errno.EINVAL, msg) + eval_str += ch + else: + # just add to eval_str + eval_str += ch + i = i+1 + if end_quote_char: + msg = "no matching quote character" + self.log.info("EINVAL: " + msg) + raise IOError(errno.EINVAL, msg) + if var_count == 0: + msg = "No field value" + self.log.info("EINVAL: " + msg) + raise IOError(errno.EINVAL, msg) + if paren_count != 0: + msg = "Mismatched paren" + self.log.info("EINVAL: " + msg) + raise IOError(errno.EINVAL, msg) + + return eval_str """ Get values from dataset identified by obj_uuid using the given diff --git a/test/unit/hdf5dbTest.py b/test/unit/hdf5dbTest.py index b2707ea..275ec60 100755 --- a/test/unit/hdf5dbTest.py +++ b/test/unit/hdf5dbTest.py @@ -1162,7 +1162,58 @@ def testRootAcl(self): self.assertEqual(acl['delete'], 0) self.assertEqual(acl['readACL'], 0) self.assertEqual(acl['updateACL'], 0) - + + def testGetEvalStr(self): + queries = { "date == 23": "rows['date'] == 23", + "wind == b'W 5'": "rows['wind'] == b'W 5'", + "temp > 61": "rows['temp'] > 61", + "(date >=22) & (date <= 24)": "(rows['date'] >=22) & (rows['date'] <= 24)", + "(date == 21) & (temp > 70)": "(rows['date'] == 21) & (rows['temp'] > 70)", + "(wind == b'E 7') | (wind == b'S 7')": "(rows['wind'] == b'E 7') | (rows['wind'] == b'S 7')" } + + fields = ["date", "wind", "temp" ] + filepath = getFile('empty.h5', 'getevalstring.h5') + with Hdf5db(filepath, app_logger=self.log) as db: + + for query in queries.keys(): + eval_str = db._getEvalStr(query, fields) + self.assertEqual(eval_str, queries[query]) + #print(query, "->", eval_str) + + def testBadQuery(self): + queries = ( "foobar", # no variable used + "wind = b'abc", # non-closed literal + "(wind = b'N') & (temp = 32", # missing paren + "foobar > 42", # invalid field name + "import subprocess; subprocess.call(['ls', '/'])") # injection attack + + fields = ("date", "wind", "temp" ) + filepath = getFile('empty.h5', 'badquery.h5') + with Hdf5db(filepath, app_logger=self.log) as db: + + for query in queries: + try: + eval_str = db._getEvalStr(query, fields) + self.assertTrue(False) # shouldn't get here + except IOError as e: + pass # ok + + def testInjectionBlock(self): + queries = ( + "import subprocess; subprocess.call(['ls', '/'])", ) # injection attack + + fields = ("import", "subprocess", "call" ) + filepath = getFile('empty.h5', 'injectionblock.h5') + with Hdf5db(filepath, app_logger=self.log) as db: + + for query in queries: + try: + eval_str = db._getEvalStr(query, fields) + self.assertTrue(False) # shouldn't get here + except IOError as e: + pass # ok + + if __name__ == '__main__':