diff --git a/src/tarski/analysis/csp_schema.py b/src/tarski/analysis/csp_schema.py index 9a0f47ae..850302e5 100644 --- a/src/tarski/analysis/csp_schema.py +++ b/src/tarski/analysis/csp_schema.py @@ -8,7 +8,7 @@ collect_effect_free_parameters from ..grounding.common import StateVariableLite from ..syntax import QuantifiedFormula, Quantifier, Contradiction, CompoundFormula, Atom, CompoundTerm, \ - is_neg, symref, Constant, Variable, Tautology, top + is_neg, symref, Constant, Variable, Tautology from ..syntax.ops import collect_unique_nodes, flatten from ..syntax.transform import to_prenex_negation_normal_form @@ -127,7 +127,7 @@ def compile_schema_csp(self, action, simplifier): if precondition is False: return None if precondition is True: - precondition = top + precondition = self.lang.top() csp = CSPInformation() csp.parameter_index = [self.variable(p, csp, "param") for p in action.parameters] diff --git a/src/tarski/fol.py b/src/tarski/fol.py index 4532da05..d9bbe574 100644 --- a/src/tarski/fol.py +++ b/src/tarski/fol.py @@ -6,7 +6,8 @@ from . import errors as err from .errors import UndefinedElement -from .syntax import Function, Constant, Variable, Sort, inclusion_closure, Predicate, Interval +from .syntax import Function, Constant, Variable, Sort, inclusion_closure, Predicate, Interval, Tautology, Contradiction +from .syntax.formulas import FormulaTerm from .syntax.algebra import Matrix from . import modules @@ -160,6 +161,14 @@ def variable(self, name: str, sort: Union[Sort, str]): sort = self._retrieve_sort(sort) return Variable(name, sort) + #todo: [John Peterson] not ideal to have to add this just to be able to fix booleans done 2 ways + def change_parent(self, sort: Sort, parent: Sort): + if parent.language is not self: + raise err.LanguageError("Tried to set as parent a sort from a different language") + + self.immediate_parent[sort] = parent + self.ancestor_sorts[sort].update(inclusion_closure(parent)) + def set_parent(self, sort: Sort, parent: Sort): if parent.language is not self: raise err.LanguageError("Tried to set as parent a sort from a different language") @@ -250,6 +259,12 @@ def _check_name_not_defined(self, name, where, exception): if name in self._global_index: raise err.DuplicateDefinition(name, self._global_index[name]) + def top(self): + return Tautology(self) + + def bot(self): + return Contradiction(self) + def predicate(self, name: str, *args): self._check_name_not_defined(name, self._predicates, err.DuplicatePredicateDefinition) @@ -333,6 +348,12 @@ def __str__(self): f"{len(self._functions)} functions and {len(self.constants())} constants" __repr__ = __str__ + #todo: [John Peterson] I'm not sure if this should be here. We + #need access to the language's sorts to be able to inject the + #necessary special boolean sort. Reevaluate as a todo. + def generate_formula_term(self, formula): + return FormulaTerm(formula) + def register_operator_handler(self, operator, t1, t2, handler): self._operators[(operator, t1, t2)] = handler diff --git a/src/tarski/fstrips/fstrips.py b/src/tarski/fstrips/fstrips.py index 6b5acec4..23896f2b 100644 --- a/src/tarski/fstrips/fstrips.py +++ b/src/tarski/fstrips/fstrips.py @@ -6,7 +6,6 @@ from .. import theories as ths from .errors import InvalidEffectError - class BaseEffect: """ A base class for all FSTRIPS effects, which might have an (optional) condition. """ def __init__(self, condition): diff --git a/src/tarski/fstrips/manipulation/simplify.py b/src/tarski/fstrips/manipulation/simplify.py index 5d9fa5fd..c4f38bf5 100644 --- a/src/tarski/fstrips/manipulation/simplify.py +++ b/src/tarski/fstrips/manipulation/simplify.py @@ -7,8 +7,8 @@ from ...evaluators.simple import evaluate from ...grounding.ops import approximate_symbol_fluency from ...syntax.terms import Constant, Variable, CompoundTerm -from ...syntax.formulas import CompoundFormula, QuantifiedFormula, Atom, Tautology, Contradiction, Connective, is_neg, \ - Quantifier, unwrap_conjunction_or_atom, is_eq_atom, land, exists +from ...syntax.formulas import CompoundFormula, QuantifiedFormula, Atom, Pass, Tautology,\ + Contradiction, Connective, is_neg, Quantifier, unwrap_conjunction_or_atom, is_eq_atom, land, exists from ...syntax.transform.substitutions import substitute_expression from ...syntax.util import get_symbols from ...syntax.walker import FOLWalker @@ -16,10 +16,10 @@ from ...syntax import symref -def bool_to_expr(val): +def bool_to_expr(val, lang): if not isinstance(val, bool): return val - return Tautology() if val else Contradiction() + return Tautology(lang) if val else Contradiction(lang) class Simplify: @@ -83,7 +83,7 @@ def simplify(self, inplace=False, remove_unused_symbols=False): def simplify_action(self, action, inplace=False): simple = action if inplace else copy.deepcopy(action) simple.precondition = self.simplify_expression(simple.precondition, inplace=True) - if simple.precondition in (False, Contradiction): + if simple.precondition is False or isinstance(simple.precondition, Contradiction): return None # Filter out those effects that are None, e.g. because they are not applicable: @@ -107,6 +107,9 @@ def simplify_expression(self, node, inplace=True): if isinstance(node, Tautology): return True + if isinstance(node, Pass): + return True + if isinstance(node, (CompoundTerm, Atom)): node.subterms = [self.simplify_expression(st) for st in node.subterms] if not self.node_can_be_statically_evaluated(node): @@ -157,14 +160,14 @@ def simplify_effect(self, effect, inplace=True): effect = effect if inplace else copy.deepcopy(effect) if isinstance(effect, (AddEffect, DelEffect)): - effect.condition = bool_to_expr(self.simplify_expression(effect.condition)) + effect.condition = bool_to_expr(self.simplify_expression(effect.condition), self.problem.language) if isinstance(effect.condition, Contradiction): return None effect.atom = self.simplify_expression(effect.atom) return effect if isinstance(effect, FunctionalEffect): - effect.condition = bool_to_expr(self.simplify_expression(effect.condition)) + effect.condition = bool_to_expr(self.simplify_expression(effect.condition), self.problem.language) if isinstance(effect.condition, Contradiction): return None effect.lhs = self.simplify_expression(effect.lhs) diff --git a/src/tarski/fstrips/representation.py b/src/tarski/fstrips/representation.py index 4a95f742..17a9c3fe 100644 --- a/src/tarski/fstrips/representation.py +++ b/src/tarski/fstrips/representation.py @@ -5,7 +5,7 @@ from .problem import Problem from . import fstrips as fs from ..syntax import Formula, CompoundTerm, Atom, CompoundFormula, QuantifiedFormula, is_and, is_neg, exists, symref,\ - VariableBinding, Constant, Tautology, land, Term + VariableBinding, Constant, Tautology, land, Term, Pass from ..syntax.ops import collect_unique_nodes, flatten, free_variables, all_variables from ..syntax.sorts import compute_signature_bindings from ..syntax.transform.substitutions import enumerate_substitutions @@ -95,7 +95,7 @@ def transform_to_strips(what: Union[Problem, Action]): def is_atomic_effect(eff: BaseEffect): """ An effect is atomic if it is a single, unconditional effect. """ - return isinstance(eff, SingleEffect) and isinstance(eff.condition, Tautology) + return isinstance(eff, SingleEffect) and isinstance(eff.condition, (Tautology, Pass)) def is_propositional_effect(eff: BaseEffect): @@ -123,10 +123,10 @@ def compute_effect_set_conflicts(effects): if not is_atomic_effect(eff) or not is_propositional_effect(eff): raise RepresentationError(f"Don't know how to compute conflicts for effect {eff}") pol = isinstance(eff, AddEffect) # i.e. polarity will be true if add effect, false otherwise - prev = polarities.get(eff.atom, None) + prev = polarities.get(symref(eff.atom), None) if prev is not None and prev != pol: - conflicts.add(eff.atom) - polarities[eff.atom] = pol + conflicts.add(symref(eff.atom)) + polarities[symref(eff.atom)] = pol return conflicts @@ -220,9 +220,9 @@ def collect_literals_from_conjunction(phi: Formula) -> Optional[Set[Tuple[Atom, def _collect_literals_from_conjunction(f, literals: Set[Tuple[Atom, bool]]): if isinstance(f, Atom): - literals.add((f, True)) + literals.add((symref(f), True)) elif is_neg(f) and isinstance(f.subformulas[0], Atom): - literals.add((f.subformulas[0], False)) + literals.add((symref(f.subformulas[0]), False)) elif is_and(f): for sub in f.subformulas: if not _collect_literals_from_conjunction(sub, literals): @@ -465,7 +465,7 @@ def compile_action_negated_preconditions_away(action: Action, negpreds, inplace= if not isinstance(eff, SingleEffect): raise RepresentationError(f"Cannot compile away negated conditions for effect '{eff}'") - if not isinstance(eff.condition, Tautology): + if not isinstance(eff.condition, (Tautology, Pass)): eff.condition = compile_away_formula_negated_literals(eff.condition, negpreds, inplace=True) return action @@ -567,7 +567,7 @@ def expand_universal_effect(effect): if not isinstance(effect, UniversalEffect): return [effect] - assert isinstance(effect.condition, Tautology) # TODO Lift this restriction + assert isinstance(effect.condition, (Tautology, Pass)) # TODO Lift this restriction expanded = [] for subst in enumerate_substitutions(effect.variables): for sub in effect.effects: diff --git a/src/tarski/fstrips/walker.py b/src/tarski/fstrips/walker.py index b219d0fb..7bbeb61b 100644 --- a/src/tarski/fstrips/walker.py +++ b/src/tarski/fstrips/walker.py @@ -121,11 +121,13 @@ def visit_effect(self, effect, inplace=True): return self.visit(effect) def visit_expression(self, node, inplace=True): - from ..syntax import CompoundFormula, QuantifiedFormula, Atom, Tautology, Contradiction, Constant, Variable,\ - CompoundTerm, IfThenElse # pylint: disable=import-outside-toplevel # Avoiding circular references + # pylint: disable=import-outside-toplevel + from ..syntax import CompoundFormula, QuantifiedFormula, Atom, \ + Tautology, Contradiction, Pass, Constant, Variable, \ + CompoundTerm, IfThenElse node = node if inplace else copy.deepcopy(node) - if isinstance(node, (Variable, Constant, Contradiction, Tautology)): + if isinstance(node, (Variable, Constant, Contradiction, Tautology, Pass)): pass elif isinstance(node, (CompoundTerm, Atom)): diff --git a/src/tarski/io/_fstrips/common.py b/src/tarski/io/_fstrips/common.py index dec17303..7629bda7 100644 --- a/src/tarski/io/_fstrips/common.py +++ b/src/tarski/io/_fstrips/common.py @@ -2,7 +2,7 @@ from ...fstrips import FunctionalEffect from ...fstrips.action import AdditiveActionCost, generate_zero_action_cost from ...fstrips.representation import is_typed_problem -from ...syntax import Interval, CompoundTerm, Tautology, BuiltinFunctionSymbol +from ...syntax import Interval, CompoundTerm, Tautology, BuiltinFunctionSymbol, Pass from ... import theories from ...syntax.util import get_symbols from ...theories import Theory @@ -58,7 +58,7 @@ def get_requirements_string(problem): # Let's check now whether the problem has any predicate or function symbol *other than "total-cost"* which # has some arithmetic parameter or result. If so, we add the ":numeric-fluents" requirement. for symbol in get_symbols(problem.language, type_='all', include_builtin=False): - if any(isinstance(s, Interval) for s in symbol.sort) and symbol.name != 'total-cost': + if any((isinstance(s, Interval) and s.name != 'Boolean') for s in symbol.sort) and symbol.name != 'total-cost': requirements.add(":numeric-fluents") return requirements @@ -113,7 +113,7 @@ def process_cost_effects(effects): def process_cost_effect(eff): """ Check if the given effect is a cost effect. If it is, return the additive cost; if it is not, return None. """ if isinstance(eff, FunctionalEffect) and isinstance(eff.lhs, CompoundTerm) and eff.lhs.symbol.name == "total-cost": - if not isinstance(eff.condition, Tautology): + if not isinstance(eff.condition, Pass): raise TarskiError(f'Don\'t know how to process conditional cost effects such as {eff}') if not isinstance(eff.rhs, CompoundTerm) or eff.rhs.symbol.name != BuiltinFunctionSymbol.ADD: raise TarskiError(f'Don\'t know how to process non-additive cost effects such as {eff}') diff --git a/src/tarski/io/_fstrips/parser/lexer.py b/src/tarski/io/_fstrips/parser/lexer.py index 500531ae..6ee2c36b 100644 --- a/src/tarski/io/_fstrips/parser/lexer.py +++ b/src/tarski/io/_fstrips/parser/lexer.py @@ -641,45 +641,45 @@ class fstripsLexer(Lexer): modeNames = [ "DEFAULT_MODE" ] literalNames = [ "", - "'('", "'define'", "')'", "'domain'", "':requirements'", "':types'", - "'-'", "'either'", "':functions'", "':constants'", "':predicates'", - "':parameters'", "':constraint'", "':condition'", "':event'", - "'#t'", "':derived'", "'assign'", "'*'", "'+'", "'/'", "'^'", - "'max'", "'min'", "'sin'", "'cos'", "'sqrt'", "'tan'", "'acos'", - "'asin'", "'atan'", "'exp'", "'abs'", "'>'", "'<'", "'='", "'>='", - "'<='", "'problem'", "':domain'", "':objects'", "':bounds'", - "'['", "'..'", "']'", "':goal'", "':constraints'", "'preference'", - "':metric'", "'minimize'", "'maximize'", "'(total-time)'", "'is-violated'", - "':terminal'", "':stage'", "'at-end'", "'always'", "'sometime'", - "'within'", "'at-most-once'", "'sometime-after'", "'sometime-before'", - "'always-within'", "'hold-during'", "'hold-after'", "'scale-up'", + "'('", "'define'", "')'", "'domain'", "':requirements'", "':types'", + "'-'", "'either'", "':functions'", "':constants'", "':predicates'", + "':parameters'", "':constraint'", "':condition'", "':event'", + "'#t'", "':derived'", "'assign'", "'*'", "'+'", "'/'", "'^'", + "'max'", "'min'", "'sin'", "'cos'", "'sqrt'", "'tan'", "'acos'", + "'asin'", "'atan'", "'exp'", "'abs'", "'>'", "'<'", "'='", "'>='", + "'<='", "'problem'", "':domain'", "':objects'", "':bounds'", + "'['", "'..'", "']'", "':goal'", "':constraints'", "'preference'", + "':metric'", "'minimize'", "'maximize'", "'(total-time)'", "'is-violated'", + "':terminal'", "':stage'", "'at-end'", "'always'", "'sometime'", + "'within'", "'at-most-once'", "'sometime-after'", "'sometime-before'", + "'always-within'", "'hold-during'", "'hold-after'", "'scale-up'", "'scale-down'", "'int'", "'float'", "'object'", "'number'" ] symbolicNames = [ "", - "REQUIRE_KEY", "K_AND", "K_NOT", "K_OR", "K_IMPLY", "K_EXISTS", - "K_FORALL", "K_WHEN", "K_ACTION", "K_INCREASE", "K_DECREASE", - "K_SCALEUP", "K_SCALEDOWN", "INT_T", "FLOAT_T", "OBJECT_T", - "NUMBER_T", "NAME", "EXTNAME", "VARIABLE", "NUMBER", "LINE_COMMENT", + "REQUIRE_KEY", "K_AND", "K_NOT", "K_OR", "K_IMPLY", "K_EXISTS", + "K_FORALL", "K_WHEN", "K_ACTION", "K_INCREASE", "K_DECREASE", + "K_SCALEUP", "K_SCALEDOWN", "INT_T", "FLOAT_T", "OBJECT_T", + "NUMBER_T", "NAME", "EXTNAME", "VARIABLE", "NUMBER", "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION", "K_EFFECT" ] - ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", - "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", - "T__14", "T__15", "T__16", "T__17", "T__18", "T__19", - "T__20", "T__21", "T__22", "T__23", "T__24", "T__25", - "T__26", "T__27", "T__28", "T__29", "T__30", "T__31", - "T__32", "T__33", "T__34", "T__35", "T__36", "T__37", - "T__38", "T__39", "T__40", "T__41", "T__42", "T__43", - "T__44", "T__45", "T__46", "T__47", "T__48", "T__49", - "T__50", "T__51", "T__52", "T__53", "T__54", "T__55", - "T__56", "T__57", "T__58", "T__59", "T__60", "T__61", - "T__62", "T__63", "T__64", "REQUIRE_KEY", "K_AND", "K_NOT", - "K_OR", "K_IMPLY", "K_EXISTS", "K_FORALL", "K_WHEN", "K_ACTION", - "K_INCREASE", "K_DECREASE", "K_SCALEUP", "K_SCALEDOWN", - "INT_T", "FLOAT_T", "OBJECT_T", "NUMBER_T", "NAME", "EXTNAME", - "DIGIT", "LETTER", "ANY_CHAR_WO_HYPHEN", "ANY_CHAR", "VARIABLE", - "NUMBER", "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION", - "K_EFFECT", "A", "B", "C", "D", "E", "F", "G", "H", "I", - "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", + ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", + "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", + "T__14", "T__15", "T__16", "T__17", "T__18", "T__19", + "T__20", "T__21", "T__22", "T__23", "T__24", "T__25", + "T__26", "T__27", "T__28", "T__29", "T__30", "T__31", + "T__32", "T__33", "T__34", "T__35", "T__36", "T__37", + "T__38", "T__39", "T__40", "T__41", "T__42", "T__43", + "T__44", "T__45", "T__46", "T__47", "T__48", "T__49", + "T__50", "T__51", "T__52", "T__53", "T__54", "T__55", + "T__56", "T__57", "T__58", "T__59", "T__60", "T__61", + "T__62", "T__63", "T__64", "REQUIRE_KEY", "K_AND", "K_NOT", + "K_OR", "K_IMPLY", "K_EXISTS", "K_FORALL", "K_WHEN", "K_ACTION", + "K_INCREASE", "K_DECREASE", "K_SCALEUP", "K_SCALEDOWN", + "INT_T", "FLOAT_T", "OBJECT_T", "NUMBER_T", "NAME", "EXTNAME", + "DIGIT", "LETTER", "ANY_CHAR_WO_HYPHEN", "ANY_CHAR", "VARIABLE", + "NUMBER", "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION", + "K_EFFECT", "A", "B", "C", "D", "E", "F", "G", "H", "I", + "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z" ] grammarFileName = "fstrips.g4" @@ -690,5 +690,3 @@ def __init__(self, input=None, output:TextIO = sys.stdout): self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) self._actions = None self._predicates = None - - diff --git a/src/tarski/io/_fstrips/parser/listener.py b/src/tarski/io/_fstrips/parser/listener.py index 63dd4e21..f9304bef 100644 --- a/src/tarski/io/_fstrips/parser/listener.py +++ b/src/tarski/io/_fstrips/parser/listener.py @@ -1041,5 +1041,3 @@ def enterAlternativeAlwaysConstraint(self, ctx:fstripsParser.AlternativeAlwaysCo # Exit a parse tree produced by fstripsParser#AlternativeAlwaysConstraint. def exitAlternativeAlwaysConstraint(self, ctx:fstripsParser.AlternativeAlwaysConstraintContext): pass - - diff --git a/src/tarski/io/_fstrips/parser/parser.py b/src/tarski/io/_fstrips/parser/parser.py index f29a2cc3..8e58a6bf 100644 --- a/src/tarski/io/_fstrips/parser/parser.py +++ b/src/tarski/io/_fstrips/parser/parser.py @@ -398,58 +398,58 @@ class fstripsParser ( Parser ): sharedContextCache = PredictionContextCache() - literalNames = [ "", "'('", "'define'", "')'", "'domain'", - "':requirements'", "':types'", "'-'", "'either'", "':functions'", - "':constants'", "':predicates'", "':parameters'", "':constraint'", - "':condition'", "':event'", "'#t'", "':derived'", "'assign'", - "'*'", "'+'", "'/'", "'^'", "'max'", "'min'", "'sin'", - "'cos'", "'sqrt'", "'tan'", "'acos'", "'asin'", "'atan'", - "'exp'", "'abs'", "'>'", "'<'", "'='", "'>='", "'<='", - "'problem'", "':domain'", "':objects'", "':bounds'", - "'['", "'..'", "']'", "':goal'", "':constraints'", - "'preference'", "':metric'", "'minimize'", "'maximize'", - "'(total-time)'", "'is-violated'", "':terminal'", "':stage'", - "'at-end'", "'always'", "'sometime'", "'within'", "'at-most-once'", - "'sometime-after'", "'sometime-before'", "'always-within'", - "'hold-during'", "'hold-after'", "", "", - "", "", "", "", - "", "", "", "", - "", "'scale-up'", "'scale-down'", "'int'", + literalNames = [ "", "'('", "'define'", "')'", "'domain'", + "':requirements'", "':types'", "'-'", "'either'", "':functions'", + "':constants'", "':predicates'", "':parameters'", "':constraint'", + "':condition'", "':event'", "'#t'", "':derived'", "'assign'", + "'*'", "'+'", "'/'", "'^'", "'max'", "'min'", "'sin'", + "'cos'", "'sqrt'", "'tan'", "'acos'", "'asin'", "'atan'", + "'exp'", "'abs'", "'>'", "'<'", "'='", "'>='", "'<='", + "'problem'", "':domain'", "':objects'", "':bounds'", + "'['", "'..'", "']'", "':goal'", "':constraints'", + "'preference'", "':metric'", "'minimize'", "'maximize'", + "'(total-time)'", "'is-violated'", "':terminal'", "':stage'", + "'at-end'", "'always'", "'sometime'", "'within'", "'at-most-once'", + "'sometime-after'", "'sometime-before'", "'always-within'", + "'hold-during'", "'hold-after'", "", "", + "", "", "", "", + "", "", "", "", + "", "'scale-up'", "'scale-down'", "'int'", "'float'", "'object'", "'number'" ] - symbolicNames = [ "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "REQUIRE_KEY", "K_AND", - "K_NOT", "K_OR", "K_IMPLY", "K_EXISTS", "K_FORALL", - "K_WHEN", "K_ACTION", "K_INCREASE", "K_DECREASE", - "K_SCALEUP", "K_SCALEDOWN", "INT_T", "FLOAT_T", "OBJECT_T", - "NUMBER_T", "NAME", "EXTNAME", "VARIABLE", "NUMBER", - "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION", - "K_EFFECT", "DOMAIN", "DOMAIN_NAME", "REQUIREMENTS", - "TYPES", "EITHER_TYPE", "CONSTANTS", "FUNCTIONS", - "FREE_FUNCTIONS", "PREDICATES", "ACTION", "CONSTRAINT", - "EVENT", "GLOBAL_CONSTRAINT", "DURATIVE_ACTION", "PROBLEM", - "PROBLEM_NAME", "PROBLEM_DOMAIN", "OBJECTS", "INIT", - "FUNC_HEAD", "PRECONDITION", "EFFECT", "AND_GD", "OR_GD", - "NOT_GD", "IMPLY_GD", "EXISTS_GD", "FORALL_GD", "COMPARISON_GD", - "AND_EFFECT", "FORALL_EFFECT", "WHEN_EFFECT", "ASSIGN_EFFECT", - "NOT_EFFECT", "PRED_HEAD", "GOAL", "BINARY_OP", "EQUALITY_CON", - "MULTI_OP", "MINUS_OP", "UNARY_MINUS", "INIT_EQ", - "INIT_AT", "NOT_PRED_INIT", "PRED_INST", "PROBLEM_CONSTRAINT", + symbolicNames = [ "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "REQUIRE_KEY", "K_AND", + "K_NOT", "K_OR", "K_IMPLY", "K_EXISTS", "K_FORALL", + "K_WHEN", "K_ACTION", "K_INCREASE", "K_DECREASE", + "K_SCALEUP", "K_SCALEDOWN", "INT_T", "FLOAT_T", "OBJECT_T", + "NUMBER_T", "NAME", "EXTNAME", "VARIABLE", "NUMBER", + "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION", + "K_EFFECT", "DOMAIN", "DOMAIN_NAME", "REQUIREMENTS", + "TYPES", "EITHER_TYPE", "CONSTANTS", "FUNCTIONS", + "FREE_FUNCTIONS", "PREDICATES", "ACTION", "CONSTRAINT", + "EVENT", "GLOBAL_CONSTRAINT", "DURATIVE_ACTION", "PROBLEM", + "PROBLEM_NAME", "PROBLEM_DOMAIN", "OBJECTS", "INIT", + "FUNC_HEAD", "PRECONDITION", "EFFECT", "AND_GD", "OR_GD", + "NOT_GD", "IMPLY_GD", "EXISTS_GD", "FORALL_GD", "COMPARISON_GD", + "AND_EFFECT", "FORALL_EFFECT", "WHEN_EFFECT", "ASSIGN_EFFECT", + "NOT_EFFECT", "PRED_HEAD", "GOAL", "BINARY_OP", "EQUALITY_CON", + "MULTI_OP", "MINUS_OP", "UNARY_MINUS", "INIT_EQ", + "INIT_AT", "NOT_PRED_INIT", "PRED_INST", "PROBLEM_CONSTRAINT", "PROBLEM_METRIC" ] RULE_pddlDoc = 0 @@ -519,24 +519,24 @@ class fstripsParser ( Parser ): RULE_stageCost = 64 RULE_conGD = 65 - ruleNames = [ "pddlDoc", "domain", "domainName", "requireDef", "declaration_of_types", - "numericBuiltinType", "builtinType", "possibly_typed_name_list", - "name_list_with_type", "possibly_typed_type_list", "possibly_typed_variable_list", - "variable_list_with_type", "typename", "primitive_type", - "function_definition_block", "single_function_definition", - "typed_function_definition", "untyped_function_definition", - "logical_symbol_name", "constant_declaration", "predicate_definition_block", - "single_predicate_definition", "predicate", "function_name", - "structureDef", "actionDef", "constraintDef", "eventDef", - "actionName", "constraintSymbol", "eventSymbol", "actionDefBody", - "precondition", "goalDesc", "atomicTermFormula", "term", - "functionTerm", "derivedDef", "effect", "single_effect", - "atomic_effect", "builtin_binary_function", "builtin_unary_function", - "builtin_binary_predicate", "assignOp", "problem", "problemMeta", - "problemDecl", "problemDomain", "object_declaration", - "boundsDecl", "typeBoundsDefinition", "init", "init_element", - "flat_term", "flat_atom", "constant_name", "goal", "probConstraints", - "prefConGD", "metricSpec", "optimization", "metricFExp", + ruleNames = [ "pddlDoc", "domain", "domainName", "requireDef", "declaration_of_types", + "numericBuiltinType", "builtinType", "possibly_typed_name_list", + "name_list_with_type", "possibly_typed_type_list", "possibly_typed_variable_list", + "variable_list_with_type", "typename", "primitive_type", + "function_definition_block", "single_function_definition", + "typed_function_definition", "untyped_function_definition", + "logical_symbol_name", "constant_declaration", "predicate_definition_block", + "single_predicate_definition", "predicate", "function_name", + "structureDef", "actionDef", "constraintDef", "eventDef", + "actionName", "constraintSymbol", "eventSymbol", "actionDefBody", + "precondition", "goalDesc", "atomicTermFormula", "term", + "functionTerm", "derivedDef", "effect", "single_effect", + "atomic_effect", "builtin_binary_function", "builtin_unary_function", + "builtin_binary_predicate", "assignOp", "problem", "problemMeta", + "problemDecl", "problemDomain", "object_declaration", + "boundsDecl", "typeBoundsDefinition", "init", "init_element", + "flat_term", "flat_atom", "constant_name", "goal", "probConstraints", + "prefConGD", "metricSpec", "optimization", "metricFExp", "terminalCost", "stageCost", "conGD" ] EOF = Token.EOF @@ -1056,7 +1056,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_numericBuiltinType - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -1182,7 +1182,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_builtinType - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -1279,7 +1279,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_possibly_typed_name_list - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -1373,7 +1373,7 @@ def possibly_typed_name_list(self): elif la_ == 2: localctx = fstripsParser.ComplexNameListContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 198 + self.state = 198 self._errHandler.sync(self) _alt = 1 while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: @@ -1383,7 +1383,7 @@ def possibly_typed_name_list(self): else: raise NoViableAltException(self) - self.state = 200 + self.state = 200 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,11,self._ctx) @@ -1451,13 +1451,13 @@ def name_list_with_type(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 211 + self.state = 211 self._errHandler.sync(self) _la = self._input.LA(1) while True: self.state = 210 self.match(fstripsParser.NAME) - self.state = 213 + self.state = 213 self._errHandler.sync(self) _la = self._input.LA(1) if not (_la==fstripsParser.NAME): @@ -1485,7 +1485,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_possibly_typed_type_list - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -1594,7 +1594,7 @@ def possibly_typed_type_list(self): elif la_ == 2: localctx = fstripsParser.TypedTypenameListContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 225 + self.state = 225 self._errHandler.sync(self) _alt = 1 while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: @@ -1604,7 +1604,7 @@ def possibly_typed_type_list(self): else: raise NoViableAltException(self) - self.state = 227 + self.state = 227 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,16,self._ctx) @@ -1644,7 +1644,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_possibly_typed_variable_list - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -1738,7 +1738,7 @@ def possibly_typed_variable_list(self): elif la_ == 2: localctx = fstripsParser.TypedVariableListContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 244 + self.state = 244 self._errHandler.sync(self) _alt = 1 while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: @@ -1748,7 +1748,7 @@ def possibly_typed_variable_list(self): else: raise NoViableAltException(self) - self.state = 246 + self.state = 246 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,20,self._ctx) @@ -1816,13 +1816,13 @@ def variable_list_with_type(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 257 + self.state = 257 self._errHandler.sync(self) _la = self._input.LA(1) while True: self.state = 256 self.match(fstripsParser.VARIABLE) - self.state = 259 + self.state = 259 self._errHandler.sync(self) _la = self._input.LA(1) if not (_la==fstripsParser.VARIABLE): @@ -1850,7 +1850,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_typename - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -1926,13 +1926,13 @@ def typename(self): self.match(fstripsParser.T__0) self.state = 265 self.match(fstripsParser.T__7) - self.state = 267 + self.state = 267 self._errHandler.sync(self) _la = self._input.LA(1) while True: self.state = 266 self.primitive_type() - self.state = 269 + self.state = 269 self._errHandler.sync(self) _la = self._input.LA(1) if not (((((_la - 79)) & ~0x3f) == 0 and ((1 << (_la - 79)) & ((1 << (fstripsParser.INT_T - 79)) | (1 << (fstripsParser.FLOAT_T - 79)) | (1 << (fstripsParser.OBJECT_T - 79)) | (1 << (fstripsParser.NUMBER_T - 79)) | (1 << (fstripsParser.NAME - 79)))) != 0)): @@ -3106,7 +3106,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_precondition - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -3203,7 +3203,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_goalDesc - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -3669,7 +3669,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_term - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -3854,7 +3854,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_functionTerm - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -4095,7 +4095,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_effect - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -4213,7 +4213,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_single_effect - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -4438,7 +4438,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_atomic_effect - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -5213,13 +5213,13 @@ def boundsDecl(self): self.match(fstripsParser.T__0) self.state = 600 self.match(fstripsParser.T__41) - self.state = 602 + self.state = 602 self._errHandler.sync(self) _la = self._input.LA(1) while True: self.state = 601 self.typeBoundsDefinition() - self.state = 604 + self.state = 604 self._errHandler.sync(self) _la = self._input.LA(1) if not (_la==fstripsParser.T__0): @@ -5385,7 +5385,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_init_element - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -5670,7 +5670,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_constant_name - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -5870,7 +5870,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_prefConGD - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -6065,7 +6065,7 @@ def prefConGD(self): elif la_ == 4: localctx = fstripsParser.PlainConstraintListContext(self, localctx) self.enterOuterAlt(localctx, 4) - self.state = 703 + self.state = 703 self._errHandler.sync(self) _alt = 1 while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: @@ -6075,7 +6075,7 @@ def prefConGD(self): else: raise NoViableAltException(self) - self.state = 705 + self.state = 705 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,56,self._ctx) @@ -6100,7 +6100,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_metricSpec - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -6219,7 +6219,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_metricFExp - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -6501,7 +6501,7 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): def getRuleIndex(self): return fstripsParser.RULE_conGD - + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) @@ -6879,13 +6879,13 @@ def conGD(self): self.match(fstripsParser.T__0) self.state = 742 self.match(fstripsParser.K_AND) - self.state = 744 + self.state = 744 self._errHandler.sync(self) _la = self._input.LA(1) while True: self.state = 743 self.conGD() - self.state = 746 + self.state = 746 self._errHandler.sync(self) _la = self._input.LA(1) if not (_la==fstripsParser.T__0): @@ -7075,8 +7075,3 @@ def conGD(self): finally: self.exitRule() return localctx - - - - - diff --git a/src/tarski/io/_fstrips/parser/visitor.py b/src/tarski/io/_fstrips/parser/visitor.py index e61738b6..d35389b5 100644 --- a/src/tarski/io/_fstrips/parser/visitor.py +++ b/src/tarski/io/_fstrips/parser/visitor.py @@ -585,4 +585,4 @@ def visitAlternativeAlwaysConstraint(self, ctx:fstripsParser.AlternativeAlwaysCo -del fstripsParser \ No newline at end of file +del fstripsParser diff --git a/src/tarski/io/_fstrips/reader.py b/src/tarski/io/_fstrips/reader.py index 4b9b7e81..7b46da5b 100644 --- a/src/tarski/io/_fstrips/reader.py +++ b/src/tarski/io/_fstrips/reader.py @@ -197,7 +197,7 @@ def visitActionDefBody(self, ctx): return prec, effs def visitTrivialPrecondition(self, ctx): - return Tautology() + return Tautology(self.problem.language) def visitRegularPrecondition(self, ctx): return self.visit(ctx.goalDesc()) @@ -256,7 +256,7 @@ def visitAndGoalDesc(self, ctx): # The PDDL spec allows for and AND with zero or a single conjunct (e.g. (and p), which Tarski does (rightly) not # We thus treat those cases specially. if len(conjuncts) == 0: - return Tautology() + return Tautology(self.problem.language) elif len(conjuncts) == 1: return conjuncts[0] return CompoundFormula(Connective.And, conjuncts) diff --git a/src/tarski/io/fstrips.py b/src/tarski/io/fstrips.py index 9d66f50e..fd66a5d8 100644 --- a/src/tarski/io/fstrips.py +++ b/src/tarski/io/fstrips.py @@ -6,7 +6,7 @@ from ..theories import load_theory, Theory from .common import load_tpl from ..model import ExtensionalFunctionDefinition -from ..syntax import Tautology, Contradiction, Atom, CompoundTerm, CompoundFormula, QuantifiedFormula, \ +from ..syntax import Pass, Tautology, Contradiction, Atom, CompoundTerm, CompoundFormula, QuantifiedFormula, \ Term, Variable, Constant, Formula, symref, BuiltinPredicateSymbol from ..syntax.sorts import parent, Interval, ancestors @@ -308,6 +308,8 @@ def print_formula(formula, indentation=0): assert isinstance(formula, Formula) if isinstance(formula, Tautology): return "(and )" + if isinstance(formula, Pass): + return "(and )" elif isinstance(formula, Contradiction): return "(= 0 1)" # PDDL HACK =) elif isinstance(formula, Atom): @@ -354,7 +356,7 @@ def print_unconditional_effect(eff, indentation=0): def print_effect(eff, indentation=0): - conditional = not isinstance(eff.condition, Tautology) + conditional = not isinstance(eff.condition, (Tautology, Pass)) if conditional: return indent( diff --git a/src/tarski/io/rddl.py b/src/tarski/io/rddl.py index 4bf2b7a5..a686c93f 100644 --- a/src/tarski/io/rddl.py +++ b/src/tarski/io/rddl.py @@ -8,7 +8,7 @@ from ..fol import FirstOrderLanguage from ..syntax import implies, land, lor, neg, Connective, Quantifier, CompoundTerm, Interval, Atom, IfThenElse, \ Contradiction, Tautology, CompoundFormula, forall, ite, AggregateCompoundTerm, QuantifiedFormula, Term, Function, \ - Variable, Predicate, Constant, Formula, builtins + Variable, Predicate, Constant, Formula, builtins, FormulaTerm from ..syntax import arithmetic as tm from ..syntax.temporal import ltl as tt from ..syntax.builtins import create_atom, BuiltinPredicateSymbol as BPS, BuiltinFunctionSymbol as BFS @@ -317,6 +317,7 @@ class Requirements(Enum): CONTINUOUS = "continuous" MULTIVALUED = "multivalued" REWARD_DET = "reward-deterministic" + PRECONDITIONS = "preconditions" INTERMEDIATE_NODES = "intermediate-nodes" PARTIALLY_OBS = "partially-observed" CONCURRENT = "concurrent" @@ -415,7 +416,30 @@ def __init__(self, task): self.non_fluent_signatures = set() self.interm_signatures = set() - def write_model(self, filename): + def rddl_2018_format(self): + tpl = load_tpl("rddl_model_2018.tpl") + domain_content = tpl.format( + domain_name=self.task.domain_name, + req_list=self.get_requirements(), + type_list=self.get_types(), + pvar_list=self.get_pvars(), + cpfs_list=self.get_cpfs(), + reward_expr=self.get_reward(), + action_precondition_list=self.get_preconditions(), + state_invariant_list=self.get_state_invariants(), + domain_non_fluents='{}_non_fluents'.format(self.task.instance_name), + object_list=self.get_objects(), + non_fluent_expr=self.get_non_fluent_init(), + instance_name=self.task.instance_name, + init_state_fluent_expr=self.get_state_fluent_init(), + non_fluents_ref='{}_non_fluents'.format(self.task.instance_name), + max_nondef_actions=self.get_max_nondef_actions(), + horizon=self.get_horizon(), + discount=self.get_discount() + ) + return domain_content + + def rddl_pre_2018_format(self): tpl = load_tpl("rddl_model.tpl") content = tpl.format( domain_name=self.task.domain_name, @@ -436,8 +460,23 @@ def write_model(self, filename): horizon=self.get_horizon(), discount=self.get_discount() ) + return content + + + def write_model(self, filename, format_2018_style=False): with open(filename, 'w') as file: + if format_2018_style: + content = self.rddl_2018_format() + else: + content = self.rddl_pre_2018_format() file.write(content) + self.reset() + + def reset(self): + self.need_obj_decl = [] + self.need_constraints = {} + self.non_fluent_signatures = set() + self.interm_signatures = set() def get_requirements(self): return ', '.join([str(r) for r in self.task.requirements]) @@ -610,7 +649,9 @@ def rewrite(self, expr): if len(re_st) > 0: # MRJ: Random variables need parenthesis, other functions need # brackets... - if expr.symbol.symbol in builtins.get_random_binary_functions(): + relevant_functions = builtins.get_random_binary_functions() + relevant_functions += builtins.get_random_unary_functions() + if expr.symbol.symbol in relevant_functions: st_str = '({})'.format(','.join(re_st)) else: st_str = '[{}]'.format(','.join(re_st)) @@ -645,13 +686,13 @@ def rewrite(self, expr): re_sf = [self.rewrite(st) for st in expr.subformulas] re_sym = symbol_map[expr.connective] if len(re_sf) == 1: - return '{}{}'.format(re_sym, re_sf) + return '{}{}'.format(re_sym, re_sf[0]) return '({} {} {})'.format(re_sf[0], re_sym, re_sf[1]) elif isinstance(expr, QuantifiedFormula): re_f = self.rewrite(expr.formula) re_vars = ['?{} : {}'.format(x.symbol, x.sort.name) for x in expr.variables] re_sym = symbol_map[expr.quantifier] - return '{}_{{{}}} ({})'.format(re_sym, ','.join(re_vars), re_f) + return '{}_{{{}}} [{}]'.format(re_sym, ','.join(re_vars), re_f) elif isinstance(expr, AggregateCompoundTerm): re_expr = self.rewrite(expr.subterm) re_vars = ['?{} : {}'.format(x.symbol, x.sort.name) for x in expr.bound_vars] @@ -659,7 +700,9 @@ def rewrite(self, expr): re_sym = 'sum' elif expr.symbol == BFS.MUL: re_sym = 'prod' - return '{}_{{{}}} ({})'.format(re_sym, ','.join(re_vars), re_expr) + return '{}_{{{}}} [{}]'.format(re_sym, ','.join(re_vars), re_expr) + elif isinstance(expr, FormulaTerm): + return self.rewrite(expr.formula) raise RuntimeError(f"Unknown expression type for '{expr}'") @staticmethod diff --git a/src/tarski/io/templates/rddl_model_2018.tpl b/src/tarski/io/templates/rddl_model_2018.tpl new file mode 100644 index 00000000..54a8bf6a --- /dev/null +++ b/src/tarski/io/templates/rddl_model_2018.tpl @@ -0,0 +1,37 @@ +domain {domain_name} {{ + + requirements {{ {req_list} }}; + + types {{ {type_list} }}; + + pvariables {{ +{pvar_list} + }}; + + cpfs {{ +{cpfs_list} + }}; + + reward = {reward_expr}; + + action-preconditions {{ +{action_precondition_list} + }}; + +}} + +instance {instance_name} {{ + + domain = {domain_name}; + + {object_list} + + {non_fluent_expr} + + init-state {{ + {init_state_fluent_expr} + }}; + + horizon = {horizon}; + discount = {discount}; +}} \ No newline at end of file diff --git a/src/tarski/reachability/asp.py b/src/tarski/reachability/asp.py index 421e65c4..13fa4a41 100644 --- a/src/tarski/reachability/asp.py +++ b/src/tarski/reachability/asp.py @@ -7,7 +7,7 @@ from ..syntax.transform import remove_quantifiers, QuantifierEliminationMode from ..syntax.builtins import symbol_complements from ..syntax.ops import free_variables -from ..syntax import Formula, Atom, CompoundFormula, Connective, Term, Variable, Constant, Tautology, \ +from ..syntax import Formula, Atom, CompoundFormula, Connective, Term, Variable, Constant, Tautology, Pass, \ BuiltinPredicateSymbol, QuantifiedFormula, Quantifier, CompoundTerm from ..syntax.sorts import parent, Interval from ..fstrips import Problem, SingleEffect, AddEffect, DelEffect, FunctionalEffect @@ -156,7 +156,7 @@ def process_formula(self, f: Formula): """ Process a given formula and return the corresponding LP rule body, along with declaring in the given LP any number of extra rules necessary to ensure equivalence of the body with the truth value of the formula. """ - if isinstance(f, Tautology): + if isinstance(f, (Tautology, Pass)): return [] elif isinstance(f, Atom): diff --git a/src/tarski/syntax/__init__.py b/src/tarski/syntax/__init__.py index 2db60cfc..d0eed037 100644 --- a/src/tarski/syntax/__init__.py +++ b/src/tarski/syntax/__init__.py @@ -5,8 +5,8 @@ from .terms import Term, Constant, Variable, CompoundTerm, IfThenElse, ite, AggregateCompoundTerm from .util import termlists_are_equal, termlist_hash from .formulas import land, lor, neg, implies, forall, exists, equiv, Connective, Atom, Formula,\ - CompoundFormula, QuantifiedFormula, Tautology, Contradiction, top, bot, Quantifier, VariableBinding, \ - is_neg, is_and, is_or + CompoundFormula, QuantifiedFormula, Pass, Tautology, Contradiction, Quantifier, VariableBinding, \ + is_neg, is_and, is_or, FormulaTerm, top from .builtins import BuiltinFunctionSymbol, BuiltinPredicateSymbol from .symrefs import symref from .transform.substitutions import create_substitution, substitute_expression diff --git a/src/tarski/syntax/arithmetic/__init__.py b/src/tarski/syntax/arithmetic/__init__.py index 3f95d88e..d62f5f0b 100644 --- a/src/tarski/syntax/arithmetic/__init__.py +++ b/src/tarski/syntax/arithmetic/__init__.py @@ -4,7 +4,8 @@ import copy from ..transform.substitutions import substitute_expression -from ...syntax import Term, AggregateCompoundTerm, CompoundTerm, Constant, Variable, IfThenElse, create_substitution +from ...syntax import Term, AggregateCompoundTerm, CompoundTerm, Constant, \ + Variable, IfThenElse, create_substitution, Formula from ...syntax.algebra import Matrix from ... import errors as err from ... grounding.naive import instantiation @@ -22,7 +23,7 @@ def sumterm(*args): if not isinstance(x, Variable): raise err.SyntacticError(msg='sum(x0,...,xn,expr) require each\ argument xi to be an instance of Variable') - if not isinstance(expr, Term): + if not isinstance(expr, (Term, Formula)): raise err.SyntacticError(msg='sum(x0,x1,...,xn,expr) requires last \ argument "expr" to be an instance of Term, got "{}"'.format(expr)) return AggregateCompoundTerm(BuiltinFunctionSymbol.ADD, variables, expr) diff --git a/src/tarski/syntax/arithmetic/random.py b/src/tarski/syntax/arithmetic/random.py index 6b508dae..756b435d 100644 --- a/src/tarski/syntax/arithmetic/random.py +++ b/src/tarski/syntax/arithmetic/random.py @@ -25,3 +25,11 @@ def gamma(shape, scale): np = modules.import_numpy() return np.random.gamma(shape, scale) return gamma_func(shape, scale) + +def bernoulli(p): + try: + bernoulli_func = p.language.get_function(bfs.BERNOULLI) + except AttributeError: + np = modules.import_numpy() + return np.random.random(p) + return bernoulli_func(p) diff --git a/src/tarski/syntax/builtins.py b/src/tarski/syntax/builtins.py index b2cd6828..d0259d69 100644 --- a/src/tarski/syntax/builtins.py +++ b/src/tarski/syntax/builtins.py @@ -1,7 +1,7 @@ from enum import Enum # A table with the negated counterparts of builtin predicates. -symbol_complements = {"=": "!=", "!=": "=", "<": ">=", "<=": ">", ">": "<=", ">=": "<"} +symbol_complements = {"=": "!=", "!=": "=", "<": ">=", "<=": ">", ">": "<=", ">=": "<", "&": "|", "|":"&"} class BuiltinPredicateSymbol(Enum): @@ -11,6 +11,8 @@ class BuiltinPredicateSymbol(Enum): LE = "<=" GT = ">" GE = ">=" + AND = "&" + OR = "|" def __str__(self): return self.value.lower() @@ -77,6 +79,9 @@ def negate_builtin_atom(atom): def is_builtin_function(fun): return isinstance(fun.symbol, BuiltinFunctionSymbol) +def get_boolean_predicates(): + return [BuiltinPredicateSymbol.AND, BuiltinPredicateSymbol.OR] + def get_equality_predicates(): return [BuiltinPredicateSymbol.EQ, BuiltinPredicateSymbol.NE] @@ -111,8 +116,8 @@ def get_random_binary_functions(): def get_random_unary_functions(): - # BFS = BuiltinFunctionSymbol - return [] + BFS = BuiltinFunctionSymbol + return [BFS.BERNOULLI] def get_predicate_from_symbol(symbol: str): diff --git a/src/tarski/syntax/factory.py b/src/tarski/syntax/factory.py index bc10c86b..9add4583 100644 --- a/src/tarski/syntax/factory.py +++ b/src/tarski/syntax/factory.py @@ -34,6 +34,7 @@ def create_arithmetic_term(symbol: BuiltinFunctionSymbol, lhs, rhs): # if lang.has_function(overload): # fun = lang.get_function(overload) # else: + fun = lang.get_function(symbol) return fun(lhs, rhs) diff --git a/src/tarski/syntax/formulas.py b/src/tarski/syntax/formulas.py index 3a3a6153..68a3a078 100644 --- a/src/tarski/syntax/formulas.py +++ b/src/tarski/syntax/formulas.py @@ -4,7 +4,7 @@ from typing import List from .. import errors as err -from .builtins import BuiltinPredicateSymbol +from .builtins import BuiltinPredicateSymbol, BuiltinFunctionSymbol from .terms import Variable, Term from .util import termlists_are_equal, termlist_hash from .predicate import Predicate @@ -45,47 +45,118 @@ def __or__(self, rhs): def __invert__(self): return neg(self) + def __eq__(self, rhs): + return self.language.dispatch_operator(BuiltinPredicateSymbol.EQ, Term, Term, self, rhs) + + def __ne__(self, rhs): + return self.language.dispatch_operator(BuiltinPredicateSymbol.NE, Term, Term, self, rhs) + + def __lt__(self, rhs): + return self.language.dispatch_operator(BuiltinPredicateSymbol.LT, Term, Term, self, rhs) + def __gt__(self, rhs): - return implies(self, rhs) + return self.language.dispatch_operator(BuiltinPredicateSymbol.GT, Term, Term, self, rhs) - def __eq__(self, other): - raise NotImplementedError() # To be subclassed + def __le__(self, rhs): + return self.language.dispatch_operator(BuiltinPredicateSymbol.LE, Term, Term, self, rhs) - def __hash__(self): - raise NotImplementedError() # To be subclassed + def __ge__(self, rhs): + return self.language.dispatch_operator(BuiltinPredicateSymbol.GE, Term, Term, self, rhs) + + def __add__(self, rhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.ADD, Term, Term, self, rhs) + + def __sub__(self, rhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.SUB, Term, Term, self, rhs) + + def __mul__(self, rhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.MUL, Term, Term, self, rhs) + + def __matmul__(self, rhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.MATMUL, Term, Term, self, rhs) + + def __truediv__(self, rhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.DIV, Term, Term, self, rhs) + + def __radd__(self, lhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.ADD, Term, Term, lhs, self) + + def __rsub__(self, lhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.SUB, Term, Term, lhs, self) + + def __rmul__(self, lhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.MUL, Term, Term, lhs, self) + + def __rtruediv__(self, lhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.DIV, Term, Term, lhs, self) + + def __pow__(self, rhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.POW, Term, Term, self, rhs) + + def __mod__(self, rhs): + return self.language.dispatch_operator(BuiltinFunctionSymbol.MOD, Term, Term, self, rhs) + + __hash__ = None # type: ignore def hash(self): - # Define a shortcut for uniformity with the Term class - return self.__hash__() + raise NotImplementedError() # To be subclassed def is_syntactically_equal(self, other): """ Return true if this formula and other are strictly syntactically equivalent. This is equivalent to self == other, but is provided for reasons of uniformity with the Term class. """ - return self.__eq__(other) + raise NotImplementedError + + def build_formulaterm(self): + return FormulaTerm(self) + +class Pass(Formula): + + def __str__(self): + return "T" + __repr__ = __str__ + + def is_syntactically_equal(self, other): + return self.__class__ is other.__class__ + + def hash(self): + return hash(self.__class__) + +top = Pass() class Tautology(Formula): + + def __init__(self, lang): + """Requires a FirstOrderLanguage to be constructed""" + super().__init__() + self.language = lang + def __str__(self): return "T" __repr__ = __str__ - def __eq__(self, other): + def is_syntactically_equal(self, other): return self.__class__ is other.__class__ - def __hash__(self): + def hash(self): return hash(self.__class__) class Contradiction(Formula): + def __init__(self, lang): + """Requires a FirstOrderLanguage to be constructed""" + super().__init__() + self.language = lang + def __str__(self): return "F" __repr__ = __str__ - def __eq__(self, other): + def is_syntactically_equal(self, other): return self.__class__ is other.__class__ - def __hash__(self): + def hash(self): return hash(self.__class__) @@ -94,6 +165,7 @@ class CompoundFormula(Formula): def __init__(self, connective, subformulas): super().__init__() + self.language = subformulas[0].language self.connective = connective self.subformulas = subformulas self._check_well_formed() @@ -117,18 +189,18 @@ def __str__(self): return "({})".format(inner) __repr__ = __str__ - def __eq__(self, other): + def is_syntactically_equal(self, other): return self.__class__ is other.__class__ and \ self.connective == other.connective and \ - self.subformulas == other.subformulas + termlists_are_equal(self.subformulas, other.subformulas) - def __hash__(self): + def hash(self): element_hashes = [self.__class__, self.connective] # TODO: formulas need to be flattened if we want to hash them, # it would be good to check if there is a better way of flattening # than this for phi in self.subformulas: - element_hashes += [hash(phi)] + element_hashes += [phi.hash()] return hash(tuple(element_hashes)) @@ -138,6 +210,7 @@ def __init__(self, quantifier: Quantifier, variables: List[Variable], formula: F self.variables = variables self.formula = formula self._check_well_formed() + self.language = formula.language def _check_well_formed(self): if len(self.variables) == 0: @@ -148,18 +221,14 @@ def __str__(self): return '{} {} : ({})'.format(self.quantifier, vars_, self.formula) __repr__ = __str__ - def __eq__(self, other): + def is_syntactically_equal(self, other): return self.__class__ is other.__class__ \ and self.quantifier == other.quantifier \ and termlists_are_equal(self.variables, other.variables) \ - and self.formula == other.formula - - def __hash__(self): - return hash((self.__class__, self.quantifier, termlist_hash(self.variables), self.formula)) - + and self.formula.is_syntactically_equal(other.formula) -top = Tautology() -bot = Contradiction() + def hash(self): + return hash((self.__class__, self.quantifier, termlist_hash(self.variables), self.formula.hash())) def _to_binary_tree(args, connective): @@ -177,11 +246,14 @@ def _create_compound(args, connective, flat): return CompoundFormula(connective, args) return _to_binary_tree(args, connective) - def land(*args, flat=False): """ Create an and-formula with the given subformulas. If binary is true, the and-formula will be shaped as a binary tree (e.g. (...((p1 and p2) and p3) and ...))), otherwise it will have a flat structure. This is an implementation detail, but might be relevant performance-wise when dealing with large structures """ + #todo: [John Peterson] --- this had originally allowed not giving + #it arguments (which returned a Tautology) --- not possible + #without a language, so we're making do here by removing it (see + #lor for equiv with Contradiction) if not args: return top return _create_compound(args, Connective.And, flat) @@ -191,8 +263,6 @@ def lor(*args, flat=False): """ Create an or-formula with the given subformulas. If binary is true, the or-formula will be shaped as a binary tree (e.g. (...((p1 or p2) or p3) or ...))), otherwise it will have a flat structure. This is an implementation detail, but might be relevant performance-wise when dealing with large structures """ - if not args: - return bot return _create_compound(args, Connective.Or, flat) @@ -295,6 +365,7 @@ def __init__(self, predicate, arguments): self.predicate = predicate self.subterms = arguments self._check_well_formed() + self.language = predicate.language @property def symbol(self): @@ -327,12 +398,12 @@ def __str__(self): return '{}({})'.format(self.predicate.symbol, ','.join(str(t) for t in self.subterms)) __repr__ = __str__ - def __eq__(self, other): + def is_syntactically_equal(self, other): return self.__class__ is other.__class__ and \ self.predicate == other.predicate and \ termlists_are_equal(self.subterms, other.subterms) - def __hash__(self): + def hash(self): return hash((self.__class__, self.predicate, termlist_hash(self.subterms))) @@ -388,3 +459,56 @@ def __iter__(self): def __str__(self): return f"Variables({','.join(map(str, self._v_values))})" __repr__ = __str__ + + +class FormulaTerm(Term): + """A Term wrapper for (boolean) formulas""" + def __init__(self, formula): + self.symbol = formula_symbol_extractor(formula) + self.formula = formula + self._sort = formula.language.Boolean + + @property + def language(self): + return self.formula.language + + @property + def sort(self): + return self._sort + + def __str__(self): + return str(self.formula) + + __repr__ = __str__ + + def hash(self): + #[John Peterson] todo: should this actually work this way? I'm + #not sure if we want these to hash the same as the underlying + #formula + return hash(self.formula) + + def is_syntactically_equal(self, other): + return self.formula.is_syntactically_equal(other.formula) + +def formula_symbol_extractor(f): + """Depending on the type of Formula, extracts and appropriate symbol for the FormulaTerm wrapper""" + if isinstance(f, CompoundFormula): + symbol = f.connective + elif isinstance(f, Atom): + symbol = f.predicate + elif isinstance(f, QuantifiedFormula): + symbol = f.quantifier + else: + raise NotImplementedError() #unknown formula type + return symbol + + +def formula_arity_extractor(f): + if isinstance(f, CompoundFormula): + if f.connective == Connective.Not: + arity = 1 + else: + arity = 2 + else: + raise NotImplementedError() #unknown formula type + return arity diff --git a/src/tarski/syntax/ops.py b/src/tarski/syntax/ops.py index f692ee9f..7ef9c293 100644 --- a/src/tarski/syntax/ops.py +++ b/src/tarski/syntax/ops.py @@ -5,7 +5,7 @@ from .sorts import children, compute_direct_sort_map, Interval from .visitors import CollectFreeVariables from .terms import Term, Constant, Variable -from .formulas import CompoundFormula, Connective +from .formulas import CompoundFormula, Connective, Formula from .symrefs import symref @@ -14,6 +14,12 @@ def cast_to_closest_common_numeric_ancestor(lang, lhs, rhs): applied to a 3 and Constant(2, Int), it should return Constant(3, Int), Constant(2, Int). Non-arithmetic objects should be left unchanged. """ + + if isinstance(lhs, Formula): + lhs = lang.generate_formula_term(lhs) + if isinstance(rhs, Formula): + rhs = lang.generate_formula_term(rhs) + if isinstance(lhs, Term) and isinstance(rhs, Term): return lhs, rhs diff --git a/src/tarski/syntax/sorts.py b/src/tarski/syntax/sorts.py index 40ef9bb2..0f38d8f0 100644 --- a/src/tarski/syntax/sorts.py +++ b/src/tarski/syntax/sorts.py @@ -155,7 +155,7 @@ def dump(self): return dict(name=self.name, domain=[self.lower_bound, self.upper_bound]) def domain(self): - if self.builtin or self.upper_bound - self.lower_bound > 9999: # Yes, very hacky + if self.upper_bound - self.lower_bound > 9999: # Yes, very hacky raise err.TarskiError(f'Cannot iterate over interval with range [{self.lower_bound}, {self.upper_bound}]') from . import Constant # pylint: disable=import-outside-toplevel # Avoiding circular references return (Constant(x, self) for x in range(self.lower_bound, self.upper_bound+1)) @@ -199,17 +199,6 @@ def float_encode_fn(x): return float(x) -def build_the_bools(lang): - bools = lang.sort('Boolean') - # TODO: we really should be setting builtin to True, but at the moment this is undesirable, as in many places in - # the code we seem to assume that "builtin" sorts are kind of "numeric" sorts, which leads us to try to do - # things with the new Bool sort that cannot be done, e.g. to cast string object "True" to a value, etc. - # bools.builtin = True - lang.constant('True', bools) - lang.constant('False', bools) - return bools - - def build_the_naturals(lang): the_nats = Interval('Natural', lang, int_encode_fn, 0, 2 ** 32 - 1, builtin=True) the_nats.builtin = True @@ -227,12 +216,22 @@ def build_the_reals(lang): reals.builtin = True return reals +def build_the_bools(lang): + bools = Interval('Boolean', lang, int_encode_fn, 0, 1, builtin=True) + bools.builtin = True + return bools + +def attach_the_non_arithmetic_bools(lang): + _ = lang.attach_sort(build_the_bools(lang), lang.ns.object) def attach_arithmetic_sorts(lang): real_t = lang.attach_sort(build_the_reals(lang), lang.ns.object) int_t = lang.attach_sort(build_the_integers(lang), real_t) - _ = lang.attach_sort(build_the_naturals(lang), int_t) - + nat_t = lang.attach_sort(build_the_naturals(lang), int_t) + if not lang.has_sort("Boolean"): + _ = lang.attach_sort(build_the_bools(lang), nat_t) + else: + lang.change_parent(lang.get_sort("Boolean"), nat_t) def compute_signature_bindings(signature): """ Return an exhaustive list of all possible bindings compatible with the given signature, i.e. diff --git a/src/tarski/syntax/terms.py b/src/tarski/syntax/terms.py index b7102cd9..48ed6647 100644 --- a/src/tarski/syntax/terms.py +++ b/src/tarski/syntax/terms.py @@ -109,13 +109,13 @@ def __divmod__(self, rhs): return self.language.dispatch_operator('divmod', Term, Term, self, rhs) def __and__(self, rhs): - return self.language.dispatch_operator('&', Term, Term, self, rhs) + return self.language.dispatch_operator(BuiltinPredicateSymbol.AND, Term, Term, self, rhs) def __xor__(self, rhs): return self.language.dispatch_operator('^', Term, Term, self, rhs) def __or__(self, rhs): - return self.language.dispatch_operator('|', Term, Term, self, rhs) + return self.language.dispatch_operator(BuiltinPredicateSymbol.OR, Term, Term, self, rhs) def is_syntactically_equal(self, other): """ Return true if this term and other are strictly syntactically equivalent. @@ -222,6 +222,8 @@ class AggregateCompoundTerm(Term): def __init__(self, operator, bound_vars, subterm: Term): self.symbol = operator self.bound_vars = bound_vars + if not isinstance(subterm, Term): + subterm = subterm.build_formulaterm() self.subterm = subterm # TODO: type checking? @property @@ -261,6 +263,12 @@ def __init__(self, condition, subterms: Tuple[Term, Term]): self.symbol = subterms[0].language.get('ite') self.condition = condition + #if either of the subterms are boolean Formulae, wrap them as Terms + if not isinstance(subterms[0], Term): + subterms = (subterms[0].build_formulaterm(), subterms[1]) + if not isinstance(subterms[1], Term): + subterms = (subterms[0], subterms[1].build_formulaterm()) + # Our implementation of ite requires both branches to have equal sort if subterms[0].sort != subterms[1].sort: if parent(subterms[0].sort) == subterms[1].sort: @@ -290,7 +298,7 @@ def __str__(self): __repr__ = __str__ def hash(self): - return hash(('ite', self.condition, termlist_hash(self.subterms))) + return hash(('ite', self.condition.hash(), termlist_hash(self.subterms))) def is_syntactically_equal(self, other): return self.__class__ is other.__class__ and \ @@ -302,7 +310,6 @@ def is_syntactically_equal(self, other): def ite(c, t1: Term, t2: Term): return IfThenElse(c, (t1, t2)) - class Constant(Term): def __init__(self, name, sort: Sort): self.name = name diff --git a/src/tarski/syntax/transform/quantifier_elimination.py b/src/tarski/syntax/transform/quantifier_elimination.py index be8b4f5a..961bc2eb 100644 --- a/src/tarski/syntax/transform/quantifier_elimination.py +++ b/src/tarski/syntax/transform/quantifier_elimination.py @@ -7,7 +7,7 @@ from ... import errors as err from .substitutions import create_substitution, substitute_expression -from ..formulas import land, lor, Quantifier, QuantifiedFormula, Atom, Tautology, Contradiction, CompoundFormula +from ..formulas import land, lor, Quantifier, QuantifiedFormula, Atom, Tautology, Contradiction, CompoundFormula, Pass from .errors import TransformationError @@ -37,7 +37,7 @@ def _eliminate_exists(self): return self.mode in (QuantifierEliminationMode.All, QuantifierEliminationMode.Exists) def _convert(self, phi): - if isinstance(phi, (Atom, Tautology, Contradiction)): + if isinstance(phi, (Atom, Tautology, Contradiction, Pass)): return phi # Already quantifier-free if isinstance(phi, CompoundFormula): diff --git a/src/tarski/syntax/walker.py b/src/tarski/syntax/walker.py index a1fae2ac..6cbda024 100644 --- a/src/tarski/syntax/walker.py +++ b/src/tarski/syntax/walker.py @@ -58,11 +58,11 @@ def run(self, expression, inplace=True): def visit_expression(self, node, inplace=True): # pylint: disable=import-outside-toplevel # Avoiding circular references - from .formulas import CompoundFormula, QuantifiedFormula, Atom, Tautology, Contradiction + from .formulas import CompoundFormula, QuantifiedFormula, Atom, Tautology, Contradiction, Pass from .terms import Constant, Variable, CompoundTerm, IfThenElse # pylint: disable=import-outside-toplevel node = node if inplace else copy.deepcopy(node) - if isinstance(node, (Variable, Constant, Contradiction, Tautology)): + if isinstance(node, (Variable, Constant, Contradiction, Tautology, Pass)): pass elif isinstance(node, (CompoundTerm, Atom)): diff --git a/src/tarski/theories.py b/src/tarski/theories.py index 0c53d95b..9427d460 100644 --- a/src/tarski/theories.py +++ b/src/tarski/theories.py @@ -3,7 +3,7 @@ from typing import Union, List, Optional from tarski.errors import DuplicateTheoryDefinition -from .syntax.sorts import attach_arithmetic_sorts, build_the_bools +from .syntax.sorts import attach_arithmetic_sorts, attach_the_non_arithmetic_bools from .fol import FirstOrderLanguage from .syntax import builtins, Term from .syntax.factory import create_atom, create_arithmetic_term @@ -36,7 +36,13 @@ def language(name='L', theories: Optional[List[Union[str, Theory]]] = None): """ theories = theories or [] lang = FirstOrderLanguage(name) - _ = [load_theory(lang, t) for t in theories] + + #todo: [John Peterson] would like to either do this differently, or + #eliminate the boolean theory alltogether + load_theory(lang, Theory.BOOLEAN) + for t in theories: + if t not in ("boolean", Theory.BOOLEAN): + load_theory(lang,t) return lang @@ -72,7 +78,12 @@ def has_theory(lang, theory: Union[Theory, str]): def load_bool_theory(lang): - build_the_bools(lang) + if not lang.has_sort("Boolean"): + attach_the_non_arithmetic_bools(lang) + for pred in builtins.get_boolean_predicates(): + lang.register_operator_handler(pred, Term, Term, create_casting_handler(lang, pred, create_atom)) + p = lang.predicate(pred, lang.Boolean, lang.Boolean) + p.builtin = True def load_equality_theory(lang): @@ -129,7 +140,7 @@ def load_random_theory(lang): f.builtin = True for fun in builtins.get_random_unary_functions(): lang.register_unary_operator_handler(fun, Term, create_casting_handler(lang, fun, create_arithmetic_term)) - f = lang.function(fun, lang.Real, lang.Real) + f = lang.function(fun, lang.Real, lang.Boolean) f.builtin = True diff --git a/src/tarski/utils/command.py b/src/tarski/utils/command.py index d9a4f417..ad3b9c7a 100644 --- a/src/tarski/utils/command.py +++ b/src/tarski/utils/command.py @@ -70,7 +70,11 @@ def silentremove(filename): libc = ctypes.CDLL(None) -c_stdout = ctypes.c_void_p.in_dll(libc, 'stdout') +if sys.platform == "darwin": + stdout_name = '__stdoutp' +elif sys.platform == "linux": + stdout_name = 'stdout' +c_stdout = ctypes.c_void_p.in_dll(libc, stdout_name) @contextmanager diff --git a/tests/common/gridworld.py b/tests/common/gridworld.py index 193f71f3..85df9f6b 100644 --- a/tests/common/gridworld.py +++ b/tests/common/gridworld.py @@ -15,7 +15,7 @@ def generate_small_gridworld(): ypos = lang.function('Y', coord_t) problem.action(name='move-up', parameters=[], - precondition=Tautology(), + precondition=Tautology(lang), # effects=[fs.FunctionalEffect(ypos(), ypos() + 1)]) effects=[ypos() << ypos() + 1]) diff --git a/tests/fol/test_fol_accessors.py b/tests/fol/test_fol_accessors.py index dffe7877..b6b2f9c7 100644 --- a/tests/fol/test_fol_accessors.py +++ b/tests/fol/test_fol_accessors.py @@ -1,10 +1,10 @@ import pytest -from tarski import FirstOrderLanguage, errors +from tarski import FirstOrderLanguage, errors, language def test_namespace_accessor(): - lang = FirstOrderLanguage() + lang = language() # Let's test different ways of accessing the same things assert lang.get_sort('object') == lang.get("object") == lang.ns.object diff --git a/tests/fol/test_sorts.py b/tests/fol/test_sorts.py index 94d498ad..351f61c4 100644 --- a/tests/fol/test_sorts.py +++ b/tests/fol/test_sorts.py @@ -6,7 +6,7 @@ from tarski.benchmarks.counters import generate_fstrips_counters_problem from tarski.syntax import symref from tarski.syntax.ops import compute_sort_id_assignment -from tarski.syntax.sorts import parent, ancestors, compute_signature_bindings, compute_direct_sort_map +from tarski.syntax.sorts import parent, ancestors, compute_signature_bindings, compute_direct_sort_map, Interval from tarski.theories import Theory @@ -201,3 +201,46 @@ def test_sort_domain_retrieval(): with pytest.raises(err.TarskiError): lang.Integer.domain() # Domain too large to iterate over it + + +def test_boolean_sort_no_arithmetic_theory(): + lang = tarski.language(theories=[]) + assert lang.Boolean in lang.sorts, 'the Boolean sort should always be attached' + assert lang.Object in lang.sorts, 'the object sort should be the only other sort' + assert len(lang.sorts) == 2, 'the Boolean and Object sorts should be the only sorts' + + assert isinstance(lang.Boolean, Interval), 'the Boolean sort is an interval' + assert lang.Boolean.cardinality() == 2, 'Boolean sort is of cardinality 2' + assert lang.Boolean.builtin == True + assert parent(lang.Boolean) == lang.Object, 'since there are no numeric sorts, the Boolean sort has parent Object' + +def test_ensure_boolean_theory_does_nothing_no_arithmetic_theory(): + lang = tarski.language(theories=[Theory.BOOLEAN]) + assert lang.Boolean in lang.sorts, 'the Boolean sort should always be attached' + assert lang.Object in lang.sorts, 'the object sort should be the only other sort' + assert len(lang.sorts) == 2, 'the Boolean and Object sorts should be the only sorts' + + assert isinstance(lang.Boolean, Interval), 'the Boolean sort is an interval' + assert lang.Boolean.cardinality() == 2, 'Boolean sort is of cardinality 2' + assert lang.Boolean.builtin == True + assert parent(lang.Boolean) == lang.Object, 'since there are no numeric sorts, the Boolean sort has parent Object' + +def test_boolean_sort_with_arithmetic_theory(): + lang = tarski.language(theories=[Theory.ARITHMETIC]) + assert lang.Boolean in lang.sorts, 'the Boolean sort should always be attached' + assert len(lang.sorts) == 5, 'sorts Boolean, Integer, Naturals, Reals and Object attached' + + assert isinstance(lang.Boolean, Interval), 'the Boolean sort is an interval' + assert lang.Boolean.cardinality() == 2, 'Boolean sort is of cardinality 2' + assert lang.Boolean.builtin == True + assert parent(lang.Boolean) == lang.Natural, 'when there are numeric sorts, the parent of Boolean is Naturals' + +def test_ensure_boolean_theory_does_nothing_with_arithmetic_theory(): + lang = tarski.language(theories=[Theory.BOOLEAN, Theory.ARITHMETIC]) + assert lang.Boolean in lang.sorts, 'the Boolean sort should always be attached' + assert len(lang.sorts) == 5, 'sorts Boolean, Integer, Naturals, Reals and Object attached' + + assert isinstance(lang.Boolean, Interval), 'the Boolean sort is an interval' + assert lang.Boolean.cardinality() == 2, 'Boolean sort is of cardinality 2' + assert lang.Boolean.builtin == True + assert parent(lang.Boolean) == lang.Natural, 'when there are numeric sorts, the parent of Boolean is Naturals' diff --git a/tests/fol/test_syntax.py b/tests/fol/test_syntax.py index 6b9a3957..f427967b 100755 --- a/tests/fol/test_syntax.py +++ b/tests/fol/test_syntax.py @@ -6,7 +6,7 @@ from tarski import theories, Term, Constant from tarski.fstrips import fstrips from tarski.syntax import symref, CompoundFormula, Atom, ite, AggregateCompoundTerm, CompoundTerm, lor, Tautology, \ - Contradiction, land, top, bot + Contradiction, land from tarski.theories import Theory from tarski import errors as err from tarski import fstrips as fs @@ -18,15 +18,15 @@ def test_language_creation(): lang = theories.language() sorts = sorted(x.name for x in lang.sorts) - assert sorts == ['object'] + assert set(sorts) == set(['object', 'Boolean']) lang = fstrips.language("test", theories=[]) sorts = sorted(x.name for x in lang.sorts) - assert sorts == ['object'] + assert set(sorts) == set(['object', 'Boolean']) lang = fstrips.language("test") sorts = sorted(x.name for x in lang.sorts) - assert sorts == ['object'] # The default equality theory should not import the arithmetic sorts either + assert set(sorts) == set(['object', 'Boolean'])# The default equality theory should not import the arithmetic sorts either def test_builtin_constants(): @@ -314,7 +314,7 @@ def test_matrices_constants(): assert isinstance(A.matrix[i, j], Term) -def test_term_hash_raises_exception(): +def test_term_and_formula_hash_raises_exception(): # from tarski.fstrips import language # from tarski.syntax import symref lang = fs.language("test") @@ -339,10 +339,13 @@ def test_term_hash_raises_exception(): # Atoms and in general formulas can be used without problem atom = f(c) == c - counter[atom] += 2 - assert counter[atom] == 2 + with pytest.raises(TypeError): + counter[atom] += 2 + counter[symref(atom)] += 2 + assert counter[symref(atom)] == 2 +@pytest.mark.skip(reason="todo: [John Peterson] re-enable these tests for shorthands once Contradiction and Tautology are fixed") def test_syntax_shorthands(): assert lor(*[]) == Contradiction(), "a lor(·) of no disjuncts is False" assert land(*[]) == Tautology(), "a land(·) of no conjuncts is True" diff --git a/tests/fstrips/contingent/localize.py b/tests/fstrips/contingent/localize.py index c0397a7a..4d6d2210 100644 --- a/tests/fstrips/contingent/localize.py +++ b/tests/fstrips/contingent/localize.py @@ -32,14 +32,16 @@ def create_small_task(): P.goal = G P.constraints += [constraint] - P.action('move_up', [], Tautology(), [fs.FunctionalEffect(y(), y() + 1)]) - P.action('move_down', [], Tautology(), [fs.FunctionalEffect(y(), y() - 1)]) - P.action('move_left', [], Tautology(), [fs.FunctionalEffect(x(), x() - 1)]) - P.action('move_right', [], Tautology(), [fs.FunctionalEffect(x(), x() + 1)]) - - P.sensor('sense_wall_up', [], Tautology(), y() == 4) - P.sensor('sense_wall_down', [], Tautology(), y() == -4) - P.sensor('sense_wall_left', [], Tautology(), x() == -4) - P.sensor('sense_wall_right', [], Tautology(), x() == 4) + lang = P.language + + P.action('move_up', [], Tautology(lang), [fs.FunctionalEffect(y(), y() + 1)]) + P.action('move_down', [], Tautology(lang), [fs.FunctionalEffect(y(), y() - 1)]) + P.action('move_left', [], Tautology(lang), [fs.FunctionalEffect(x(), x() - 1)]) + P.action('move_right', [], Tautology(lang), [fs.FunctionalEffect(x(), x() + 1)]) + + P.sensor('sense_wall_up', [], Tautology(lang), y() == 4) + P.sensor('sense_wall_down', [], Tautology(lang), y() == -4) + P.sensor('sense_wall_left', [], Tautology(lang), x() == -4) + P.sensor('sense_wall_right', [], Tautology(lang), x() == 4) return P diff --git a/tests/fstrips/contingent/test_sensors.py b/tests/fstrips/contingent/test_sensors.py index 5f213af5..1a87230a 100644 --- a/tests/fstrips/contingent/test_sensors.py +++ b/tests/fstrips/contingent/test_sensors.py @@ -10,11 +10,11 @@ def test_sensor_creation(): nav = grid_navigation.generate_single_agent_language() y = nav.get_function('y') - _ = contingent.Sensor(nav, 'sense_wall_up', [], Tautology(), y() == 4) + _ = contingent.Sensor(nav, 'sense_wall_up', [], Tautology(nav), y() == 4) def test_sensor_duplicate(): task = localize.create_small_task() y = task.language.get_function('y') with pytest.raises(contingent.errors.DuplicateSensorDefinition): - _ = task.sensor('sense_wall_up', [], Tautology(), y() == 4) + _ = task.sensor('sense_wall_up', [], Tautology(task.language), y() == 4) diff --git a/tests/fstrips/hybrid/test_differential.py b/tests/fstrips/hybrid/test_differential.py index c89356b1..a2eee844 100644 --- a/tests/fstrips/hybrid/test_differential.py +++ b/tests/fstrips/hybrid/test_differential.py @@ -10,5 +10,5 @@ def test_diff_constraint_creation(): x, y, f = [particles.get_function(name) for name in ['x', 'y', 'f']] p1, p2, p3, p4 = [particles.get_constant(name) for name in ['p1', 'p2', 'p3', 'p4']] - constraint = hybrid.DifferentialConstraint(particles, 'test', [], top, x(p1), f(p1) * 2.0) + constraint = hybrid.DifferentialConstraint(particles, 'test', [], particles.top(), x(p1), f(p1) * 2.0) assert isinstance(constraint, hybrid.DifferentialConstraint) diff --git a/tests/fstrips/test_fstrips_operations.py b/tests/fstrips/test_fstrips_operations.py index 4ab428bb..483d750a 100644 --- a/tests/fstrips/test_fstrips_operations.py +++ b/tests/fstrips/test_fstrips_operations.py @@ -19,7 +19,7 @@ def test_symbol_classification_in_gripper(): assert (len(fluent), len(static)) == (4, 3) fluent, static = approximate_symbol_fluency(prob, include_builtin=True) - assert (len(fluent), len(static)) == (4, 5) # Same as before plus "=" and "!=" + assert (len(fluent), len(static)) == (4, 7) # Same as before plus "=" and "!=" def test_symbol_classification_with_nested_effect_heads(): diff --git a/tests/fstrips/test_representation.py b/tests/fstrips/test_representation.py index 116a202f..d722c2a0 100644 --- a/tests/fstrips/test_representation.py +++ b/tests/fstrips/test_representation.py @@ -5,7 +5,7 @@ identify_cost_related_functions, compute_delete_free_relaxation, is_delete_free, is_strips_problem, \ is_conjunction_of_positive_atoms, is_strips_effect_set, compile_away_formula_negated_literals, \ compile_action_negated_preconditions_away, compile_negated_preconditions_away, compute_complementary_atoms -from tarski.syntax import exists, land, neg +from tarski.syntax import exists, land, neg, symref from tarski.fstrips import representation as rep, AddEffect, DelEffect from tarski.syntax.ops import flatten @@ -48,8 +48,8 @@ def test_literal_collection(): clear, loc, b1, b2, b3 = lang.get('clear', 'loc', 'b1', 'b2', 'b3') x = lang.variable('x', lang.ns.block) - assert rep.collect_literals_from_conjunction(clear(b1)) == {(clear(b1), True)} - assert rep.collect_literals_from_conjunction(~clear(b1)) == {(clear(b1), False)} + assert rep.collect_literals_from_conjunction(clear(b1)) == {(symref(clear(b1)), True)} + assert rep.collect_literals_from_conjunction(~clear(b1)) == {(symref(clear(b1)), False)} assert len(rep.collect_literals_from_conjunction(clear(b1) & ~clear(b2))) == 2 assert len(rep.collect_literals_from_conjunction(land(clear(b1), clear(b2), clear(b3)))) == 3 diff --git a/tests/fstrips/test_simplify.py b/tests/fstrips/test_simplify.py index cb428ebc..11726fda 100644 --- a/tests/fstrips/test_simplify.py +++ b/tests/fstrips/test_simplify.py @@ -3,7 +3,7 @@ from tarski.fstrips import UniversalEffect from tarski.fstrips.manipulation import Simplify from tarski.fstrips.manipulation.simplify import simplify_existential_quantification -from tarski.syntax import symref, land, lor, neg, bot, top, forall, exists +from tarski.syntax import symref, land, lor, neg, forall, exists, Tautology, Contradiction def test_simplifier(): @@ -57,6 +57,9 @@ def test_simplification_of_negation(): lang = problem.language b1, clear, on, ontable, handempty, holding = lang.get('b1', 'clear', 'on', 'ontable', 'handempty', 'holding') + top = Tautology(problem.language) + bot = Contradiction(problem.language) + s = Simplify(problem, problem.init) cb1 = clear(b1) assert str(s.simplify_expression(land(cb1, neg(bot)))) == 'clear(b1)' @@ -100,6 +103,8 @@ def test_simplification_of_ex_quantification(): z = lang.variable('z', counter) two, three, six = [lang.constant(c, val_t) for c in (2, 3, 6)] + top = Tautology(problem.language) + phi = exists(z, land(x == z, top, value(z) < six)) assert simplify_existential_quantification(phi, inplace=False) == land(top, value(x) < six), \ "z has been replaced by x and removed from the quantification list, thus removing the quantifier" diff --git a/tests/io/test_fstrips_parsing.py b/tests/io/test_fstrips_parsing.py index f6957ec1..7ee09aee 100644 --- a/tests/io/test_fstrips_parsing.py +++ b/tests/io/test_fstrips_parsing.py @@ -4,7 +4,7 @@ from tarski.fstrips import AddEffect, FunctionalEffect from tarski.fstrips.errors import InvalidEffectError from tarski.io.fstrips import ParsingError, FstripsReader -from tarski.syntax import Atom, CompoundFormula, Tautology +from tarski.syntax import Atom, CompoundFormula, Tautology, Pass from tarski.syntax.util import get_symbols from tarski.theories import Theory @@ -320,5 +320,5 @@ def test_increase_effects(): ], r=reader(strict_with_requirements=False)) # This uses one single reader for all tests increase = output[1][0] - assert isinstance(increase, FunctionalEffect) and isinstance(increase.condition, Tautology) + assert isinstance(increase, FunctionalEffect) and isinstance(increase.condition, Pass) assert str(increase.rhs) == '+(total-cost(), 1)' diff --git a/tests/io/test_rddl_writer.py b/tests/io/test_rddl_writer.py index 6510adc2..474bd094 100644 --- a/tests/io/test_rddl_writer.py +++ b/tests/io/test_rddl_writer.py @@ -472,3 +472,139 @@ def test_parametrized_model_with_random_vars_and_waypoints_boolean(): assert mr_reader.rddl_model.domain.name == 'lqg_nav_2d_multi_unit_bool_waypoints' mr_reader.translate_rddl_model() assert mr_reader.language is not None + +def test_rddl_integration_academic_advising_example_write(): + lang = tarski.language('standard', [Theory.EQUALITY, Theory.ARITHMETIC, Theory.RANDOM]) + the_task = Task(lang, 'academic_advising', 'instance_001') + + the_task.requirements = [rddl.Requirements.REWARD_DET, rddl.Requirements.PRECONDITIONS] + + the_task.parameters.discount = 1.0 + the_task.parameters.horizon = 20 + the_task.parameters.max_nondef_actions = 1 + + #sorts + course = lang.sort('course') + + # variables + c = lang.variable('c', course) + c2 = lang.variable('c2', course) + + # non fluents + PREREQ = lang.predicate('PREREQ', course, course) + PRIOR_PROB_PASS_NO_PREREQ = lang.function('PRIOR_PROB_PASS_NO_PREREQ', course, lang.Real) + PRIOR_PROB_PASS = lang.function('PRIOR_PROB_PASS', course, lang.Real) + PROGRAM_REQUIREMENT = lang.predicate('PROGRAM_REQUIREMENT', course) + COURSE_COST = lang.function('COURSE_COST', lang.Real) + COURSE_RETAKE_COST = lang.function('COURSE_RETAKE_COST', lang.Real) + PROGRAM_INCOMPLETE_PENALTY = lang.function('PROGRAM_INCOMPLETE_PENALTY', lang.Real) + COURSES_PER_SEMESTER = lang.function('COURSES_PER_SEMESTER', lang.Real) + + # state fluents + passed = lang.predicate('passed', course) + taken = lang.predicate('taken', course) + + # action fluents + take_course = lang.predicate('take-course', course) + + one = lang.constant(1, lang.Real) + # cpfs + the_task.add_cpfs(passed(c), ite(take_course(c) & ~(exists(c2, PREREQ(c2, c))), + bernoulli(PRIOR_PROB_PASS_NO_PREREQ(c)), + ite(take_course(c), + bernoulli((one - PRIOR_PROB_PASS(c)) + * (sumterm(c2, (PREREQ(c2, c) & passed(c2)))) + / (one + sumterm(c2, (PREREQ(c2, c))))), + passed(c)))) + + the_task.add_cpfs(taken(c), taken(c) | take_course(c)) + + # cost function + the_task.reward = ( sumterm(c, COURSE_COST() * (take_course(c) & ~taken(c))) + + sumterm(c, COURSE_RETAKE_COST() * (take_course(c) & taken(c))) + + (PROGRAM_INCOMPLETE_PENALTY() * ~(forall(c, PROGRAM_REQUIREMENT(c) > passed(c))))) + + # constraints + the_task.add_constraint(forall(c, take_course(c) > ~passed(c)), rddl.ConstraintType.ACTION) + the_task.add_constraint(sumterm(c, take_course(c)) <= COURSES_PER_SEMESTER(), rddl.ConstraintType.ACTION) + + # fluent metadata + the_task.declare_state_fluent(passed(c), 'false') + the_task.declare_state_fluent(taken(c), 'false') + the_task.declare_action_fluent(take_course(c), 'false') + the_task.declare_non_fluent(PREREQ(c, c2), 'false') + the_task.declare_non_fluent(PRIOR_PROB_PASS_NO_PREREQ(c), 0.8) + the_task.declare_non_fluent(PRIOR_PROB_PASS(c), 0.2) + the_task.declare_non_fluent(PROGRAM_REQUIREMENT(c), 'false') + the_task.declare_non_fluent(COURSE_COST(), -1) + the_task.declare_non_fluent(COURSE_RETAKE_COST(), -2) + the_task.declare_non_fluent(PROGRAM_INCOMPLETE_PENALTY(), -5) + the_task.declare_non_fluent(COURSES_PER_SEMESTER(), 1) + + #constants + c0000 = lang.constant('c0000', course) + c0001 = lang.constant('c0001', course) + c0002 = lang.constant('c0002', course) + c0003 = lang.constant('c0003', course) + c0004 = lang.constant('c0004', course) + c0100 = lang.constant('c0100', course) + c0101 = lang.constant('c0101', course) + c0102 = lang.constant('c0102', course) + c0103 = lang.constant('c0103', course) + c0200 = lang.constant('c0200', course) + c0201 = lang.constant('c0201', course) + c0202 = lang.constant('c0202', course) + c0300 = lang.constant('c0300', course) + c0301 = lang.constant('c0301', course) + c0302 = lang.constant('c0302', course) + + the_task.x0.setx(PRIOR_PROB_PASS_NO_PREREQ(c0000), 0.80) + the_task.x0.setx(PRIOR_PROB_PASS_NO_PREREQ(c0001), 0.55) + the_task.x0.setx(PRIOR_PROB_PASS_NO_PREREQ(c0002), 0.67) + the_task.x0.setx(PRIOR_PROB_PASS_NO_PREREQ(c0003), 0.78) + the_task.x0.setx(PRIOR_PROB_PASS_NO_PREREQ(c0004), 0.75) + the_task.x0.setx(PRIOR_PROB_PASS(c0100), 0.22) + the_task.x0.setx(PRIOR_PROB_PASS(c0101), 0.45) + the_task.x0.setx(PRIOR_PROB_PASS(c0102), 0.41) + the_task.x0.setx(PRIOR_PROB_PASS(c0103), 0.44) + the_task.x0.setx(PRIOR_PROB_PASS(c0200), 0.14) + the_task.x0.setx(PRIOR_PROB_PASS(c0201), 0.07) + the_task.x0.setx(PRIOR_PROB_PASS(c0202), 0.24) + the_task.x0.setx(PRIOR_PROB_PASS(c0300), 0.23) + the_task.x0.setx(PRIOR_PROB_PASS(c0301), 0.08) + the_task.x0.setx(PRIOR_PROB_PASS(c0302), 0.16) + + + the_task.x0.add(PREREQ(c0003, c0100), 'true') + the_task.x0.add(PREREQ(c0000, c0100), 'true') + the_task.x0.add(PREREQ(c0004, c0100), 'true') + the_task.x0.add(PREREQ(c0001, c0101), 'true') + the_task.x0.add(PREREQ(c0002, c0101), 'true') + the_task.x0.add(PREREQ(c0000, c0102), 'true') + the_task.x0.add(PREREQ(c0004, c0102), 'true') + the_task.x0.add(PREREQ(c0001, c0103), 'true') + the_task.x0.add(PREREQ(c0001, c0200), 'true') + the_task.x0.add(PREREQ(c0101, c0200), 'true') + the_task.x0.add(PREREQ(c0103, c0201), 'true') + the_task.x0.add(PREREQ(c0002, c0202), 'true') + the_task.x0.add(PREREQ(c0200, c0300), 'true') + the_task.x0.add(PREREQ(c0201, c0301), 'true') + the_task.x0.add(PREREQ(c0201, c0301), 'true') + the_task.x0.add(PREREQ(c0200, c0302), 'true') + + the_task.x0.add(PROGRAM_REQUIREMENT(c0300), 'true') + the_task.x0.add(PROGRAM_REQUIREMENT(c0202), 'true') + the_task.x0.add(PROGRAM_REQUIREMENT(c0101), 'true') + the_task.x0.add(PROGRAM_REQUIREMENT(c0002), 'true') + the_task.x0.add(PROGRAM_REQUIREMENT(c0001), 'true') + + the_task.x0.add(passed(c0000), 'false') + + the_writer = rddl.Writer(the_task) + rddl_filename = os.path.join(tempfile.gettempdir(), 'academic_advising_001.rddl') + the_writer.write_model(rddl_filename) + mr_reader = rddl.Reader(rddl_filename) + assert mr_reader.rddl_model is not None + assert mr_reader.rddl_model.domain.name == 'academic_advising' + mr_reader.translate_rddl_model() + assert mr_reader.language is not None diff --git a/tests/rddl/test_boolean_term_interop.py b/tests/rddl/test_boolean_term_interop.py new file mode 100644 index 00000000..8d0352d8 --- /dev/null +++ b/tests/rddl/test_boolean_term_interop.py @@ -0,0 +1,145 @@ +import pytest +import tarski +from tarski.syntax import * +from tarski.theories import Theory + + +def test_bool_arithmetic_atom(): + lang = tarski.language(theories=[Theory.ARITHMETIC]) + f = lang.predicate('f') + phi = lang.constant(1, lang.Integer) * f() + assert isinstance(phi, Term), "phi should be a Term" + +def test_bool_arithmetic_compoundterm(): + lang = tarski.language(theories=[Theory.ARITHMETIC]) + f = lang.predicate('f') + phi = lang.constant(1, lang.Integer) * (lang.constant(0, lang.Integer) > lang.constant(1, lang.Integer)) + assert isinstance(phi, Term), "phi should be a Term" + +def test_bool_constants(): + lang = tarski.language(theories=[Theory.ARITHMETIC]) + true = lang.constant(1, lang.Boolean) + assert isinstance(true, Constant), "Boolean constants should be declareable" + +def test_complex_mixed_type_formula(): + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.EQUALITY]) + x = lang.predicate("x") + y = lang.predicate("y") + f = ((x == y) + lang.constant(1, lang.Integer)) + assert isinstance(f, CompoundTerm) , "should construct a correct CompoundTerm with mixed Bool and Int" + z = lang.predicate("z") + g = ((x() == y()) & z()) + assert isinstance(g, CompoundFormula), "should construct a CompoundFormula when allowed" + + a = lang.predicate("a") + b = lang.predicate("b") + g = ((x == y) + (a() & b())) + + assert isinstance(g, CompoundTerm), "should construct a correct CompoundTerm with mixed Bool and Int" + +def test_equiv_simple(): + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.EQUALITY]) + c1 = lang.constant(1, lang.Boolean) + c2 = lang.constant(1, lang.Boolean) + + v = lang.variable('v', lang.Boolean) + + a = (v == c1) + b = (v == c2) + + assert a.is_syntactically_equal(b) + +def test_equiv_mixed_type(): + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.EQUALITY]) + c1 = lang.constant(True, lang.Boolean) + c2 = lang.constant(True, lang.Boolean) + + v = lang.variable('v', lang.Boolean) + x = lang.variable('x', lang.Integer) + y = lang.variable('y', lang.Integer) + + a = (x < y) | (c1 & v) + b = (x > y) | (c1 & v) + assert not a.is_syntactically_equal(b) + +def test_quantified_over_bool_functions(): + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.BOOLEAN, Theory.EQUALITY]) + + c1 = lang.constant(True, lang.Boolean) + v = lang.variable('v', lang.Boolean) + x = lang.variable('x', lang.Integer) + y = lang.variable('y', lang.Integer) + b = (x > y) | (c1 & v) + + qf = tarski.syntax.formulas.forall(v,x,y, (b == 1)) + assert isinstance(qf, QuantifiedFormula), "should have built a quantified formula" + #todo: better test + with pytest.raises(Exception): + qf = tarski.syntax.formulas.forall(v,x,y, b + 5) #should raise an exception if top level type is not boolean + + +def test_if_then_else_boolean_codomain(): + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.BOOLEAN, Theory.EQUALITY]) + + v = lang.predicate('v') + x = lang.variable('x', lang.Integer) + y = lang.variable('y', lang.Integer) + + ite = tarski.syntax.ite(x < y, v(), lang.constant(0, lang.Boolean)) + assert isinstance(ite, Term), "if then else should be a term" + assert ite.sort == lang.Boolean, "if then else term should have boolean codomain" + +def test_if_then_else_numeric_codomain(): + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.BOOLEAN, Theory.EQUALITY]) + + v = lang.predicate('v') + x = lang.variable('x', lang.Integer) + y = lang.variable('y', lang.Integer) + + ite = tarski.syntax.ite(x < y, (x + 1), lang.constant(3, lang.Integer)) + assert isinstance(ite, Term), "if then else should be a term" + assert ite.sort == lang.Real, "if then else term should have numeric codomain" #todo: actually might need to be integer + + +def test_if_then_else_fails_on_conflicting_codomain(): + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.BOOLEAN, Theory.EQUALITY]) + + v = lang.variable('v', lang.Boolean) + x = lang.variable('x', lang.Integer) + y = lang.variable('y', lang.Integer) + + ite = tarski.syntax.ite(x < y, (x + 1), lang.constant(3, lang.Integer)) + + with pytest.raises(Exception): + ite = tarski.syntax.ite(x < y, lang.constant(True, lang.Boolean), lang.constant(3, lang.Integer)) + +def test_sum_operates_over_predicates(): + from tarski.syntax.arithmetic import sumterm + + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.EQUALITY]) + box_t = lang.sort("box") + v = lang.variable('v', box_t) + boxes = [lang.constant(f"b{k}", box_t) for k in range(5)] + + visible = lang.predicate('visible', box_t) + s = (sumterm(v, visible(v)) >= lang.constant(3, lang.Integer)) + assert isinstance(s, Formula), "after summing and making a logical comparison, we should have a formula" + + +def test_nested_quantification_in_ite(): + lang = tarski.language(theories=[Theory.ARITHMETIC, Theory.EQUALITY]) + + a_sort = lang.sort('a_sort') + v = lang.variable('v', a_sort) + vp = lang.variable('vp', a_sort) + p1 = lang.predicate('predicate1', a_sort) + p2 = lang.predicate('predicate2', a_sort) + p3 = lang.predicate('predicate3') + weight = lang.function('weight', lang.Real) + nested = ite(p1(v) & ~(exists(vp, p2(vp))), + weight(), + 1 - (weight() * p3())) + assert isinstance(nested, Term) + assert isinstance(nested.condition, Formula) + assert nested.subterms[0].sort == lang.Real + assert nested.subterms[1].sort == lang.Real