Skip to content

Commit

Permalink
fix for failing where in queries
Browse files Browse the repository at this point in the history
  • Loading branch information
jreadey committed Mar 20, 2024
1 parent 5554b83 commit e0e12d7
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 46 deletions.
41 changes: 27 additions & 14 deletions hsds/chunk_dn.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ async def PUT_Chunk(request):
if rank != 1:
log.error("expected one-dimensional array for PUT query")
raise HTTPInternalServerError()

try:
parser = BooleanParser(query)
except Exception as e:
Expand Down Expand Up @@ -487,19 +488,31 @@ async def GET_Chunk(request):
select_dt = chunk_arr.dtype

if query:
try:
parser = BooleanParser(query)
except Exception as e:
msg = f"query: {query} is not valid, got exception: {e}"
log.error(msg)
raise HTTPInternalServerError()
try:
eval_str = parser.getEvalStr()
except Exception as e:
msg = f"query: {query} unable to get eval str, got exception: {e}"
log.error(msg)
raise HTTPInternalServerError()
log.debug(f"got eval str: {eval_str} for query: {query}")
# if there's a where clause, just use the expression
# part with BooleanParser
# TBD: Remove when BooleanParser knows how to use where keyword
if query.startswith("where"):
query_expr = None
else:
n = query.find(" where ")
if n > 0:
query_expr = query[:n]
else:
query_expr = query
if query_expr:
try:
parser = BooleanParser(query_expr)
except Exception as e:
msg = f"query: {query} is not valid, got exception: {e}"
log.error(msg)
raise HTTPInternalServerError()
try:
eval_str = parser.getEvalStr()
except Exception as e:
msg = f"query: {query} unable to get eval str, got exception: {e}"
log.error(msg)
raise HTTPInternalServerError()
log.debug(f"got eval str: {eval_str} for query: {query}")

# run given query
try:
Expand All @@ -508,7 +521,7 @@ async def GET_Chunk(request):
"chunk_layout": dims,
"chunk_arr": chunk_arr,
"slices": selection,
"query": eval_str,
"query": query,
"limit": limit,
"select_dt": select_dt,
}
Expand Down
25 changes: 23 additions & 2 deletions hsds/dset_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,31 @@ def getParser(query, dtype):
""" get query BooleanParser. If query contains variables that
arent' part of the data type, throw a HTTPBadRequest exception. """

# separate out the where clause if any
if query.startswith("where"):
where_in = query
expr = None
else:
n = query.find(" where ")
if n > 0:
where_in = query[(n + 1):]
expr = query[:n]
else:
where_in = None
expr = query

if where_in:
log.debug(f"got where in clause: {where_in}")
# TBD: do full syntax check on this

if not expr:
# just a where clause
return None

try:
parser = BooleanParser(query)
parser = BooleanParser(expr)
except Exception:
msg = f"query: {query} is not valid"
msg = f"query: {expr} is not valid"
log.warn(msg)
raise HTTPBadRequest(reason=msg)

Expand Down
81 changes: 57 additions & 24 deletions hsds/util/chunkUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,6 @@ def _getEvalStr(query, arr_name, field_names):
msg = "Mismatched paren"
log.warn("Bad query: " + msg)
raise ValueError(msg)
log.debug(f"eval_str: {eval_str}")
return eval_str


