Skip to content

Commit

Permalink
Added NumPy support for utbot-python (#2742)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewPolyakov1 authored Jun 3, 2024
1 parent 2b37f07 commit 228072f
Show file tree
Hide file tree
Showing 30 changed files with 595 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class TestWriter {

fun generateTestCode(): String {
val (importLines, code) = testCode.fold(mutableListOf<String>() to StringBuilder()) { acc, s ->
val lines = s.split(System.lineSeparator())
// val lines = s.split(System.lineSeparator())
val lines = s.split("(\\r\\n|\\r|\\n)".toRegex())
val firstClassIndex = lines.indexOfFirst { it.startsWith("class") }
lines.take(firstClassIndex).forEach { line -> if (line !in acc.first) acc.first.add(line) }
lines.drop(firstClassIndex).forEach { line -> acc.second.append(line + System.lineSeparator()) }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "utbot-executor"
version = "1.9.19"
version = "1.10.0"
description = ""
authors = ["Vyacheslav Tamarin <[email protected]>"]
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def serialize_objects_dump(objs: List[Any], clear_visited: bool = False) -> Tupl
serializer.write_object_to_memory(obj)
for obj in objs
]

return ids, serializer.memory, serialize_memory_dump(serializer.memory)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
ListMemoryObject,
DictMemoryObject,
ReduceMemoryObject,
MemoryDump, IteratorMemoryObject,
MemoryDump, IteratorMemoryObject, NdarrayMemoryObject,
)
from utbot_executor.deep_serialization.utils import PythonId, TypeInfo
try:
import numpy as np
except ImportError:
import sys
print("numpy is not installed", file=sys.stderr)


class MemoryObjectEncoder(json.JSONEncoder):
Expand All @@ -27,6 +32,10 @@ def default(self, o):
}
if isinstance(o, ReprMemoryObject):
base_json["value"] = o.value
elif isinstance(o, NdarrayMemoryObject):
base_json["items"] = o.items
base_json["comparable"] = True
base_json["dimensions"] = o.dimensions
elif isinstance(o, (ListMemoryObject, DictMemoryObject)):
base_json["items"] = o.items
elif isinstance(o, IteratorMemoryObject):
Expand All @@ -53,6 +62,10 @@ def default(self, o):
"kind": o.kind,
"module": o.module,
}
if isinstance(o, np.ndarray):
raise NotImplementedError("np.ndarray is not supported")
if isinstance(o, np.bool_):
return bool(o)
return json.JSONEncoder.default(self, o)


Expand All @@ -75,6 +88,17 @@ def as_reduce_object(dct: Dict) -> Union[MemoryObject, Dict]:
)
obj.comparable = dct["comparable"]
return obj

if dct["strategy"] == "ndarray":
obj = NdarrayMemoryObject.__new__(NdarrayMemoryObject)
obj.items = dct["items"]
obj.typeinfo = TypeInfo(
kind=dct["typeinfo"]["kind"], module=dct["typeinfo"]["module"]
)
obj.comparable = dct["comparable"]
obj.dimensions = dct["dimensions"]
return obj

if dct["strategy"] == "dict":
obj = DictMemoryObject.__new__(DictMemoryObject)
obj.items = dct["items"]
Expand Down Expand Up @@ -138,6 +162,11 @@ def reload_id(self) -> MemoryDump:
new_memory_object.items = [
self.dump_id_to_real_id[id_] for id_ in new_memory_object.items
]
elif isinstance(new_memory_object, NdarrayMemoryObject):
new_memory_object.items = [
self.dump_id_to_real_id[id_] for id_ in new_memory_object.items
]
new_memory_object.dimensions = obj.dimensions
elif isinstance(new_memory_object, IteratorMemoryObject):
new_memory_object.items = [
self.dump_id_to_real_id[id_] for id_ in new_memory_object.items
Expand Down Expand Up @@ -198,6 +227,27 @@ def load_object(self, python_id: PythonId) -> object:

for item in dump_object.items:
real_object.append(self.load_object(item))

elif isinstance(dump_object, NdarrayMemoryObject):
real_object = []

id_ = PythonId(str(id(real_object)))
self.dump_id_to_real_id[python_id] = id_
self.memory[id_] = real_object

temp_list = []
for item in dump_object.items:
temp_list.append(self.load_object(item))

dt = np.dtype(type(temp_list[0]) if len(temp_list) > 0 else np.int64)
temp_list = np.array(temp_list, dtype=dt)

real_object = np.ndarray(
shape=dump_object.dimensions,
dtype=dt,
buffer=temp_list
)

elif isinstance(dump_object, DictMemoryObject):
real_object = {}

Expand Down Expand Up @@ -250,7 +300,7 @@ def load_object(self, python_id: PythonId) -> object:
for key, dictitem in dictitems.items():
real_object[key] = dictitem
else:
raise TypeError(f"Invalid type {dump_object}")
raise TypeError(f"Invalid type {dump_object}, type: {type(dump_object)}")

id_ = PythonId(str(id(real_object)))
self.dump_id_to_real_id[python_id] = id_
Expand Down Expand Up @@ -279,6 +329,10 @@ def main():
"builtins.tuple",
"builtins.bytes",
"builtins.type",
"numpy.ndarray"
]
)
print(loader.load_object(PythonId("140239390887040")))

