Skip to content

Commit

Permalink
[issue1133] Fix bug in invariant synthesis and revise its implementat…
Browse files Browse the repository at this point in the history
…ion.

It was possible that an unbalanced action passed the balance check. This was
a conceptual gap in the original algorithm for which we had to significantly
revise the invariant synthesis. We used the opportunity to revise wide parts of
its implementation, restructuring and renaming the different components to
improve clarity.
  • Loading branch information
roeger authored Feb 1, 2024
1 parent 07e922e commit 492f71c
Show file tree
Hide file tree
Showing 4 changed files with 469 additions and 351 deletions.
249 changes: 133 additions & 116 deletions src/translate/constraints.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,30 @@
import itertools
from typing import Iterable, List, Tuple

class NegativeClause:
# disjunction of inequalities
def __init__(self, parts):
class InequalityDisjunction:
def __init__(self, parts: List[Tuple[str, str]]):
self.parts = parts
assert len(parts)

def __str__(self):
disj = " or ".join(["(%s != %s)" % (v1, v2)
for (v1, v2) in self.parts])
return "(%s)" % disj

def is_satisfiable(self):
for part in self.parts:
if part[0] != part[1]:
return True
return False

def apply_mapping(self, m):
new_parts = [(m.get(v1, v1), m.get(v2, v2)) for (v1, v2) in self.parts]
return NegativeClause(new_parts)
disj = " or ".join([f"({v1} != {v2})" for (v1, v2) in self.parts])
return f"({disj})"


class Assignment:
def __init__(self, equalities):
self.equalities = tuple(equalities)
# represents a conjunction of expressions ?x = ?y or ?x = d
# with ?x, ?y being variables and d being a domain value
class EqualityConjunction:
def __init__(self, equalities: List[Tuple[str, str]]):
self.equalities = equalities
# A conjunction of expressions x = y, where x,y are either strings
# that denote objects or variables, or ints that denote invariant
# parameters.

self.consistent = None
self.mapping = None
self.eq_classes = None
self._consistent = None
self._representative = None # dictionary
self._eq_classes = None

def __str__(self):
conj = " and ".join(["(%s = %s)" % (v1, v2)
for (v1, v2) in self.equalities])
return "(%s)" % conj
conj = " and ".join([f"({v1} = {v2})" for (v1, v2) in self.equalities])
return f"({conj})"

def _compute_equivalence_classes(self):
eq_classes = {}
Expand All @@ -48,113 +37,141 @@ def _compute_equivalence_classes(self):
c1.update(c2)
for elem in c2:
eq_classes[elem] = c1
self.eq_classes = eq_classes
self._eq_classes = eq_classes

def _compute_mapping(self):
if not self.eq_classes:
def _compute_representatives(self):
if not self._eq_classes:
self._compute_equivalence_classes()

# create mapping: each key is mapped to the smallest
# element in its equivalence class (with objects being
# smaller than variables)
mapping = {}
for eq_class in self.eq_classes.values():
variables = [item for item in eq_class if item.startswith("?")]
constants = [item for item in eq_class if not item.startswith("?")]
if len(constants) >= 2:
self.consistent = False
self.mapping = None
# Choose a representative for each equivalence class. Objects are
# prioritized over variables and ints, but at most one object per
# equivalence class is allowed (otherwise the conjunction is
# inconsistent).
representative = {}
for eq_class in self._eq_classes.values():
if next(iter(eq_class)) in representative:
continue # we already processed this equivalence class
variables = [item for item in eq_class if isinstance(item, int) or
item.startswith("?")]
objects = [item for item in eq_class if not isinstance(item, int)
and not item.startswith("?")]

if len(objects) >= 2:
self._consistent = False
self._representative = None
return
if constants:
set_val = constants[0]
if objects:
set_val = objects[0]
else:
set_val = min(variables)
set_val = variables[0]
for entry in eq_class:
mapping[entry] = set_val
self.consistent = True
self.mapping = mapping
representative[entry] = set_val
self._consistent = True
self._representative = representative

