Skip to content

Commit

Permalink
MNT Refactor with a SaveState and avoid duplicate numpy arrays (skops…
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan authored Oct 6, 2022
1 parent 8fe2577 commit 4d8d70f
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 87 deletions.
57 changes: 32 additions & 25 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from __future__ import annotations

import json
from functools import partial
from types import FunctionType
from typing import Any

import numpy as np

from ._utils import _import_obj, get_instance, get_module, get_state, gettype
from ._utils import SaveState, _import_obj, get_instance, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException


def dict_get_state(obj, dst):
def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}

key_types = get_state([type(key) for key in obj.keys()], dst)
key_types = get_state([type(key) for key in obj.keys()], save_state)
content = {}
for key, value in obj.items():
if isinstance(value, property):
continue
if np.isscalar(key) and hasattr(key, "item"):
# convert numpy value to python object
key = key.item()
content[key] = get_state(value, dst)
key = key.item() # type: ignore
content[key] = get_state(value, save_state)
res["content"] = content
res["key_types"] = key_types
return res
Expand All @@ -36,14 +39,14 @@ def dict_get_instance(state, src):
return content


def list_get_state(obj, dst):
def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}
content = []
for value in obj:
content.append(get_state(value, dst))
content.append(get_state(value, save_state))
res["content"] = content
return res

Expand All @@ -55,12 +58,12 @@ def list_get_instance(state, src):
return content


def tuple_get_state(obj, dst):
def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}
content = tuple(get_state(value, dst) for value in obj)
content = tuple(get_state(value, save_state) for value in obj)
res["content"] = content
return res

Expand All @@ -86,7 +89,7 @@ def isnamedtuple(t):
return content


def function_get_state(obj, dst):
def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(obj),
Expand All @@ -103,16 +106,16 @@ def function_get_instance(state, src):
return loaded


def partial_get_state(obj, dst):
def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
_, _, (func, args, kwds, namespace) = obj.__reduce__()
res = {
"__class__": "partial", # don't allow any subclass
"__module__": get_module(type(obj)),
"content": {
"func": get_state(func, dst),
"args": get_state(args, dst),
"kwds": get_state(kwds, dst),
"namespace": get_state(namespace, dst),
"func": get_state(func, save_state),
"args": get_state(args, save_state),
"kwds": get_state(kwds, save_state),
"namespace": get_state(namespace, save_state),
},
}
return res
Expand All @@ -129,7 +132,7 @@ def partial_get_instance(state, src):
return instance


def type_get_state(obj, dst):
def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
# To serialize a type, we first need to set the metadata to tell that it's
# a type, then store the type's info itself in the content field.
res = {
Expand All @@ -148,7 +151,7 @@ def type_get_instance(state, src):
return loaded


def slice_get_state(obj, dst):
def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
Expand All @@ -168,13 +171,19 @@ def slice_get_instance(state, src):
return slice(start, stop, step)


def object_get_state(obj, dst):
def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
# This method is for objects which can either be persisted with json, or
# the ones for which we can get/set attributes through
# __getstate__/__setstate__ or reading/writing to __dict__.
try:
# if we can simply use json, then we're done.
return json.dumps(obj)
obj_str = json.dumps(obj)
return {
"__class__": "str",
"__module__": "builtins",
"content": obj_str,
"is_json": True,
}
except Exception:
pass

Expand All @@ -192,18 +201,16 @@ def object_get_state(obj, dst):
else:
return res

content = get_state(attrs, dst)
content = get_state(attrs, save_state)
# it's sufficient to store the "content" because we know that this dict can
# only have str type keys
res["content"] = content
return res


def object_get_instance(state, src):
try:
return json.loads(state)
except Exception:
pass
if state.get("is_json", False):
return json.loads(state["content"])

cls = gettype(state)

Expand All @@ -225,7 +232,7 @@ def object_get_instance(state, src):
return instance


def unsupported_get_state(obj, dst):
def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
raise UnsupportedTypeException(obj)


Expand Down
75 changes: 41 additions & 34 deletions skops/io/_numpy.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
from __future__ import annotations

import io
from pathlib import Path
from uuid import uuid4
from typing import Any

import numpy as np

from ._general import function_get_instance
from ._utils import _import_obj, get_instance, get_module, get_state
from ._utils import SaveState, _import_obj, get_instance, get_module, get_state
from .exceptions import UnsupportedTypeException


def ndarray_get_state(obj, dst):
def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}

