Skip to content

Commit

Permalink
Implements partial refactor to allow for Boolean & Numeric interop (v…
Browse files Browse the repository at this point in the history
…ia `FormulaTerm` wrapper) (#96)

These changes do not complete the full SMT-like refactor previously discussed, but *do* allow for interoperability between `Formula` and `Term` objects and treats the boolean sort differently so that it can be used in numeric arithmetic (as needed for standard patterns in RDDL). There are a handful hacks that were required to preserve the FSTRIPS *Effect syntax that should be fixed in the future.

The main changes include:
1. Changes the Boolean sort to be 0/1 valued and a child of the Naturals sort when the `Arithmetic` theory is included in a language (and a standalone child of `Object` otherwise)
2. Adds the `FormulaTerm` wrapper class, which is used as a container for `Formula` objects that must be treated as `Term` objects for arithmetic, etc.
3. All of `Predicate`, `Formula`, `Term`, `Function` objects are now associated with a `FirstOrderLanguage` (previously only `Function` and `Term` objects had a language property). Languages are inherited up from "subterms", when, for example, a `CompoundFormula` is constructed. This is essential for being able to construct `FormulaTerm` wrappers when needed, since `Term` objects always need an associated language.
4. Implements a set of tests (primarily focused on RDDL use cases) that adds to the existing test suite

The contributing commit messages follow:

* implements basic functionality for FormulaTerm wrapper function and makes the Boolean types a core builtin

* completes partial implementation of Term and Formula type conversion with wrappers

* implements major refactor components. includes modifications to tests to reflect API changes to FSTRIPS *Effect(s)

* fixes bug where multiple writes without reset would dump repeated obj domain lists in rddl instance

* implements basic RDDL writer integration test & makes it pass

* completes implemenation of academic_advising rddl writer integration test (passing)

* un-breaks strips *Effect building API with "`Pass` replaces `Tautology`" workaround

This is a bit of an ugly workaround, but it will work for now. we have
a special `Pass` type that is EXACTLY only ever used when we need to
construct or check against an FSTRIPS effect that is not a conditional
effect. Previously, the condition had been set as a default parameter
to `Tautology` for any regular FSTRIPS effect. However, since *Effects are
built outside of the context of a language, this did not work after we
needed a language for Tautology and Contradiction (so that other
Formula types could inherit them).

There are a few more permanent approaches to this. We could either
break the API and build Effects in the context of a language, or we
could figure out another way to have a universal Tautology that
somehow does not need to be in the context of a language. Notably this
problem will STILL BE AN ISSUE if we go to a fully boolean-valued
Function refactor (eliminating Predicates, etc), rather than the
current partial refactor that involves wrappers between Formula and
Term.

* updates tests related to Term/Formula interop and arithmetic with Booleans

* makes some code style fixes

* re-enables a test that is now passing again after pull from upstream devel

* adds option for 2018-style RDDL file format writing -- maintains default to pre-2018

* adds condition to avoid adding :numeric-fluents simply due to the Boolean sort being attached to the language

* fixes stdout name issue on macos
  • Loading branch information
mejrpete authored Oct 27, 2020
1 parent c16ef4a commit 233b568
Show file tree
Hide file tree
Showing 42 changed files with 880 additions and 282 deletions.
4 changes: 2 additions & 2 deletions src/tarski/analysis/csp_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
collect_effect_free_parameters
from ..grounding.common import StateVariableLite
from ..syntax import QuantifiedFormula, Quantifier, Contradiction, CompoundFormula, Atom, CompoundTerm, \
is_neg, symref, Constant, Variable, Tautology, top
is_neg, symref, Constant, Variable, Tautology
from ..syntax.ops import collect_unique_nodes, flatten
from ..syntax.transform import to_prenex_negation_normal_form

Expand Down Expand Up @@ -127,7 +127,7 @@ def compile_schema_csp(self, action, simplifier):
if precondition is False:
return None
if precondition is True:
precondition = top
precondition = self.lang.top()

csp = CSPInformation()
csp.parameter_index = [self.variable(p, csp, "param") for p in action.parameters]
Expand Down
23 changes: 22 additions & 1 deletion src/tarski/fol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from . import errors as err
from .errors import UndefinedElement
from .syntax import Function, Constant, Variable, Sort, inclusion_closure, Predicate, Interval
from .syntax import Function, Constant, Variable, Sort, inclusion_closure, Predicate, Interval, Tautology, Contradiction
from .syntax.formulas import FormulaTerm
from .syntax.algebra import Matrix
from . import modules

Expand Down Expand Up @@ -160,6 +161,14 @@ def variable(self, name: str, sort: Union[Sort, str]):
sort = self._retrieve_sort(sort)
return Variable(name, sort)

#todo: [John Peterson] not ideal to have to add this just to be able to fix booleans done 2 ways
def change_parent(self, sort: Sort, parent: Sort):
if parent.language is not self:
raise err.LanguageError("Tried to set as parent a sort from a different language")

self.immediate_parent[sort] = parent
self.ancestor_sorts[sort].update(inclusion_closure(parent))

def set_parent(self, sort: Sort, parent: Sort):
if parent.language is not self:
raise err.LanguageError("Tried to set as parent a sort from a different language")
Expand Down Expand Up @@ -250,6 +259,12 @@ def _check_name_not_defined(self, name, where, exception):
if name in self._global_index:
raise err.DuplicateDefinition(name, self._global_index[name])

def top(self):
return Tautology(self)

def bot(self):
return Contradiction(self)

def predicate(self, name: str, *args):
self._check_name_not_defined(name, self._predicates, err.DuplicatePredicateDefinition)

Expand Down Expand Up @@ -333,6 +348,12 @@ def __str__(self):
f"{len(self._functions)} functions and {len(self.constants())} constants"
__repr__ = __str__

#todo: [John Peterson] I'm not sure if this should be here. We
#need access to the language's sorts to be able to inject the
#necessary special boolean sort. Reevaluate as a todo.
def generate_formula_term(self, formula):
return FormulaTerm(formula)

def register_operator_handler(self, operator, t1, t2, handler):
self._operators[(operator, t1, t2)] = handler

Expand Down
1 change: 0 additions & 1 deletion src/tarski/fstrips/fstrips.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .. import theories as ths
from .errors import InvalidEffectError


class BaseEffect:
""" A base class for all FSTRIPS effects, which might have an (optional) condition. """
def __init__(self, condition):
Expand Down
17 changes: 10 additions & 7 deletions src/tarski/fstrips/manipulation/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
from ...evaluators.simple import evaluate
from ...grounding.ops import approximate_symbol_fluency
from ...syntax.terms import Constant, Variable, CompoundTerm
from ...syntax.formulas import CompoundFormula, QuantifiedFormula, Atom, Tautology, Contradiction, Connective, is_neg, \
Quantifier, unwrap_conjunction_or_atom, is_eq_atom, land, exists
from ...syntax.formulas import CompoundFormula, QuantifiedFormula, Atom, Pass, Tautology,\
Contradiction, Connective, is_neg, Quantifier, unwrap_conjunction_or_atom, is_eq_atom, land, exists
from ...syntax.transform.substitutions import substitute_expression
from ...syntax.util import get_symbols
from ...syntax.walker import FOLWalker
from ...syntax.ops import flatten
from ...syntax import symref


def bool_to_expr(val):
def bool_to_expr(val, lang):
if not isinstance(val, bool):
return val
return Tautology() if val else Contradiction()
return Tautology(lang) if val else Contradiction(lang)


class Simplify:
Expand Down Expand Up @@ -83,7 +83,7 @@ def simplify(self, inplace=False, remove_unused_symbols=False):
def simplify_action(self, action, inplace=False):
simple = action if inplace else copy.deepcopy(action)
simple.precondition = self.simplify_expression(simple.precondition, inplace=True)
if simple.precondition in (False, Contradiction):
if simple.precondition is False or isinstance(simple.precondition, Contradiction):
return None

# Filter out those effects that are None, e.g. because they are not applicable:
Expand All @@ -107,6 +107,9 @@ def simplify_expression(self, node, inplace=True):
if isinstance(node, Tautology):
return True

if isinstance(node, Pass):
return True

if isinstance(node, (CompoundTerm, Atom)):
node.subterms = [self.simplify_expression(st) for st in node.subterms]
if not self.node_can_be_statically_evaluated(node):
Expand Down Expand Up @@ -157,14 +160,14 @@ def simplify_effect(self, effect, inplace=True):
effect = effect if inplace else copy.deepcopy(effect)

if isinstance(effect, (AddEffect, DelEffect)):
effect.condition = bool_to_expr(self.simplify_expression(effect.condition))
effect.condition = bool_to_expr(self.simplify_expression(effect.condition), self.problem.language)
if isinstance(effect.condition, Contradiction):
return None
effect.atom = self.simplify_expression(effect.atom)
return effect

if isinstance(effect, FunctionalEffect):
effect.condition = bool_to_expr(self.simplify_expression(effect.condition))
effect.condition = bool_to_expr(self.simplify_expression(effect.condition), self.problem.language)
if isinstance(effect.condition, Contradiction):
return None
effect.lhs = self.simplify_expression(effect.lhs)
Expand Down
18 changes: 9 additions & 9 deletions src/tarski/fstrips/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .problem import Problem
from . import fstrips as fs
from ..syntax import Formula, CompoundTerm, Atom, CompoundFormula, QuantifiedFormula, is_and, is_neg, exists, symref,\
VariableBinding, Constant, Tautology, land, Term
VariableBinding, Constant, Tautology, land, Term, Pass
from ..syntax.ops import collect_unique_nodes, flatten, free_variables, all_variables
from ..syntax.sorts import compute_signature_bindings
from ..syntax.transform.substitutions import enumerate_substitutions
Expand Down Expand Up @@ -95,7 +95,7 @@ def transform_to_strips(what: Union[Problem, Action]):

def is_atomic_effect(eff: BaseEffect):
""" An effect is atomic if it is a single, unconditional effect. """
return isinstance(eff, SingleEffect) and isinstance(eff.condition, Tautology)
return isinstance(eff, SingleEffect) and isinstance(eff.condition, (Tautology, Pass))


def is_propositional_effect(eff: BaseEffect):
Expand Down Expand Up @@ -123,10 +123,10 @@ def compute_effect_set_conflicts(effects):
if not is_atomic_effect(eff) or not is_propositional_effect(eff):
raise RepresentationError(f"Don't know how to compute conflicts for effect {eff}")
pol = isinstance(eff, AddEffect) # i.e. polarity will be true if add effect, false otherwise
prev = polarities.get(eff.atom, None)
prev = polarities.get(symref(eff.atom), None)
if prev is not None and prev != pol:
conflicts.add(eff.atom)
polarities[eff.atom] = pol
conflicts.add(symref(eff.atom))
polarities[symref(eff.atom)] = pol
return conflicts


Expand Down Expand Up @@ -220,9 +220,9 @@ def collect_literals_from_conjunction(phi: Formula) -> Optional[Set[Tuple[Atom,

def _collect_literals_from_conjunction(f, literals: Set[Tuple[Atom, bool]]):
if isinstance(f, Atom):
literals.add((f, True))
literals.add((symref(f), True))
elif is_neg(f) and isinstance(f.subformulas[0], Atom):
literals.add((f.subformulas[0], False))
literals.add((symref(f.subformulas[0]), False))
elif is_and(f):
for sub in f.subformulas:
if not _collect_literals_from_conjunction(sub, literals):
Expand Down Expand Up @@ -465,7 +465,7 @@ def compile_action_negated_preconditions_away(action: Action, negpreds, inplace=
if not isinstance(eff, SingleEffect):
raise RepresentationError(f"Cannot compile away negated conditions for effect '{eff}'")

if not isinstance(eff.condition, Tautology):
if not isinstance(eff.condition, (Tautology, Pass)):
eff.condition = compile_away_formula_negated_literals(eff.condition, negpreds, inplace=True)

return action
Expand Down Expand Up @@ -567,7 +567,7 @@ def expand_universal_effect(effect):
if not isinstance(effect, UniversalEffect):
return [effect]

assert isinstance(effect.condition, Tautology) # TODO Lift this restriction
assert isinstance(effect.condition, (Tautology, Pass)) # TODO Lift this restriction
expanded = []
for subst in enumerate_substitutions(effect.variables):
for sub in effect.effects:
Expand Down
8 changes: 5 additions & 3 deletions src/tarski/fstrips/walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,13 @@ def visit_effect(self, effect, inplace=True):
return self.visit(effect)

def visit_expression(self, node, inplace=True):
from ..syntax import CompoundFormula, QuantifiedFormula, Atom, Tautology, Contradiction, Constant, Variable,\
CompoundTerm, IfThenElse # pylint: disable=import-outside-toplevel # Avoiding circular references
# pylint: disable=import-outside-toplevel
from ..syntax import CompoundFormula, QuantifiedFormula, Atom, \
Tautology, Contradiction, Pass, Constant, Variable, \
CompoundTerm, IfThenElse
node = node if inplace else copy.deepcopy(node)

if isinstance(node, (Variable, Constant, Contradiction, Tautology)):
if isinstance(node, (Variable, Constant, Contradiction, Tautology, Pass)):
pass

elif isinstance(node, (CompoundTerm, Atom)):
Expand Down
6 changes: 3 additions & 3 deletions src/tarski/io/_fstrips/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ...fstrips import FunctionalEffect
from ...fstrips.action import AdditiveActionCost, generate_zero_action_cost
from ...fstrips.representation import is_typed_problem
from ...syntax import Interval, CompoundTerm, Tautology, BuiltinFunctionSymbol
from ...syntax import Interval, CompoundTerm, Tautology, BuiltinFunctionSymbol, Pass
from ... import theories
from ...syntax.util import get_symbols
from ...theories import Theory
Expand Down Expand Up @@ -58,7 +58,7 @@ def get_requirements_string(problem):
# Let's check now whether the problem has any predicate or function symbol *other than "total-cost"* which
# has some arithmetic parameter or result. If so, we add the ":numeric-fluents" requirement.
for symbol in get_symbols(problem.language, type_='all', include_builtin=False):
if any(isinstance(s, Interval) for s in symbol.sort) and symbol.name != 'total-cost':
if any((isinstance(s, Interval) and s.name != 'Boolean') for s in symbol.sort) and symbol.name != 'total-cost':
requirements.add(":numeric-fluents")

return requirements
Expand Down Expand Up @@ -113,7 +113,7 @@ def process_cost_effects(effects):
def process_cost_effect(eff):
""" Check if the given effect is a cost effect. If it is, return the additive cost; if it is not, return None. """
if isinstance(eff, FunctionalEffect) and isinstance(eff.lhs, CompoundTerm) and eff.lhs.symbol.name == "total-cost":
if not isinstance(eff.condition, Tautology):
if not isinstance(eff.condition, Pass):
raise TarskiError(f'Don\'t know how to process conditional cost effects such as {eff}')
if not isinstance(eff.rhs, CompoundTerm) or eff.rhs.symbol.name != BuiltinFunctionSymbol.ADD:
raise TarskiError(f'Don\'t know how to process non-additive cost effects such as {eff}')
Expand Down
70 changes: 34 additions & 36 deletions src/tarski/io/_fstrips/parser/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,45 +641,45 @@ class fstripsLexer(Lexer):
modeNames = [ "DEFAULT_MODE" ]

literalNames = [ "<INVALID>",
"'('", "'define'", "')'", "'domain'", "':requirements'", "':types'",
"'-'", "'either'", "':functions'", "':constants'", "':predicates'",
"':parameters'", "':constraint'", "':condition'", "':event'",
"'#t'", "':derived'", "'assign'", "'*'", "'+'", "'/'", "'^'",
"'max'", "'min'", "'sin'", "'cos'", "'sqrt'", "'tan'", "'acos'",
"'asin'", "'atan'", "'exp'", "'abs'", "'>'", "'<'", "'='", "'>='",
"'<='", "'problem'", "':domain'", "':objects'", "':bounds'",
"'['", "'..'", "']'", "':goal'", "':constraints'", "'preference'",
"':metric'", "'minimize'", "'maximize'", "'(total-time)'", "'is-violated'",
"':terminal'", "':stage'", "'at-end'", "'always'", "'sometime'",
"'within'", "'at-most-once'", "'sometime-after'", "'sometime-before'",
"'always-within'", "'hold-during'", "'hold-after'", "'scale-up'",
"'('", "'define'", "')'", "'domain'", "':requirements'", "':types'",
"'-'", "'either'", "':functions'", "':constants'", "':predicates'",
"':parameters'", "':constraint'", "':condition'", "':event'",
"'#t'", "':derived'", "'assign'", "'*'", "'+'", "'/'", "'^'",
"'max'", "'min'", "'sin'", "'cos'", "'sqrt'", "'tan'", "'acos'",
"'asin'", "'atan'", "'exp'", "'abs'", "'>'", "'<'", "'='", "'>='",
"'<='", "'problem'", "':domain'", "':objects'", "':bounds'",
"'['", "'..'", "']'", "':goal'", "':constraints'", "'preference'",
"':metric'", "'minimize'", "'maximize'", "'(total-time)'", "'is-violated'",
"':terminal'", "':stage'", "'at-end'", "'always'", "'sometime'",
"'within'", "'at-most-once'", "'sometime-after'", "'sometime-before'",
"'always-within'", "'hold-during'", "'hold-after'", "'scale-up'",
"'scale-down'", "'int'", "'float'", "'object'", "'number'" ]

symbolicNames = [ "<INVALID>",
"REQUIRE_KEY", "K_AND", "K_NOT", "K_OR", "K_IMPLY", "K_EXISTS",
"K_FORALL", "K_WHEN", "K_ACTION", "K_INCREASE", "K_DECREASE",
"K_SCALEUP", "K_SCALEDOWN", "INT_T", "FLOAT_T", "OBJECT_T",
"NUMBER_T", "NAME", "EXTNAME", "VARIABLE", "NUMBER", "LINE_COMMENT",
"REQUIRE_KEY", "K_AND", "K_NOT", "K_OR", "K_IMPLY", "K_EXISTS",
"K_FORALL", "K_WHEN", "K_ACTION", "K_INCREASE", "K_DECREASE",
"K_SCALEUP", "K_SCALEDOWN", "INT_T", "FLOAT_T", "OBJECT_T",
"NUMBER_T", "NAME", "EXTNAME", "VARIABLE", "NUMBER", "LINE_COMMENT",
"WHITESPACE", "K_INIT", "K_PRECONDITION", "K_EFFECT" ]

ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6",
"T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13",
"T__14", "T__15", "T__16", "T__17", "T__18", "T__19",
"T__20", "T__21", "T__22", "T__23", "T__24", "T__25",
"T__26", "T__27", "T__28", "T__29", "T__30", "T__31",
"T__32", "T__33", "T__34", "T__35", "T__36", "T__37",
"T__38", "T__39", "T__40", "T__41", "T__42", "T__43",
"T__44", "T__45", "T__46", "T__47", "T__48", "T__49",
"T__50", "T__51", "T__52", "T__53", "T__54", "T__55",
"T__56", "T__57", "T__58", "T__59", "T__60", "T__61",
"T__62", "T__63", "T__64", "REQUIRE_KEY", "K_AND", "K_NOT",
"K_OR", "K_IMPLY", "K_EXISTS", "K_FORALL", "K_WHEN", "K_ACTION",
"K_INCREASE", "K_DECREASE", "K_SCALEUP", "K_SCALEDOWN",
"INT_T", "FLOAT_T", "OBJECT_T", "NUMBER_T", "NAME", "EXTNAME",
"DIGIT", "LETTER", "ANY_CHAR_WO_HYPHEN", "ANY_CHAR", "VARIABLE",
"NUMBER", "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION",
"K_EFFECT", "A", "B", "C", "D", "E", "F", "G", "H", "I",
"J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T",
ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6",
"T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13",
"T__14", "T__15", "T__16", "T__17", "T__18", "T__19",
"T__20", "T__21", "T__22", "T__23", "T__24", "T__25",
"T__26", "T__27", "T__28", "T__29", "T__30", "T__31",
"T__32", "T__33", "T__34", "T__35", "T__36", "T__37",
"T__38", "T__39", "T__40", "T__41", "T__42", "T__43",
"T__44", "T__45", "T__46", "T__47", "T__48", "T__49",
"T__50", "T__51", "T__52", "T__53", "T__54", "T__55",
"T__56", "T__57", "T__58", "T__59", "T__60", "T__61",
"T__62", "T__63", "T__64", "REQUIRE_KEY", "K_AND", "K_NOT",
"K_OR", "K_IMPLY", "K_EXISTS", "K_FORALL", "K_WHEN", "K_ACTION",
"K_INCREASE", "K_DECREASE", "K_SCALEUP", "K_SCALEDOWN",
"INT_T", "FLOAT_T", "OBJECT_T", "NUMBER_T", "NAME", "EXTNAME",
"DIGIT", "LETTER", "ANY_CHAR_WO_HYPHEN", "ANY_CHAR", "VARIABLE",
"NUMBER", "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION",
"K_EFFECT", "A", "B", "C", "D", "E", "F", "G", "H", "I",
"J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T",
"U", "V", "W", "X", "Y", "Z" ]

grammarFileName = "fstrips.g4"
Expand All @@ -690,5 +690,3 @@ def __init__(self, input=None, output:TextIO = sys.stdout):
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
self._actions = None
self._predicates = None


2 changes: 0 additions & 2 deletions src/tarski/io/_fstrips/parser/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,5 +1041,3 @@ def enterAlternativeAlwaysConstraint(self, ctx:fstripsParser.AlternativeAlwaysCo
# Exit a parse tree produced by fstripsParser#AlternativeAlwaysConstraint.
def exitAlternativeAlwaysConstraint(self, ctx:fstripsParser.AlternativeAlwaysConstraintContext):
pass


Loading

0 comments on commit 233b568

Please sign in to comment.