if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import typing
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Set, Type, Iterable
try:
import numpy as np
except ImportError:
import sys
print("numpy is not installed", file=sys.stderr)

from utbot_executor.deep_serialization.config import PICKLE_PROTO
from utbot_executor.deep_serialization.iterator_wrapper import IteratorWrapper
Expand Down Expand Up @@ -41,7 +46,7 @@ def __init__(self, obj: object) -> None:
self.id_ = PythonId(str(id(self.obj)))

def _initialize(
self, deserialized_obj: object = None, comparable: bool = True
self, deserialized_obj: object = None, comparable: bool = True
) -> None:
self.deserialized_obj = deserialized_obj
self.comparable = comparable
Expand Down Expand Up @@ -111,14 +116,42 @@ def initialize(self) -> None:
elif self.typeinfo.fullname == "builtins.set":
deserialized_obj = set(deserialized_obj)


comparable = all(serializer.get_by_id(elem).comparable for elem in self.items)

super()._initialize(deserialized_obj, comparable)

class NdarrayMemoryObject(MemoryObject):
strategy: str = "ndarray"
items: List[PythonId] = []
dimensions: List[int] = []

def __init__(self, ndarray_object: object) -> None:
self.items: List[PythonId] = []
super().__init__(ndarray_object)

def initialize(self) -> None:
serializer = PythonSerializer()
self.deserialized_obj = [] # for recursive collections
self.comparable = False # for recursive collections

temp_object = self.obj.copy().flatten()

self.dimensions = self.obj.shape
if temp_object.shape != (0, ):
for elem in temp_object:
elem_id = serializer.write_object_to_memory(elem)
self.items.append(elem_id)
self.deserialized_obj.append(serializer[elem_id])

deserialized_obj = self.deserialized_obj
comparable = all(serializer.get_by_id(elem).comparable for elem in self.items) if self.deserialized_obj != [] else True
super()._initialize(deserialized_obj, comparable)

def __repr__(self) -> str:
if hasattr(self, "obj"):
return str(self.obj)
return f"{self.typeinfo.kind}{self.items}"
return f"{self.typeinfo.kind}{self.items}{self.dimensions}"


class DictMemoryObject(MemoryObject):
Expand Down Expand Up @@ -264,10 +297,10 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:

is_reconstructor = constructor_kind.qualname == "copyreg._reconstructor"
is_reduce_user_type = (
len(self.reduce_value[1]) == 3
and isinstance(self.reduce_value[1][0], type(self.obj))
and self.reduce_value[1][1] is object
and self.reduce_value[1][2] is None
len(self.reduce_value[1]) == 3
and isinstance(self.reduce_value[1][0], type(self.obj))
and self.reduce_value[1][1] is object
and self.reduce_value[1][2] is None
)
is_reduce_ex_user_type = len(self.reduce_value[1]) == 1 and isinstance(
self.reduce_value[1][0], type(self.obj)
Expand All @@ -294,8 +327,8 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:
len(inspect.signature(init_method).parameters),
)
if (
not init_from_object
and len(inspect.signature(init_method).parameters) == 1
not init_from_object
and len(inspect.signature(init_method).parameters) == 1
) or init_from_object:
logging.debug("init with one argument! %s", init_method)
constructor_arguments = []
Expand All @@ -317,9 +350,9 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:
if is_reconstructor and is_user_type:
constructor_arguments = self.reduce_value[1]
if (
len(constructor_arguments) == 3
and constructor_arguments[-1] is None
and constructor_arguments[-2] == object
len(constructor_arguments) == 3
and constructor_arguments[-1] is None
and constructor_arguments[-2] == object
):
del constructor_arguments[1:]
callable_constructor = object.__new__
Expand Down Expand Up @@ -392,6 +425,12 @@ def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
return ListMemoryObject
return None

