Skip to content

Commit

Permalink
Make compatible with json_tricks numpy handling (#6)
Browse files Browse the repository at this point in the history
Added support for handling np.ndarrays in the json_tricks style. We can load files written by json_tricks and write files that are compatible with it. I also copied over the sparse matrix handler from core.
  • Loading branch information
algrs authored Mar 15, 2017
1 parent 2969f95 commit d7e2f2f
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 9 deletions.
2 changes: 1 addition & 1 deletion baiji/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

__version__ = '1.1.0'
__version__ = '1.2.0'
95 changes: 87 additions & 8 deletions baiji/serialization/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,99 @@ def dump(obj, f, *args, **kwargs):
from baiji.serialization.util.openlib import ensure_file_open_and_call
return ensure_file_open_and_call(f, _dump, 'w', obj, *args, **kwargs)

def dumps(*args, **kwargs):
return json.dumps(*args, **_dump_args(kwargs))

def _dump(f, obj, *args, **kwargs):
return json.dump(obj, f, *args, **_dump_args(kwargs))

def load(f, *args, **kwargs):
from baiji.serialization.util.openlib import ensure_file_open_and_call
return ensure_file_open_and_call(f, _load, 'r', *args, **kwargs)

def _load(f, *args, **kwargs):
return json.load(f, *args, **_load_args(kwargs))

def loads(*args, **kwargs):
return json.loads(*args, **kwargs)
return json.loads(*args, **_load_args(kwargs))

def dumps(*args, **kwargs):
kwargs.update(for_json=True)
return json.dumps(*args, **kwargs)
def _load_args(kwargs):
kwargs.update(object_hook=_json_decode)
return kwargs

def _dump(f, obj, *args, **kwargs):
def _dump_args(kwargs):
if kwargs.get('primitive', False):
kwargs.update(default=_json_encode_primitive)
else:
kwargs.update(default=_json_encode)
if 'primitive' in kwargs:
del kwargs['primitive']
kwargs.update(for_json=True)
return json.dump(obj, f, *args, **kwargs)
return kwargs

def _load(f, *args, **kwargs):
return json.load(f, *args, **kwargs)
def _json_decode(dct):
'''
Handle Custom json decoding and internal double underscore annotated Direct object deserialization
'''
if '__ndarray__' in dct:
try:
import numpy as np
except ImportError:
raise ImportError("JSON file contains numpy arrays; install numpy to load it")
if 'dtype' in dct:
dtype = np.dtype(dct['dtype'])
else:
dtype = np.float64
return np.array(dct['__ndarray__'], dtype=dtype)
elif '__scipy.sparse.sparsematrix__' in dct:
if not 'dtype' in dct and 'shape' in dct and 'data' in dct and 'format' in dct and 'row' in dct and 'col' in dct:
return dct
try:
import numpy as np
import scipy.sparse as sp
except ImportError:
raise ImportError("JSON file contains scipy.sparse arrays; install numpy and scipy to load it")
coo = sp.coo_matrix((dct['data'], (dct['row'], dct['col'])), shape=dct['shape'], dtype=np.dtype(dct['dtype']))
return coo.asformat(dct['format'])
else:
return dct

def _json_encode_primitive(obj):
try:
import numpy as np
if isinstance(obj, np.ndarray):
return obj.tolist()
except ImportError:
# Clearly there won't be any numpy arrays to encode...
pass
return None

def _json_encode(obj):
try:
import numpy as np
if isinstance(obj, np.ndarray):
return {
'__ndarray__': obj.tolist(),
'dtype': obj.dtype.name,
'shape': obj.shape,
}
except ImportError:
# Clearly there won't be any numpy arrays to encode...
pass
try:
import scipy.sparse as sp
if sp.isspmatrix(obj):
coo = obj.tocoo()
return {
'__scipy.sparse.sparsematrix__': True,
'format': obj.getformat(),
'dtype': obj.dtype.name,
'shape': obj.shape,
'data': coo.data,
'row': coo.row,
'col': coo.col,
}
except ImportError:
# Clearly there won't be any scipy.sparse arrays to encode...
pass
return None
47 changes: 47 additions & 0 deletions baiji/serialization/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,50 @@ def test_json_load_path(self):
with open(path, 'w') as f:
f.write(r'["File Test"]')
self.assertEqual(json.load(path), [u'File Test'])


def test_json_load_ndarray_tricks_compatible_1d(self):
import numpy as np
res = json.loads('{"foo": {"__ndarray__": [859.033935546875, 859.033935546875], "dtype": "float32", "shape": [2]}}')
res_array = res["foo"]
self.assertIsInstance(res_array, np.ndarray)
self.assertEqual(res_array.shape, (2, ))
self.assertEqual(res_array.dtype, np.float32)
np.testing.assert_equal(res_array, np.array([859.033935546875, 859.033935546875]))

def test_json_load_ndarray_tricks_compatible_2d(self):
import numpy as np
res = json.loads('{"foo": {"__ndarray__": [[859.0, 859.0], [217.0, 106.0], [302.0, 140.0]], "dtype": "float32", "shape": [3, 2]}}')
res_array = res["foo"]
self.assertIsInstance(res_array, np.ndarray)
self.assertEqual(res_array.shape, (3, 2))
self.assertEqual(res_array.dtype, np.float32)
np.testing.assert_almost_equal(res_array, np.array([[859.0, 859.0], [217.0, 106.0], [302.0, 140.0]]))

def test_json_dump_ndarray_tricks_compatible(self):
import numpy as np
self.assertEqual(
json.dumps({"foo": np.array([[859.0, 859.0], [217.0, 106.0], [302.0, 140.0]], dtype=np.float32)}),
r'{"foo": {"dtype": "float32", "shape": [3, 2], "__ndarray__": [[859.0, 859.0], [217.0, 106.0], [302.0, 140.0]]}}')

def test_json_dump_ndarray_tricks_compatible_primitive_option(self):
import numpy as np
self.assertEqual(
json.dumps({"foo": np.array([[859.0, 859.0], [217.0, 106.0], [302.0, 140.0]], dtype=np.float32)}, primitive=True),
r'{"foo": [[859.0, 859.0], [217.0, 106.0], [302.0, 140.0]]}')

def test_json_load_sprase_matrix(self):
import numpy as np
import scipy.sparse as sp
res = json.loads(r'{"foo": {"format": "dia", "dtype": "float32", "shape": [3, 3], "__scipy.sparse.sparsematrix__": true, "data": {"dtype": "float64", "shape": [3], "__ndarray__": [1.0, 1.0, 1.0]}, "col": {"dtype": "int32", "shape": [3], "__ndarray__": [0, 1, 2]}, "row": {"dtype": "int32", "shape": [3], "__ndarray__": [0, 1, 2]}}}')
res_array = res["foo"]
self.assertIsInstance(res_array, sp.dia.dia_matrix)
self.assertEqual(res_array.shape, (3, 3))
self.assertEqual(res_array.dtype, np.float32)
np.testing.assert_almost_equal(res_array.todense(), np.eye(3))

def test_json_dump_sprase_matrix(self):
import scipy.sparse as sp
self.assertEqual(
json.dumps({"foo": sp.eye(3)}),
r'{"foo": {"format": "dia", "dtype": "float64", "shape": [3, 3], "__scipy.sparse.sparsematrix__": true, "data": {"dtype": "float64", "shape": [3], "__ndarray__": [1.0, 1.0, 1.0]}, "col": {"dtype": "int32", "shape": [3], "__ndarray__": [0, 1, 2]}, "row": {"dtype": "int32", "shape": [3], "__ndarray__": [0, 1, 2]}}}')
2 changes: 2 additions & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
-r requirements.txt

numpy>=1.10.0
nose2==0.5.0
pylint==1.5.4
scipy>=0.18.0

0 comments on commit d7e2f2f

Please sign in to comment.