Skip to content

Commit

Permalink
Merge pull request #3929 from tybug/shrinker-ir-descendant
Browse files Browse the repository at this point in the history
Migrate `pass_to_descendant` and `redistribute_block_pairs` shrinker passes
  • Loading branch information
Zac-HD authored Mar 19, 2024
2 parents 51bb792 + 01c18a9 commit 501d2ba
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 34 deletions.
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

This patch continues our work on refactoring the shrinker (:issue:`3921`).
57 changes: 56 additions & 1 deletion hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ def length(self) -> int:
"""The number of bytes in this example."""
return self.end - self.start

@property
def ir_length(self) -> int:
"""The number of ir nodes in this example."""
return self.ir_end - self.ir_start

@property
def children(self) -> "List[Example]":
"""The list of all examples with this as a parent, in increasing index
Expand Down Expand Up @@ -465,7 +470,11 @@ def freeze(self) -> None:
def record_ir_draw(self, ir_type, value, *, kwargs, was_forced):
self.trail.append(IR_NODE_RECORD)
node = IRNode(
ir_type=ir_type, value=value, kwargs=kwargs, was_forced=was_forced
ir_type=ir_type,
value=value,
kwargs=kwargs,
was_forced=was_forced,
index=len(self.ir_nodes),
)
self.ir_nodes.append(node)

Expand Down Expand Up @@ -950,18 +959,64 @@ class IRNode:
value: IRType = attr.ib()
kwargs: IRKWargsType = attr.ib()
was_forced: bool = attr.ib()
index: Optional[int] = attr.ib(default=None)

def copy(self, *, with_value: IRType) -> "IRNode":
# we may want to allow this combination in the future, but for now it's
# a footgun.
assert not self.was_forced, "modifying a forced node doesn't make sense"
# explicitly not copying index. node indices are only assigned via
# ExampleRecord. This prevents footguns with relying on stale indices
# after copying.
return IRNode(
ir_type=self.ir_type,
value=with_value,
kwargs=self.kwargs,
was_forced=self.was_forced,
)

@property
def trivial(self):
"""
A node is trivial if it cannot be simplified any further. This does not
mean that modifying a trivial node can't produce simpler test cases when
viewing the tree as a whole. Just that when viewing this node in
isolation, this is the simplest the node can get.
"""
if self.was_forced:
return True

if self.ir_type == "integer":
shrink_towards = self.kwargs["shrink_towards"]
min_value = self.kwargs["min_value"]
max_value = self.kwargs["max_value"]

if min_value is not None:
shrink_towards = max(min_value, shrink_towards)
if max_value is not None:
shrink_towards = min(max_value, shrink_towards)

return self.value == shrink_towards
if self.ir_type == "float":
# floats shrink "like integers" (for now, anyway), except shrink_towards
# is not configurable and is always 0.
shrink_towards = 0
shrink_towards = max(self.kwargs["min_value"], shrink_towards)
shrink_towards = min(self.kwargs["max_value"], shrink_towards)

return ir_value_equal("float", self.value, shrink_towards)
if self.ir_type == "boolean":
return self.value is False
if self.ir_type == "string":
# smallest size and contains only the smallest-in-shrink-order character.
minimal_char = self.kwargs["intervals"].char_in_shrink_order(0)
return self.value == (minimal_char * self.kwargs["min_size"])
if self.ir_type == "bytes":
# smallest size and all-zero value.
return len(self.value) == self.kwargs["size"] and not any(self.value)

raise NotImplementedError(f"unhandled ir_type {self.ir_type}")

def __eq__(self, other):
if not isinstance(other, IRNode):
return NotImplemented
Expand Down
92 changes: 59 additions & 33 deletions hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
prefix_selection_order,
random_selection_order,
)
from hypothesis.internal.conjecture.data import ConjectureData, ConjectureResult, Status
from hypothesis.internal.conjecture.data import (
ConjectureData,
ConjectureResult,
Status,
bits_to_bytes,
ir_value_permitted,
)
from hypothesis.internal.conjecture.dfa import ConcreteDFA
from hypothesis.internal.conjecture.floats import is_simple
from hypothesis.internal.conjecture.junkdrawer import (
Expand Down Expand Up @@ -377,7 +383,6 @@ def calls(self):

def consider_new_tree(self, tree):
data = self.engine.ir_tree_to_data(tree)

return self.consider_new_buffer(data.buffer)

def consider_new_buffer(self, buffer):
Expand Down Expand Up @@ -825,12 +830,10 @@ def pass_to_descendant(self, chooser):
)

ls = self.examples_by_label[label]

i = chooser.choose(range(len(ls) - 1))

ancestor = ls[i]

if i + 1 == len(ls) or ls[i + 1].start >= ancestor.end:
if i + 1 == len(ls) or ls[i + 1].ir_start >= ancestor.ir_end:
return

@self.cached(label, i)
Expand All @@ -839,22 +842,22 @@ def descendants():
hi = len(ls)
while lo + 1 < hi:
mid = (lo + hi) // 2
if ls[mid].start >= ancestor.end:
if ls[mid].ir_start >= ancestor.ir_end:
hi = mid
else:
lo = mid
return [t for t in ls[i + 1 : hi] if t.length < ancestor.length]
return [t for t in ls[i + 1 : hi] if t.ir_length < ancestor.ir_length]

descendant = chooser.choose(descendants, lambda ex: ex.length > 0)
descendant = chooser.choose(descendants, lambda ex: ex.ir_length > 0)

assert ancestor.start <= descendant.start
assert ancestor.end >= descendant.end
assert descendant.length < ancestor.length
assert ancestor.ir_start <= descendant.ir_start
assert ancestor.ir_end >= descendant.ir_end
assert descendant.ir_length < ancestor.ir_length

self.incorporate_new_buffer(
self.buffer[: ancestor.start]
+ self.buffer[descendant.start : descendant.end]
+ self.buffer[ancestor.end :]
self.consider_new_tree(
self.nodes[: ancestor.ir_start]
+ self.nodes[descendant.ir_start : descendant.ir_end]
+ self.nodes[ancestor.ir_end :]
)

def lower_common_block_offset(self):
Expand Down Expand Up @@ -1221,7 +1224,6 @@ def minimize_floats(self, chooser):
and not is_simple(node.value),
)

i = self.nodes.index(node)
# the Float shrinker was only built to handle positive floats. We'll
# shrink the positive portion and reapply the sign after, which is
# equivalent to this shrinker's previous behavior. We'll want to refactor
Expand All @@ -1231,9 +1233,9 @@ def minimize_floats(self, chooser):
Float.shrink(
abs(node.value),
lambda val: self.consider_new_tree(
self.nodes[:i]
self.nodes[: node.index]
+ [node.copy(with_value=sign * val)]
+ self.nodes[i + 1 :]
+ self.nodes[node.index + 1 :]
),
random=self.random,
node=node,
Expand All @@ -1245,32 +1247,56 @@ def redistribute_block_pairs(self, chooser):
to exceed some bound, lowering one of them requires raising the
other. This pass enables that."""

