Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[SideEffect] add side effects support for list #220

Merged
merged 17 commits into from
Jul 4, 2023
Merged
28 changes: 28 additions & 0 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .variables import (
ContainerVariable,
DictVariable,
ListVariable,
PaddleLayerVariable,
TensorVariable,
VariableBase,
Expand Down Expand Up @@ -288,12 +289,22 @@ def _find_tensor_outputs(
self, outputs: list[VariableBase]
) -> list[TensorVariable]:
output_tensors: list[TensorVariable] = []
# Find Tensor Variables from outputs.
for output in outputs:
if isinstance(output.tracker, DummyTracker):
if isinstance(output, TensorVariable):
output_tensors.append(output)
else:
# Guard output that can not be traced.
self.add_global_guarded_variable(output)
# Find Tensor Variables from side effects Variables.
for side_effect_var in self.side_effects.variables:
if side_effect_var.proxy.has_changed:
for var in side_effect_var.flatten_items():
if isinstance(var.tracker, DummyTracker) and isinstance(
var, TensorVariable
):
output_tensors.append(var)
return output_tensors

def restore_side_effects(self, variables: list[VariableBase]):
Expand Down Expand Up @@ -327,3 +338,20 @@ def restore_side_effects(self, variables: list[VariableBase]):
self.pycode_gen.gen_pop_top()
self.pycode_gen.gen_call_method(1) # call update
self.pycode_gen.gen_pop_top()
elif isinstance(var, ListVariable):
# old_list[:] = new_list

# Reference to the original list.
# load new_list to stack.
var._reconstruct(self.pycode_gen)
# load old_list[:] to stack.
var.reconstruct(self.pycode_gen)
self.pycode_gen.gen_load_const(None)
self.pycode_gen.gen_load_const(None)
self.pycode_gen.gen_build_slice(2)

# Generate side effects of other variables.
self.restore_side_effects(variables[1:])

# Call STROE_SUBSCR to apply side effects.
self.pycode_gen.gen_store_subscr()
175 changes: 154 additions & 21 deletions sot/opcode_translator/executor/mutable_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,94 @@
R = TypeVar("R")

DataGetter: TypeAlias = Callable[[Any, Any], Any]
MutableDataT = TypeVar("MutableDataT", bound="MutableData")


class Mutation:
...
ABBR: str


class MutationSet(Mutation):
"""
Setting a value.
This mutation is used for MutableDictLikeData and MutableListLikeData.
"""

ABBR = "S"

def __init__(self, key, value):
self.key = key
self.value = value

def __repr__(self):
return f"MutationSet({self.key}, {self.value})"


class MutationDel(Mutation):
"""
Deleting a value.
This mutation is used for MutableDictLikeData and MutableListLikeData.
"""

ABBR = "D"

def __init__(self, key):
self.key = key

def __repr__(self):
return f"MutationDel({self.key})"


class MutationNew(Mutation):
def __init__(self, name, value):
self.name = name
"""
Adding a new value.
This mutation is only used for MutableDictLikeData.
"""

ABBR = "N"

def __init__(self, key, value):
self.key = key
self.value = value

def __repr__(self):
return f"MutationNew({self.name}, {self.value})"
return f"MutationNew({self.key}, {self.value})"


class MutationSet(Mutation):
def __init__(self, name, value):
self.name = name
class MutationInsert(Mutation):
"""
Inserting a value.
This mutation is only used for MutableListLikeData.
"""

ABBR = "I"

def __init__(self, index, value):
self.index = index
self.value = value

def __repr__(self):
return f"MutationSet({self.name}, {self.value})"
return f"MutationInsert({self.index}, {self.value})"


class MutationDel(Mutation):
def __init__(self, name):
self.name = name
class MutationPermutate(Mutation):
"""
Permutating all the values.
This mutation is only used for MutableListLikeData.
"""

ABBR = "P"

def __init__(self, permutation):
self.permutation = permutation

def __repr__(self):
return f"MutationDel({self.name})"
return f"MutationPermutate({self.permutation})"


def record_mutation(
mutation_fn: Callable[Concatenate[MutableDictLikeData, P], Mutation]
) -> Callable[Concatenate[MutableDictLikeData, P], None]:
mutation_fn: Callable[Concatenate[MutableDataT, P], Mutation]
) -> Callable[Concatenate[MutableDataT, P], None]:
def wrapper(self, *args: P.args, **kwargs: P.kwargs):
mutation = mutation_fn(self, *args, **kwargs)
self.records.append(mutation)
Expand All @@ -56,6 +109,8 @@ class MutableData:
An intermediate data structure between data and variable, it records all the mutations.
"""

read_cache: list[Any] | dict[str, Any]

class Empty:
def __repr__(self):
return "Empty()"
Expand Down Expand Up @@ -86,16 +141,35 @@ def get(self, key):
def set(self, key, value):
raise NotImplementedError()

def apply(
self, mutation: Mutation, write_cache: dict[str, Any] | list[Any]
):
raise NotImplementedError()

def reproduce(self, version: int | None = None):
if version is None:
version = self.version
write_cache = self.read_cache.copy()
for mutation in self.records[:version]:
self.apply(mutation, write_cache)
return write_cache

def __repr__(self) -> str:
records_abbrs = "".join([mutation.ABBR for mutation in self.records])
return f"{self.__class__.__name__}({records_abbrs})"


class MutableDictLikeData(MutableData):
read_cache: dict[str, Any]

def __init__(self, data: Any, getter: DataGetter):
super().__init__(data, getter)
self.read_cache = {}

def clear(self):
self.read_cache.clear()

def get(self, key):
def get(self, key: Any):
# TODO(SigureMo): Optimize performance of this.
write_cache = self.reproduce(self.version)
if key not in write_cache:
Expand All @@ -106,13 +180,13 @@ def get_all(self):
original_keys = list(self.original_data.keys())
for mutation in self.records:
if isinstance(mutation, MutationNew):
original_keys.append(mutation.name)
original_keys.append(mutation.key)
elif isinstance(mutation, MutationDel):
original_keys.remove(mutation.name)
original_keys.remove(mutation.key)
return {key: self.get(key) for key in original_keys}

@record_mutation
def set(self, key, value) -> Mutation:
def set(self, key: Any, value: Any) -> Mutation:
is_new = False
if self.is_empty(self.get(key)):
is_new = True
Expand All @@ -126,11 +200,11 @@ def delete(self, key):

def apply(self, mutation: Mutation, write_cache: dict[str, Any]):
if isinstance(mutation, MutationNew):
write_cache[mutation.name] = mutation.value
write_cache[mutation.key] = mutation.value
elif isinstance(mutation, MutationSet):
write_cache[mutation.name] = mutation.value
write_cache[mutation.key] = mutation.value
elif isinstance(mutation, MutationDel):
write_cache[mutation.name] = MutableData.Empty()
write_cache[mutation.key] = MutableData.Empty()
else:
raise ValueError(f"Unknown mutation type {mutation}")

Expand All @@ -141,3 +215,62 @@ def reproduce(self, version: int | None = None):
for mutation in self.records[:version]:
self.apply(mutation, write_cache)
return write_cache


class MutableListLikeData(MutableData):
read_cache: list[Any]

def __init__(self, data: Any, getter: DataGetter):
super().__init__(data, getter)
self.read_cache = [
self.getter(self.original_data, idx) for idx in range(len(data))
]

def clear(self):
self.read_cache[:] = []

@property
def length(self):
return len(self.reproduce())

def get(self, key):
write_cache = self.reproduce(self.version)
return write_cache[key]

def get_all(self):
return self.reproduce(self.version)

@record_mutation
def set(self, key: int, value: Any):
return MutationSet(self._regularize_index(key), value)

@record_mutation
def delete(self, key: int):
return MutationDel(self._regularize_index(key))

@record_mutation
def insert(self, index: int, value: Any):
return MutationInsert(self._regularize_index(index), value)

@record_mutation
def permutate(self, permutation: list[int]):
return MutationPermutate(permutation)

def _regularize_index(self, index: int):
if index < 0:
index += self.length
return index

def apply(self, mutation: Mutation, write_cache: list[Any]):
if isinstance(mutation, MutationSet):
write_cache[mutation.key] = mutation.value
elif isinstance(mutation, MutationDel):
write_cache[:] = (
write_cache[: mutation.key] + write_cache[mutation.key + 1 :]
)
elif isinstance(mutation, MutationInsert):
write_cache.insert(mutation.index, mutation.value)
elif isinstance(mutation, MutationPermutate):
write_cache[:] = [write_cache[i] for i in mutation.permutation]
else:
raise ValueError(f"Unknown mutation type {mutation}")
7 changes: 7 additions & 0 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,13 @@ def DICT_MERGE(self, instr: Instruction):
self._stack[-instr.arg], dict_value
)

def LIST_APPEND(self, instr: Instruction):
list_value = self.pop()
assert instr.argval > 0
BuiltinVariable(list.append, self._graph, tracker=DanglingTracker())(
self._stack[-instr.arg], list_value
)

def LIST_EXTEND(self, instr: Instruction):
list_value = self.pop()
assert instr.argval > 0
Expand Down
6 changes: 6 additions & 0 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ def gen_store_fast(self, name):
idx = self._code_options["co_varnames"].index(name)
self._add_instr("STORE_FAST", arg=idx, argval=name)

def gen_store_subscr(self):
self._add_instr("STORE_SUBSCR")

def gen_subscribe(self):
self._add_instr("BINARY_SUBSCR")

Expand All @@ -508,6 +511,9 @@ def gen_build_list(self, count):
def gen_build_map(self, count):
self._add_instr("BUILD_MAP", arg=count, argval=count)

def gen_build_slice(self, argc):
self._add_instr("BUILD_SLICE", arg=argc, argval=argc)

def gen_unpack_sequence(self, count):
self._add_instr("UNPACK_SEQUENCE", arg=count, argval=count)

Expand Down
3 changes: 2 additions & 1 deletion sot/opcode_translator/executor/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(self):
self.variables: list[VariableBase] = []

def record_variable(self, variable: VariableBase):
self.variables.append(variable)
if variable not in self.variables:
self.variables.append(variable)

def get_proxy(
self,
Expand Down
2 changes: 1 addition & 1 deletion sot/opcode_translator/executor/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def gen_instructions(self, codegen: PyCodeGen):
codegen.gen_load_const(self.value)

def trace_value_from_frame(self):
return StringifyExpression(f"{self.value}", {})
return StringifyExpression(f"{self.value!r}", {})

def __repr__(self) -> str:
return f"ConstTracker(value={self.value})"
Expand Down
Loading