def is_consistent(self):
if self.consistent is None:
self._compute_mapping()
return self.consistent
if self._consistent is None:
self._compute_representatives()
return self._consistent

def get_mapping(self):
if self.consistent is None:
self._compute_mapping()
return self.mapping
def get_representative(self):
if self._consistent is None:
self._compute_representatives()
return self._representative


class ConstraintSystem:
"""A ConstraintSystem stores two parts, both talking about the equality or
inequality of strings and ints (strings representing objects or
variables, ints representing invariant parameters):
- equality_DNFs is a list containing lists of EqualityConjunctions.
Each EqualityConjunction represents an expression of the form
(x1 = y1 and ... and xn = yn). A list of EqualityConjunctions can be
interpreted as a disjunction of such expressions. So
self.equality_DNFs represents a formula of the form "⋀ ⋁ ⋀ (x = y)"
as a list of lists of EqualityConjunctions.
- ineq_disjunctions is a list of InequalityDisjunctions. Each of them
represents a expression of the form (u1 != v1 or ... or um !=i vm).
- not_constant is a list of strings.
We say that the system is solvable if we can pick from each list of
EqualityConjunctions in equality_DNFs one EquivalenceConjunction such
that the finest equivalence relation induced by all the equivalences in
the conjunctions is
- consistent, i.e. no equivalence class contains more than one object,
- for every disjunction in ineq_disjunctions there is at least one
inequality such that the two terms are in different equivalence
classes.
- every element of not_constant is not in the same equivalence class
as a constant.
We refer to the equivalence relation as the solution of the system."""

def __init__(self):
self.combinatorial_assignments = []
self.neg_clauses = []
self.equality_DNFs = []
self.ineq_disjunctions = []
self.not_constant = []

def __str__(self):
combinatorial_assignments = []
for comb_assignment in self.combinatorial_assignments:
disj = " or ".join([str(assig) for assig in comb_assignment])
equality_DNFs = []
for eq_DNF in self.equality_DNFs:
disj = " or ".join([str(eq_conjunction)
for eq_conjunction in eq_DNF])
disj = "(%s)" % disj
combinatorial_assignments.append(disj)
assigs = " and\n".join(combinatorial_assignments)

neg_clauses = [str(clause) for clause in self.neg_clauses]
neg_clauses = " and ".join(neg_clauses)
return assigs + "(" + neg_clauses + ")"

def _all_clauses_satisfiable(self, assignment):
mapping = assignment.get_mapping()
for neg_clause in self.neg_clauses:
clause = neg_clause.apply_mapping(mapping)
if not clause.is_satisfiable():
return False
return True

def _combine_assignments(self, assignments):
new_equalities = []
for a in assignments:
new_equalities.extend(a.equalities)
return Assignment(new_equalities)

def add_assignment(self, assignment):
self.add_assignment_disjunction([assignment])

def add_assignment_disjunction(self, assignments):
self.combinatorial_assignments.append(assignments)

def add_negative_clause(self, negative_clause):
self.neg_clauses.append(negative_clause)

def combine(self, other):
"""Combines two constraint systems to a new system"""
combined = ConstraintSystem()
combined.combinatorial_assignments = (self.combinatorial_assignments +
other.combinatorial_assignments)
combined.neg_clauses = self.neg_clauses + other.neg_clauses
return combined

def copy(self):
other = ConstraintSystem()
other.combinatorial_assignments = list(self.combinatorial_assignments)
other.neg_clauses = list(self.neg_clauses)
return other

def dump(self):
print("AssignmentSystem:")
for comb_assignment in self.combinatorial_assignments:
disj = " or ".join([str(assig) for assig in comb_assignment])
print(" ASS: ", disj)
for neg_clause in self.neg_clauses:
print(" NEG: ", str(neg_clause))
equality_DNFs.append(disj)
eq_part = " and\n".join(equality_DNFs)