block = chooser.choose(self.blocks, lambda b: not b.all_zero)
node = chooser.choose(
self.nodes, lambda node: node.ir_type == "integer" and not node.trivial
)

for j in range(block.index + 1, len(self.blocks)):
next_block = self.blocks[j]
if next_block.length == block.length:
# The preconditions for this pass are that the two integer draws are only
# separated by non-integer nodes, and have the same size value in bytes.
#
# This isn't particularly principled. For instance, this wouldn't reduce
# e.g. @given(integers(), integers(), integers()) where the sum property
# involves the first and last integers.
#
# A better approach may be choosing *two* such integer nodes arbitrarily
# from the list, instead of conditionally scanning forward.

for j in range(node.index + 1, len(self.nodes)):
next_node = self.nodes[j]
if next_node.ir_type == "integer" and bits_to_bytes(
node.value.bit_length()
) == bits_to_bytes(next_node.value.bit_length()):
break
else:
return

buffer = self.buffer
if next_node.was_forced:
# avoid modifying a forced node. Note that it's fine for next_node
# to be trivial, because we're going to explicitly make it *not*
# trivial by adding to its value.
return

m = int_from_bytes(buffer[block.start : block.end])
n = int_from_bytes(buffer[next_block.start : next_block.end])
m = node.value
n = next_node.value

def boost(k):
if k > m:
return False
attempt = bytearray(buffer)
attempt[block.start : block.end] = int_to_bytes(m - k, block.length)
try:
attempt[next_block.start : next_block.end] = int_to_bytes(
n + k, next_block.length
)
except OverflowError:

node_value = m - k
next_node_value = n + k
if (not ir_value_permitted(node_value, "integer", node.kwargs)) or (
not ir_value_permitted(next_node_value, "integer", next_node.kwargs)
):
return False
return self.consider_new_buffer(attempt)

return self.consider_new_tree(
self.nodes[: node.index]
+ [node.copy(with_value=node_value)]
+ self.nodes[node.index + 1 : next_node.index]
+ [next_node.copy(with_value=next_node_value)]
+ self.nodes[next_node.index + 1 :]
)

find_integer(boost)

Expand Down
Loading

0 comments on commit 501d2ba

Please sign in to comment.