diff --git a/skops/io/_general.py b/skops/io/_general.py index 1b046bc5..88413472 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -1,28 +1,31 @@ +from __future__ import annotations + import json from functools import partial from types import FunctionType +from typing import Any import numpy as np -from ._utils import _import_obj, get_instance, get_module, get_state, gettype +from ._utils import SaveState, _import_obj, get_instance, get_module, get_state, gettype from .exceptions import UnsupportedTypeException -def dict_get_state(obj, dst): +def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), } - key_types = get_state([type(key) for key in obj.keys()], dst) + key_types = get_state([type(key) for key in obj.keys()], save_state) content = {} for key, value in obj.items(): if isinstance(value, property): continue if np.isscalar(key) and hasattr(key, "item"): # convert numpy value to python object - key = key.item() - content[key] = get_state(value, dst) + key = key.item() # type: ignore + content[key] = get_state(value, save_state) res["content"] = content res["key_types"] = key_types return res @@ -36,14 +39,14 @@ def dict_get_instance(state, src): return content -def list_get_state(obj, dst): +def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), } content = [] for value in obj: - content.append(get_state(value, dst)) + content.append(get_state(value, save_state)) res["content"] = content return res @@ -55,12 +58,12 @@ def list_get_instance(state, src): return content -def tuple_get_state(obj, dst): +def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), } - content = tuple(get_state(value, dst) for value in obj) + content = tuple(get_state(value, save_state) for value in obj) res["content"] = content return res @@ -86,7 +89,7 @@ def isnamedtuple(t): return content -def function_get_state(obj, dst): +def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(obj), @@ -103,16 +106,16 @@ def function_get_instance(state, src): return loaded -def partial_get_state(obj, dst): +def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: _, _, (func, args, kwds, namespace) = obj.__reduce__() res = { "__class__": "partial", # don't allow any subclass "__module__": get_module(type(obj)), "content": { - "func": get_state(func, dst), - "args": get_state(args, dst), - "kwds": get_state(kwds, dst), - "namespace": get_state(namespace, dst), + "func": get_state(func, save_state), + "args": get_state(args, save_state), + "kwds": get_state(kwds, save_state), + "namespace": get_state(namespace, save_state), }, } return res @@ -129,7 +132,7 @@ def partial_get_instance(state, src): return instance -def type_get_state(obj, dst): +def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # To serialize a type, we first need to set the metadata to tell that it's # a type, then store the type's info itself in the content field. res = { @@ -148,7 +151,7 @@ def type_get_instance(state, src): return loaded -def slice_get_state(obj, dst): +def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -168,13 +171,19 @@ def slice_get_instance(state, src): return slice(start, stop, step) -def object_get_state(obj, dst): +def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # This method is for objects which can either be persisted with json, or # the ones for which we can get/set attributes through # __getstate__/__setstate__ or reading/writing to __dict__. try: # if we can simply use json, then we're done. - return json.dumps(obj) + obj_str = json.dumps(obj) + return { + "__class__": "str", + "__module__": "builtins", + "content": obj_str, + "is_json": True, + } except Exception: pass @@ -192,7 +201,7 @@ def object_get_state(obj, dst): else: return res - content = get_state(attrs, dst) + content = get_state(attrs, save_state) # it's sufficient to store the "content" because we know that this dict can # only have str type keys res["content"] = content @@ -200,10 +209,8 @@ def object_get_state(obj, dst): def object_get_instance(state, src): - try: - return json.loads(state) - except Exception: - pass + if state.get("is_json", False): + return json.loads(state["content"]) cls = gettype(state) @@ -225,7 +232,7 @@ def object_get_instance(state, src): return instance -def unsupported_get_state(obj, dst): +def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: raise UnsupportedTypeException(obj) diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 3f8ccf55..912bb81e 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -1,42 +1,49 @@ +from __future__ import annotations + import io -from pathlib import Path -from uuid import uuid4 +from typing import Any import numpy as np from ._general import function_get_instance -from ._utils import _import_obj, get_instance, get_module, get_state +from ._utils import SaveState, _import_obj, get_instance, get_module, get_state from .exceptions import UnsupportedTypeException -def ndarray_get_state(obj, dst): +def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), } - # First, try to save object with np.save and allow_pickle=False, which - # should generally work as long as the dtype is not object. try: - f_name = f"{uuid4()}.npy" - with open(Path(dst) / f_name, "wb") as f: - np.save(f, obj, allow_pickle=False) - res.update(type="numpy", file=f_name) + # If the dtype is object, np.save should not work with + # allow_pickle=False, therefore we convert them to a list and + # recursively call get_state on it. + if obj.dtype == object: + obj_serialized = get_state(obj.tolist(), save_state) + res["content"] = obj_serialized["content"] + res["type"] = "json" + res["shape"] = get_state(obj.shape, save_state) + else: + # Memoize the object and then check if it's file name (containing + # the object id) already exists. If it does, there is no need to + # save the object again. Memoizitation is necessary since for + # ephemeral objects, the same id might otherwise be reused. + obj_id = save_state.memoize(obj) + f_name = f"{obj_id}.npy" + path = save_state.path / f_name + if not path.exists(): + with open(path, "wb") as f: + np.save(f, obj, allow_pickle=False) + res.update(type="numpy", file=f_name) except ValueError: - # Object arrays cannot be saved with allow_pickle=False, therefore we - # convert them to a list and recursively call get_state on it. For this, - # we expect the dtype to be object. - if obj.dtype != object: - raise UnsupportedTypeException( - f"numpy arrays of dtype {obj.dtype} are not supported yet, please " - "open an issue at https://github.com/skops-dev/skops/issues and " - "report your error" - ) - - obj_serialized = get_state(obj.tolist(), dst) - res["content"] = obj_serialized["content"] - res["type"] = "json" - res["shape"] = get_state(obj.shape, dst) + # Couldn't save the numpy array with either method + raise UnsupportedTypeException( + f"numpy arrays of dtype {obj.dtype} are not supported yet, please " + "open an issue at https://github.com/skops-dev/skops/issues and " + "report your error" + ) return res @@ -67,13 +74,13 @@ def ndarray_get_instance(state, src): return val -def maskedarray_get_state(obj, dst): +def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "content": { - "data": get_state(obj.data, dst), - "mask": get_state(obj.mask, dst), + "data": get_state(obj.data, save_state), + "mask": get_state(obj.mask, save_state), }, } return res @@ -85,8 +92,8 @@ def maskedarray_get_instance(state, src): return np.ma.MaskedArray(data, mask) -def random_state_get_state(obj, dst): - content = get_state(obj.get_state(legacy=False), dst) +def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: + content = get_state(obj.get_state(legacy=False), save_state) res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -103,7 +110,7 @@ def random_state_get_instance(state, src): return random_state -def random_generator_get_state(obj, dst): +def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: bit_generator_state = obj.bit_generator.state res = { "__class__": obj.__class__.__name__, @@ -128,7 +135,7 @@ def random_generator_get_instance(state, src): # For numpy.ufunc we need to get the type from the type's module, but for other # functions we get it from objet's module directly. Therefore sett a especial # get_state method for them here. The load is the same as other functions. -def ufunc_get_state(obj, dst): +def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, # ufunc "__module__": get_module(type(obj)), # numpy @@ -140,14 +147,14 @@ def ufunc_get_state(obj, dst): return res -def dtype_get_state(obj, dst): +def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. - tmp = np.ndarray(0, dtype=obj) + tmp: np.typing.NDArray = np.ndarray(0, dtype=obj) res = { "__class__": "dtype", "__module__": "numpy", - "content": ndarray_get_state(tmp, dst), + "content": ndarray_get_state(tmp, save_state), } return res diff --git a/skops/io/_persist.py b/skops/io/_persist.py index aac9902e..da9f8982 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -9,11 +9,7 @@ import skops -from ._utils import _get_instance, _get_state, get_instance, get_state - -# For now, there is just one protocol version -PROTOCOL = 0 - +from ._utils import SaveState, _get_instance, _get_state, get_instance, get_state # We load the dispatch functions from the corresponding modules and register # them. @@ -53,9 +49,13 @@ def save(obj, file): """ with tempfile.TemporaryDirectory() as dst: - with open(Path(dst) / "schema.json", "w") as f: - state = get_state(obj, dst) - state["protocol"] = PROTOCOL + path = Path(dst) + with open(path / "schema.json", "w") as f: + save_state = SaveState(path=path) + state = get_state(obj, save_state) + save_state.clear_memo() + + state["protocol"] = save_state.protocol state["_skops_version"] = skops.__version__ json.dump(state, f, indent=2) diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index e2c7d8ae..215597b8 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -1,23 +1,31 @@ +from __future__ import annotations + import io -from pathlib import Path -from uuid import uuid4 +from typing import Any from scipy.sparse import load_npz, save_npz, spmatrix -from ._utils import get_module +from ._utils import SaveState, get_module -def sparse_matrix_get_state(obj, dst): +def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), } - f_name = f"{uuid4()}.npz" - save_npz(Path(dst) / f_name, obj) + # Memoize the object and then check if it's file name (containing the object + # id) already exists. If it does, there is no need to save the object again. + # Memoizitation is necessary since for ephemeral objects, the same id might + # otherwise be reused. + obj_id = save_state.memoize(obj) + f_name = f"{obj_id}.npz" + path = save_state.path / f_name + if not path.exists(): + save_npz(path, obj) + res["type"] = "scipy" res["file"] = f_name - return res diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index d68621e6..da32b7e5 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + from sklearn.cluster import Birch from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys from sklearn.linear_model._sgd_fast import ( @@ -15,7 +19,7 @@ from sklearn.utils import Bunch from ._general import dict_get_instance, dict_get_state, unsupported_get_state -from ._utils import get_instance, get_module, get_state, gettype +from ._utils import SaveState, get_instance, get_module, get_state, gettype from .exceptions import UnsupportedTypeException ALLOWED_SGD_LOSSES = { @@ -32,7 +36,7 @@ UNSUPPORTED_TYPES = {Birch} -def reduce_get_state(obj, dst): +def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # This method is for objects for which we have to use the __reduce__ # method to get the state. res = { @@ -56,7 +60,7 @@ def reduce_get_state(obj, dst): # As a good example, this makes Tree object to be serializable. reduce = obj.__reduce__() res["__reduce__"] = {} - res["__reduce__"]["args"] = get_state(reduce[1], dst) + res["__reduce__"]["args"] = get_state(reduce[1], save_state) if len(reduce) == 3: # reduce includes what's needed for __getstate__ and we don't need to @@ -74,7 +78,7 @@ def reduce_get_state(obj, dst): f"Objects of type {res['__class__']} not supported yet" ) - res["content"] = get_state(attrs, dst) + res["content"] = get_state(attrs, save_state) return res @@ -118,15 +122,17 @@ def bunch_get_instance(state, src): return Bunch(**content) -def _DictWithDeprecatedKeys_get_state(obj, dst): +def _DictWithDeprecatedKeys_get_state( + obj: Any, save_state: SaveState +) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), } content = {} - content["main"] = dict_get_state(obj, dst) + content["main"] = dict_get_state(obj, save_state) content["_deprecated_key_to_new_key"] = dict_get_state( - obj._deprecated_key_to_new_key, dst + obj._deprecated_key_to_new_key, save_state ) res["content"] = content return res diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 27fcf96f..61437894 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -1,8 +1,13 @@ +from __future__ import annotations + import importlib import json # type: ignore import sys +from dataclasses import dataclass, field from functools import _find_impl, get_cache_token, update_wrapper # type: ignore +from pathlib import Path from types import FunctionType +from typing import Any from skops.utils.fixes import GenericAlias @@ -210,6 +215,45 @@ def get_module(obj): return whichmodule(obj, obj.__name__) +# For now, there is just one protocol version +DEFAULT_PROTOCOL = 0 + + +@dataclass(frozen=True) +class SaveState: + """State required for saving the objects + + This state is passed to each ``get_state_*`` function. + + Parameters + ---------- + path: pathlib.Path + The path to the directory to store the object in. + + protocol: int + The protocol of the persistence format. Right now, there is only + protocol 0, but this leaves the door open for future changes. + + """ + + path: Path + protocol: int = DEFAULT_PROTOCOL + memo: dict[int, Any] = field(default_factory=dict) + + def memoize(self, obj: Any) -> int: + # Currenlty, the only purpose for saving the object id is to make sure + # that for the length of the context that the main object is being + # saved, all attributes persist, so that the same id cannot be re-used + # for different objects. + obj_id = id(obj) + if obj_id not in self.memo: + self.memo[obj_id] = obj + return obj_id + + def clear_memo(self) -> None: + self.memo.clear() + + @singledispatch def _get_state(obj, dst): # This function should never be called directly. Instead, it is used to diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index f640cb7a..068f6af6 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -75,13 +75,11 @@ def debug_get_state(func): # Check consistency of argument names, output type, and that the output, # if a dict, has certain keys, or if not a dict, is a primitive type. signature = inspect.signature(func) - assert list(signature.parameters.keys()) == ["obj", "dst"] + assert list(signature.parameters.keys()) == ["obj", "save_state"] @wraps(func) - def wrapper(obj, dst): - assert isinstance(dst, str) - - result = func(obj, dst) + def wrapper(obj, save_state): + result = func(obj, save_state) if isinstance(result, dict): assert "__class__" in result @@ -663,7 +661,7 @@ def fit(self, X, y=None, **fit_params): schema = json.loads(ZipFile(f_name).read("schema.json")) # check some schema metainfo - assert schema["protocol"] == skops.io._persist.PROTOCOL + assert schema["protocol"] == skops.io._utils.DEFAULT_PROTOCOL assert schema["_skops_version"] == skops.__version__ # additionally, check following metainfo: class, module, and version @@ -704,3 +702,83 @@ def fit(self, X, y=None, **fit_params): # change across versions, e.g. 'scipy.sparse.csr' moving to # 'scipy.sparse._csr'. assert val_state["__module__"].startswith(val_expected["__module__"]) + + +class EstimatorIdenticalArrays(BaseEstimator): + """Estimator that stores multiple references to the same array""" + + def fit(self, X, y=None, **fit_params): + # each block below should reference the same file + self.X = X + self.X_2 = X + self.X_list = [X, X] + self.X_dict = {"a": X, 2: X} + + # copies are not deduplicated + X_copy = X.copy() + self.X_copy = X_copy + self.X_copy2 = X_copy + + # transposed matrices are not the same + X_T = X.T + self.X_T = X_T + self.X_T2 = X_T + + # slices are not the same + self.vector = X[0] + + self.vector_2 = X[0] + + self.scalar = X[0, 0] + + self.scalar_2 = X[0, 0] + + # deduplication should work on sparse matrices + X_sparse = sparse.csr_matrix(X) + self.X_sparse = X_sparse + self.X_sparse2 = X_sparse + + return self + + +def test_identical_numpy_arrays_not_duplicated(tmp_path): + # Test that identical numpy arrays are not stored multiple times + X = np.random.random((10, 5)) + estimator = EstimatorIdenticalArrays().fit(X) + f_name = tmp_path / "file.skops" + loaded = save_load_round(estimator, f_name) + assert_params_equal(estimator.__dict__, loaded.__dict__) + + # check number of numpy arrays stored on disk + with ZipFile(f_name, "r") as input_zip: + files = input_zip.namelist() + # expected number of files are: + # schema, X, X_copy, X_t, 2 vectors, 2 scalars, X_sparse = 9 + expected_files = 9 + num_files = len(files) + assert num_files == expected_files + + +class NumpyDtypeObjectEstimator(BaseEstimator): + """An estimator with a numpy array of dtype object""" + + def fit(self, X, y=None, **fit_params): + self.obj_ = np.zeros(3, dtype=object) + return self + + +def test_numpy_dtype_object_does_not_store_broken_file(tmp_path): + # This addresses a specific bug where trying to store an object numpy array + # resulted in the creation of a broken .npy file being left over. This is + # because numpy tries to write to the file until it encounters an error and + # raises, but then doesn't clean up said file. Before the bugfix in #150, we + # would include that broken file in the zip archive, although we wouldn't do + # anything with it. Here we test that no such file exists. + estimator = NumpyDtypeObjectEstimator().fit(None) + f_name = tmp_path / "file.skops" + save_load_round(estimator, f_name) + with ZipFile(f_name, "r") as input_zip: + files = input_zip.namelist() + + # this estimator should not have any numpy file + assert not any(file.endswith(".npy") for file in files)