ineq_disjunctions = [str(clause) for clause in self.ineq_disjunctions]
ineq_part = " and ".join(ineq_disjunctions)
return f"{eq_part} ({ineq_part}) (not constant {self.not_constant}"

def _combine_equality_conjunctions(self, eq_conjunctions:
Iterable[EqualityConjunction]) -> None:
all_eq = itertools.chain.from_iterable(c.equalities
for c in eq_conjunctions)
return EqualityConjunction(list(all_eq))

def add_equality_conjunction(self, eq_conjunction: EqualityConjunction):
self.add_equality_DNF([eq_conjunction])

def add_equality_DNF(self, equality_DNF: List[EqualityConjunction]) -> None:
self.equality_DNFs.append(equality_DNF)

def add_inequality_disjunction(self, ineq_disj: InequalityDisjunction):
self.ineq_disjunctions.append(ineq_disj)

def add_not_constant(self, not_constant: str) -> None:
self.not_constant.append(not_constant)

def extend(self, other: "ConstraintSystem") -> None:
self.equality_DNFs.extend(other.equality_DNFs)
self.ineq_disjunctions.extend(other.ineq_disjunctions)
self.not_constant.extend(other.not_constant)

def is_solvable(self):
"""Check whether the combinatorial assignments include at least
one consistent assignment under which the negative clauses
are satisfiable"""
for assignments in itertools.product(*self.combinatorial_assignments):
combined = self._combine_assignments(assignments)
# cf. top of class for explanation
def inequality_disjunction_ok(ineq_disj, representative):
for inequality in ineq_disj.parts:
a, b = inequality
if representative.get(a, a) != representative.get(b, b):
return True
return False

for eq_conjunction in itertools.product(*self.equality_DNFs):
combined = self._combine_equality_conjunctions(eq_conjunction)
if not combined.is_consistent():
continue
if self._all_clauses_satisfiable(combined):
return True
# check whether with the finest equivalence relation induced by the
# combined equality conjunction there is no element of not_constant
# in the same equivalence class as a constant and that in each
# inequality disjunction there is an inequality where the two terms
# are in different equivalence classes.
representative = combined.get_representative()
if any(not isinstance(representative.get(s, s), int) and
representative.get(s, s)[0] != "?"
for s in self.not_constant):
continue
if any(not inequality_disjunction_ok(d, representative)
for d in self.ineq_disjunctions):
continue
return True
return False
17 changes: 13 additions & 4 deletions src/translate/invariant_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,12 @@ def get_fluents(task):
def get_initial_invariants(task):
for predicate in get_fluents(task):
all_args = list(range(len(predicate.arguments)))
for omitted_arg in [-1] + all_args:
order = [i for i in all_args if i != omitted_arg]
part = invariants.InvariantPart(predicate.name, order, omitted_arg)
part = invariants.InvariantPart(predicate.name, all_args, None)
yield invariants.Invariant((part,))
for omitted in range(len(predicate.arguments)):
inv_args = (all_args[0:omitted] + [invariants.COUNTED] +
all_args[omitted:-1])
part = invariants.InvariantPart(predicate.name, inv_args, omitted)
yield invariants.Invariant((part,))

def find_invariants(task, reachable_action_params):
Expand Down Expand Up @@ -118,7 +121,13 @@ def useful_groups(invariants, initial_facts):
if isinstance(atom, pddl.Assign):
continue
for invariant in predicate_to_invariants.get(atom.predicate, ()):
group_key = (invariant, tuple(invariant.get_parameters(atom)))
parameters = invariant.get_parameters(atom)
# we need to make the parameters dictionary hashable, so
# we store the values as a tuple
parameters_tuple = tuple(parameters[var]
for var in range(invariant.arity()))

group_key = (invariant, parameters_tuple)
if group_key not in nonempty_groups:
nonempty_groups.add(group_key)
else:
Expand Down
Loading

0 comments on commit 492f71c

Please sign in to comment.