class NdarrayMemoryObjectProvider(MemoryObjectProvider):
@staticmethod
def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
if type(obj) == np.ndarray:
return NdarrayMemoryObject
return None

class DictMemoryObjectProvider(MemoryObjectProvider):
@staticmethod
Expand Down Expand Up @@ -425,6 +464,7 @@ def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
return None



class ReprMemoryObjectProvider(MemoryObjectProvider):
@staticmethod
def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
Expand All @@ -450,6 +490,7 @@ class PythonSerializer:
visited: Set[PythonId] = set()

providers: List[MemoryObjectProvider] = [
NdarrayMemoryObjectProvider,
ListMemoryObjectProvider,
DictMemoryObjectProvider,
IteratorMemoryObjectProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def get_constructor_info(constructor: object, obj: object) -> TypeInfo:


def has_reduce(py_object: object) -> bool:
if get_kind(py_object).module == "numpy":
return False
reduce = getattr(py_object, "__reduce__", None)
if reduce is None:
return False
Expand Down Expand Up @@ -161,6 +163,10 @@ def check_eval(py_object: object) -> bool:
except Exception:
return False

try:
import numpy as np
except ImportError:
pass

def has_repr(py_object: object) -> bool:
reprable_types = [
Expand All @@ -171,11 +177,13 @@ def has_repr(py_object: object) -> bool:
bytes,
bytearray,
str,
# tuple,
# list,
# dict,
# set,
# frozenset,
np.int64,
np.int32,
np.int16,
np.int8,
np.float32,
np.float16,
np.float64,
type,
]
if type(py_object) in reprable_types:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


def _update_states(
init_memory_dump: MemoryDump, before_memory_dump: MemoryDump
init_memory_dump: MemoryDump, before_memory_dump: MemoryDump
) -> MemoryDump:
for id_, obj in before_memory_dump.objects.items():
if id_ in init_memory_dump.objects:
Expand Down Expand Up @@ -87,8 +87,9 @@ def add_imports(imports: Iterable[str]):
globals()[submodule_name] = importlib.import_module(
submodule_name
)
except ModuleNotFoundError:
except ModuleNotFoundError as e:
logging.warning("Import submodule %s failed", submodule_name)
raise e
logging.debug("Submodule #%d: OK", i)

def run_function(self, request: ExecutionRequest) -> ExecutionResponse:
Expand All @@ -115,6 +116,9 @@ def run_reduce_function(self, request: ExecutionRequest) -> ExecutionResponse:
self.add_imports(request.imports)
loader.add_syspaths(request.syspaths)
loader.add_imports(request.imports)
except ModuleNotFoundError as _:
logging.debug("Error \n%s", traceback.format_exc())
return ExecutionFailResponse("fail", traceback.format_exc())
except Exception as _:
logging.debug("Error \n%s", traceback.format_exc())
return ExecutionFailResponse("fail", traceback.format_exc())
Expand Down Expand Up @@ -246,9 +250,9 @@ def run_pickle_function(self, request: ExecutionRequest) -> ExecutionResponse:


def _serialize_state(
args: List[Any],
kwargs: Dict[str, Any],
result: Any = None,
args: List[Any],
kwargs: Dict[str, Any],
result: Any = None,
) -> Tuple[List[PythonId], Dict[str, PythonId], PythonId, MemoryDump, str]:
"""Serialize objects from args, kwargs and result.
Expand All @@ -267,13 +271,13 @@ def _serialize_state(


def _run_calculate_function_value(
function: types.FunctionType,
args: List[Any],
kwargs: Dict[str, Any],
fullpath: str,
state_init: str,
tracer: UtTracer,
state_assertions: bool,
function: types.FunctionType,
args: List[Any],
kwargs: Dict[str, Any],
fullpath: str,
state_init: str,
tracer: UtTracer,
state_assertions: bool,
) -> ExecutionResponse:
"""Calculate function evaluation result.
Expand Down
Loading

0 comments on commit 228072f

Please sign in to comment.