Skip to content

Commit

Permalink
add variable_names, update to v0.1.3
Browse files Browse the repository at this point in the history
  • Loading branch information
yymao committed Nov 12, 2017
1 parent 74b5d60 commit da55289
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 5 deletions.
45 changes: 42 additions & 3 deletions easyquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
http://opensource.org/licenses/MIT
"""

import warnings
import numpy as np
import numexpr as ne

if not hasattr(list, 'copy'):
try:
from builtins import list
except ImportError:
raise ImportError('Please install python package "future"')
import numpy as np
import numexpr as ne


__all__ = ['Query']
__version__ = '0.1.2'
__version__ = '0.1.3'


def _is_string_like(obj):
"""
Expand Down Expand Up @@ -73,6 +77,7 @@ class Query(object):
def __init__(self, *queries):
self._operator = None
self._operands = None
self._variable_names = None
self._query_class = type(self)

if len(queries) == 1:
Expand Down Expand Up @@ -284,6 +289,40 @@ def copy(self):
return out


@staticmethod
def _get_variable_names(basic_query):
if _is_string_like(basic_query):
return tuple(set(ne.necompiler.precompile(basic_query)[-1]))

elif callable(basic_query):
warnings.warn('`variable_names` does not support a single callable query')
return tuple()

elif isinstance(basic_query, tuple) and len(basic_query) > 1 and callable(basic_query[0]):
return tuple(set(basic_query[1:]))


@property
def variable_names(self):
if self._variable_names is None:

if self._operator is None:
if self._operands is None:
self._variable_names = tuple()
else:
self._variable_names = self._get_variable_names(self._operands)

elif self._operator == 'NOT':
self._variable_names = self._operands.variable_names

else:
v = list()
for op in self._operands:
v.extend(op.variable_names)
self._variable_names = tuple(set(v))

return self._variable_names


_query_class = Query

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

setup(
name='easyquery',
version='0.1.2',
version='0.1.3',
description='Create easy-to-use Query objects that can apply on NumPy structured arrays, astropy Table, and Pandas DataFrame.',
url='https://github.com/yymao/easyquery',
download_url = 'https://github.com/yymao/easyquery/archive/v0.1.2.zip',
download_url = 'https://github.com/yymao/easyquery/archive/v0.1.3.zip',
author='Yao-Yuan Mao',
author_email='[email protected]',
maintainer='Yao-Yuan Mao',
Expand Down
18 changes: 18 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,24 @@ def test_derive_class():
do_compound_query(t, DictQuery, check_query_on_dict_table)


def test_variable_names():
q1 = Query('log(a) > b**2.0')
q2 = Query((lambda x, y: x + y < 1, 'c', 'd'))
q3 = q1 & 'a + 2'
q4 = ~q2
q5 = q1 ^ q2
q6 = Query('sin(5)')
q7 = Query()

assert set(q1.variable_names) == {'a', 'b'}
assert set(q2.variable_names) == {'c', 'd'}
assert set(q3.variable_names) == {'a', 'b'}
assert set(q4.variable_names) == {'c', 'd'}
assert set(q5.variable_names) == {'a', 'b', 'c', 'd'}
assert set(q6.variable_names) == set()
assert set(q7.variable_names) == set()


if __name__ == '__main__':
test_valid_init()
test_invalid_init()
Expand Down

0 comments on commit da55289

Please sign in to comment.