# First, try to save object with np.save and allow_pickle=False, which
# should generally work as long as the dtype is not object.
try:
f_name = f"{uuid4()}.npy"
with open(Path(dst) / f_name, "wb") as f:
np.save(f, obj, allow_pickle=False)
res.update(type="numpy", file=f_name)
# If the dtype is object, np.save should not work with
# allow_pickle=False, therefore we convert them to a list and
# recursively call get_state on it.
if obj.dtype == object:
obj_serialized = get_state(obj.tolist(), save_state)
res["content"] = obj_serialized["content"]
res["type"] = "json"
res["shape"] = get_state(obj.shape, save_state)
else:
# Memoize the object and then check if it's file name (containing
# the object id) already exists. If it does, there is no need to
# save the object again. Memoizitation is necessary since for
# ephemeral objects, the same id might otherwise be reused.
obj_id = save_state.memoize(obj)
f_name = f"{obj_id}.npy"
path = save_state.path / f_name
if not path.exists():
with open(path, "wb") as f:
np.save(f, obj, allow_pickle=False)
res.update(type="numpy", file=f_name)
except ValueError:
# Object arrays cannot be saved with allow_pickle=False, therefore we
# convert them to a list and recursively call get_state on it. For this,
# we expect the dtype to be object.
if obj.dtype != object:
raise UnsupportedTypeException(
f"numpy arrays of dtype {obj.dtype} are not supported yet, please "
"open an issue at https://github.com/skops-dev/skops/issues and "
"report your error"
)

obj_serialized = get_state(obj.tolist(), dst)
res["content"] = obj_serialized["content"]
res["type"] = "json"
res["shape"] = get_state(obj.shape, dst)
# Couldn't save the numpy array with either method
raise UnsupportedTypeException(
f"numpy arrays of dtype {obj.dtype} are not supported yet, please "
"open an issue at https://github.com/skops-dev/skops/issues and "
"report your error"
)

return res

Expand Down Expand Up @@ -67,13 +74,13 @@ def ndarray_get_instance(state, src):
return val


def maskedarray_get_state(obj, dst):
def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"content": {
"data": get_state(obj.data, dst),
"mask": get_state(obj.mask, dst),
"data": get_state(obj.data, save_state),
"mask": get_state(obj.mask, save_state),
},
}
return res
Expand All @@ -85,8 +92,8 @@ def maskedarray_get_instance(state, src):
return np.ma.MaskedArray(data, mask)


def random_state_get_state(obj, dst):
content = get_state(obj.get_state(legacy=False), dst)
def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
content = get_state(obj.get_state(legacy=False), save_state)
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
Expand All @@ -103,7 +110,7 @@ def random_state_get_instance(state, src):
return random_state


def random_generator_get_state(obj, dst):
def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
bit_generator_state = obj.bit_generator.state
res = {
"__class__": obj.__class__.__name__,
Expand All @@ -128,7 +135,7 @@ def random_generator_get_instance(state, src):
# For numpy.ufunc we need to get the type from the type's module, but for other
# functions we get it from objet's module directly. Therefore sett a especial
# get_state method for them here. The load is the same as other functions.
def ufunc_get_state(obj, dst):
def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__, # ufunc
"__module__": get_module(type(obj)), # numpy
Expand All @@ -140,14 +147,14 @@ def ufunc_get_state(obj, dst):
return res


def dtype_get_state(obj, dst):
def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
# we use numpy's internal save mechanism to store the dtype by
# saving/loading an empty array with that dtype.
tmp = np.ndarray(0, dtype=obj)
tmp: np.typing.NDArray = np.ndarray(0, dtype=obj)
res = {
"__class__": "dtype",
"__module__": "numpy",
"content": ndarray_get_state(tmp, dst),
"content": ndarray_get_state(tmp, save_state),
}
return res

Expand Down
16 changes: 8 additions & 8 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@

import skops

from ._utils import _get_instance, _get_state, get_instance, get_state

# For now, there is just one protocol version
PROTOCOL = 0

from ._utils import SaveState, _get_instance, _get_state, get_instance, get_state

# We load the dispatch functions from the corresponding modules and register
# them.
Expand Down Expand Up @@ -53,9 +49,13 @@ def save(obj, file):
"""
with tempfile.TemporaryDirectory() as dst:
with open(Path(dst) / "schema.json", "w") as f:
state = get_state(obj, dst)
state["protocol"] = PROTOCOL
path = Path(dst)
with open(path / "schema.json", "w") as f:
save_state = SaveState(path=path)
state = get_state(obj, save_state)
save_state.clear_memo()

state["protocol"] = save_state.protocol
state["_skops_version"] = skops.__version__
json.dump(state, f, indent=2)

Expand Down
22 changes: 15 additions & 7 deletions skops/io/_scipy.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
from __future__ import annotations

import io
from pathlib import Path
from uuid import uuid4
from typing import Any

from scipy.sparse import load_npz, save_npz, spmatrix

from ._utils import get_module
from ._utils import SaveState, get_module


def sparse_matrix_get_state(obj, dst):
def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}

f_name = f"{uuid4()}.npz"
save_npz(Path(dst) / f_name, obj)
# Memoize the object and then check if it's file name (containing the object
# id) already exists. If it does, there is no need to save the object again.
# Memoizitation is necessary since for ephemeral objects, the same id might
# otherwise be reused.
obj_id = save_state.memoize(obj)
f_name = f"{obj_id}.npz"
path = save_state.path / f_name
if not path.exists():
save_npz(path, obj)

res["type"] = "scipy"
res["file"] = f_name

return res


Expand Down
Loading

0 comments on commit 4d8d70f

Please sign in to comment.