Expand Down Expand Up @@ -1313,25 +1312,34 @@ def chunkQuery(
raise ValueError(msg)
log.debug(f"tbd: where_elements_arr; {where_elements_arr}")
log.debug(f"tbd: chunk_sel[{where_field}]: {chunk_sel[where_field]}")
isin_arr = np.isin(chunk_sel[where_field], where_elements_arr)
log.debug(f"tbd: isin_arr: {isin_arr}")
isin_mask = np.isin(chunk_sel[where_field], where_elements_arr)
log.debug(f"tbd: isin_arr: {isin_mask}")

if not np.any(isin_arr):
if not np.any(isin_mask):
# all false
log.debug("query - no rows found for where elements")
return None
log.debug(f"tbd - isin_mask: {isin_mask}")

chunk_sel = chunk_sel[isin_arr]
log.debug(f"tbd - chunk_sel after boolean selection: {chunk_sel}")
isin_indices = np.where(isin_mask)
if not isinstance(isin_indices, tuple):
log.warn(f"expected where_indices of tuple but got: {type(isin_indices)}")
return None
if len(isin_indices) == 0:
log.warn("chunkQuery - got empty tuple where in result")
return None
log.debug(f"tbd: isin_indices: {isin_indices}")

if len(chunk_sel) == 0:
log.debug("query - no elements matched where list")
isin_indices = isin_indices[0]
if not isinstance(isin_indices, np.ndarray):
log.warn(f"expected isin_indices of ndarray but got: {type(isin_indices)}")
return None
if not eval_str:
# can just return the array
return chunk_sel
log.debug(f"tbd - isin_indices: {isin_indices}")
nrows = isin_indices.shape[0]
log.debug(f"tbd - isin_indices nrows: {nrows}")
elif eval_str:
log.debug("no where keyword")
isin_indices = None
else:
log.warn("query - no eval and no where in, returning None")
return None
Expand All @@ -1353,20 +1361,45 @@ def chunkQuery(
else:
replace_mask = None

where_indices = np.where(eval(eval_str))
if not isinstance(where_indices, tuple):
log.warn(f"expected where_indices of tuple but got: {type(where_indices)}")
return None
if len(where_indices) == 0:
log.warn("chunkQuery - got empty tuple where result")
return None
if eval_str:
where_indices = np.where(eval(eval_str))
if not isinstance(where_indices, tuple):
log.warn(f"expected where_indices of tuple but got: {type(where_indices)}")
return None
if len(where_indices) == 0:
log.warn("chunkQuery - got empty tuple where result")
return None

where_indices = where_indices[0]
if not isinstance(where_indices, np.ndarray):
log.warn(f"expected where_indices of ndarray but got: {type(where_indices)}")
return None
nrows = where_indices.shape[0]
log.debug(f"chunkQuery - {nrows} found")
where_indices = where_indices[0]
log.debug(f"tbd - where_indices: {where_indices}")
if not isinstance(where_indices, np.ndarray):
log.warn(f"expected where_indices of ndarray but got: {type(where_indices)}")
return None
nrows = where_indices.shape[0]
log.debug(f"chunkQuery - {nrows} where rows found")
else:
where_indices = None

log.debug("tbd - check isin_indices")
if isin_indices is None:
pass # skip intersection
else:
if where_indices is None:
# just use the isin_indices
where_indices = isin_indices
else:
# interest the two sets of indices
intersect = np.intersect1d(where_indices, isin_indices)
log.debug(f"tbd - intersect: {intersect}")

nrows = intersect.shape[0]
if nrows == 0:
log.debug("chunkQuery - no rows found after intersect with is in")
return None
else:
log.debug(f"chunkQuery - intersection, {nrows} found")
# use the intsection as our new where index
where_indices = intersect

if limit > 0 and nrows > limit:
# truncate to limit rows
Expand Down
19 changes: 13 additions & 6 deletions tests/integ/query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,11 @@ def verifyQueryRsp(rsp, expected_indices=None, expect_bin=None):
kwargs["expect_bin"] = False

# items in list
# TBD - needs update for chunk_dn.py to work
"""
params = {"query": "where stock_symbol in (b'AAPL', b'EBAY')"}
params = {"query": "open < 4000 where stock_symbol in (b'AAPL', b'EBAY')"}
rsp = self.session.get(req, params=params, headers=query_headers)
self.assertEqual(rsp.status_code, 200)
kwargs["expected_indices"] = [0, 1, 3, 4, 6, 7, 9, 10]
print(rsp.text)
verifyQueryRsp(rsp, **kwargs)
"""

# read first row with AAPL
params = {"query": "stock_symbol == b'AAPL'", "Limit": 1}
Expand Down Expand Up @@ -214,10 +210,21 @@ def verifyQueryRsp(rsp, expected_indices=None, expect_bin=None):
kwargs["expected_indices"] = (4, 7, 10)
verifyQueryRsp(rsp, **kwargs)

params = {"query": "where stock_symbol in (b'AAPL', b'EBAY')"}
rsp = self.session.get(req, params=params, headers=query_headers)
self.assertEqual(rsp.status_code, 200)
kwargs["expected_indices"] = [0, 1, 3, 4, 6, 7, 9, 10]
verifyQueryRsp(rsp, **kwargs)
params = {"query": "open < 3000 where stock_symbol in (b'AAPL', b'EBAY')"}
rsp = self.session.get(req, params=params, headers=query_headers)
self.assertEqual(rsp.status_code, 200)
kwargs["expected_indices"] = [6, 7, 9, 10]
verifyQueryRsp(rsp, **kwargs)

# combine with Limit
params["Limit"] = 2
rsp = self.session.get(req, params=params, headers=query_headers)
kwargs["expected_indices"] = (4, 7)
kwargs["expected_indices"] = (6, 7)
verifyQueryRsp(rsp, **kwargs)

# try bad Limit
Expand Down

0 comments on commit e0e12d7

Please sign in to comment.