diff --git a/include/loki/details/pddl/declarations.hpp b/include/loki/details/pddl/declarations.hpp index c9f16b58..4addb5e0 100644 --- a/include/loki/details/pddl/declarations.hpp +++ b/include/loki/details/pddl/declarations.hpp @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -42,6 +43,7 @@ using Requirements = const RequirementsImpl*; class TypeImpl; using Type = const TypeImpl*; using TypeList = std::vector; +using TypeSet = std::unordered_set; class ObjectImpl; using Object = const ObjectImpl*; diff --git a/include/loki/details/pddl/exceptions.hpp b/include/loki/details/pddl/exceptions.hpp index 3b378a9c..33aa539f 100644 --- a/include/loki/details/pddl/exceptions.hpp +++ b/include/loki/details/pddl/exceptions.hpp @@ -155,6 +155,12 @@ class MismatchedFunctionSkeletonTermListError : public SemanticParserError MismatchedFunctionSkeletonTermListError(const FunctionSkeleton& function_skeleton, const TermList& term_list, const std::string& error_handler_output); }; +class IncompatibleObjectTypeError : public SemanticParserError +{ +public: + IncompatibleObjectTypeError(const Object& object, const Variable& variable, const std::string& error_handler_output); +}; + class UnexpectedDerivedPredicateInEffect : public SemanticParserError { public: diff --git a/include/loki/details/pddl/scope.hpp b/include/loki/details/pddl/scope.hpp index b72286f3..b4e7ea1a 100644 --- a/include/loki/details/pddl/scope.hpp +++ b/include/loki/details/pddl/scope.hpp @@ -43,39 +43,34 @@ namespace loki /// The position points to the matched location /// in the input stream and is used for error reporting. template -using BindingValueType = std::tuple, std::optional>; +using BindingValueType = std::pair>; -/// @brief Datastructure to store bindings of a type T. +/// @brief Encapsulates the result of search for a binding with the corresponding ErrorHandler. template -using BindingMapType = std::unordered_map>; - -/// @brief Encapsulates bindings for different types. -template -class Bindings -{ -private: - std::tuple...> bindings; - -public: - /// @brief Returns a binding if it exists. - template - std::optional> get(const std::string& key) const; +using BindingSearchResult = std::tuple, const PDDLErrorHandler&>; - /// @brief Inserts a binding of type T - template - void insert(const std::string& key, const PDDLElement& binding, const std::optional& position); -}; +/// @brief Datastructure to store bindings of a type T. +template +using Bindings = std::unordered_map>; /// @brief Wraps bindings in a scope with reference to a parent scope. class Scope { private: + const PDDLErrorHandler& m_error_handler; const Scope* m_parent_scope; - Bindings bindings; + Bindings m_types; + Bindings m_objects; + Bindings m_function_skeletons; + Bindings m_variables; + Bindings m_predicates; + Bindings m_derived_predicates; + + std::unordered_map m_variable_types; public: - explicit Scope(const Scope* parent_scope = nullptr); + Scope(const PDDLErrorHandler& error_handler, const Scope* parent_scope = nullptr); // delete copy and move to avoid dangling references. Scope(const Scope& other) = delete; @@ -83,19 +78,28 @@ class Scope Scope(Scope&& other) = delete; Scope& operator=(Scope&& other) = delete; - /// @brief Returns a binding if it exists. - template - std::optional> get(const std::string& name) const; + /// @brief Return a binding if it exists. + std::optional> get_type(const std::string& name) const; + std::optional> get_object(const std::string& name) const; + std::optional> get_function_skeleton(const std::string& name) const; + std::optional> get_variable(const std::string& name) const; + std::optional> get_predicate(const std::string& name) const; + std::optional> get_derived_predicate(const std::string& name) const; + const TypeSet& get_variable_types(const Variable& variable) const; + + /// @brief Insert a binding. + void insert_type(const std::string& name, const Type& type, const std::optional& position); + void insert_object(const std::string& name, const Object& object, const std::optional& position); + void insert_function_skeleton(const std::string& name, const FunctionSkeleton& function_skeleton, const std::optional& position); + void insert_variable(const std::string& name, const Variable& variable, const std::optional& position); + void insert_predicate(const std::string& name, const Predicate& predicate, const std::optional& position); + void insert_derived_predicate(const std::string& name, const Predicate& derived_predicate, const std::optional& position); + void insert_variable_types(const Variable& variable, const TypeSet& types); - /// @brief Insert a binding of type T. - template - void insert(const std::string& name, const PDDLElement& element, const std::optional& position); + /// @brief Get the error handler to print an error message. + const PDDLErrorHandler& get_error_handler() const; }; -/// @brief Encapsulates the result of search for a binding with the corresponding ErrorHandler. -template -using ScopeStackSearchResult = std::tuple, const std::optional, const PDDLErrorHandler&>; - /// @brief Implements a scoping mechanism to store bindings which are mappings from name to a pointer to a PDDL object /// type and a position in the input stream that can be used to construct error messages with the given ErrorHandler. /// @@ -113,12 +117,11 @@ using ScopeStackSearchResult = std::tuple, const std::optio class ScopeStack { private: - std::deque> m_stack; - const PDDLErrorHandler& m_error_handler; - const ScopeStack* m_parent; + std::deque> m_stack; + public: ScopeStack(const PDDLErrorHandler& error_handler, const ScopeStack* parent = nullptr); @@ -134,16 +137,9 @@ class ScopeStack /// @brief Deletes the topmost scope from the stack. void close_scope(); - /// @brief Returns a binding if it exists. - template - std::optional> get(const std::string& name) const; - - /// @brief Insert a binding of type T. - template - void insert(const std::string& name, const PDDLElement& element, const std::optional& position); - - /// @brief Get the error handler to print an error message. - const PDDLErrorHandler& get_error_handler() const; + /// @brief Return a binding if it exists. + Scope& top(); + const Scope& top() const; // For testing purposes only. const std::deque>& get_stack() const; @@ -151,6 +147,4 @@ class ScopeStack } -#include "scope.tpp" - #endif \ No newline at end of file diff --git a/include/loki/details/pddl/scope.tpp b/include/loki/details/pddl/scope.tpp deleted file mode 100644 index 2e2561cc..00000000 --- a/include/loki/details/pddl/scope.tpp +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (C) 2023 Dominik Drexler - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -namespace loki -{ - -template -template -std::optional> Bindings::get(const std::string& key) const -{ - const auto& t_bindings = std::get>(bindings); - auto it = t_bindings.find(key); - if (it != t_bindings.end()) - { - return { it->second }; - } - return std::nullopt; -} - -template -template -void Bindings::insert(const std::string& key, const PDDLElement& element, const std::optional& position) -{ - assert(element); - auto& t_bindings = std::get>(bindings); - assert(!t_bindings.count(key)); - t_bindings.emplace(key, std::make_tuple(element, position)); -} - -template -std::optional> Scope::get(const std::string& name) const -{ - const auto result = bindings.get(name); - if (result.has_value()) - return result.value(); - if (m_parent_scope) - { - return m_parent_scope->get(name); - } - return std::nullopt; -} - -template -void Scope::insert(const std::string& name, const PDDLElement& element, const std::optional& position) -{ - assert(element); - assert(!this->get(name)); - bindings.insert(name, element, position); -} - -template -std::optional> ScopeStack::get(const std::string& name) const -{ - assert(!m_stack.empty()); - auto result = m_stack.back()->get(name); - if (result.has_value()) - { - return std::make_tuple(std::get<0>(result.value()), std::get<1>(result.value()), std::cref(m_error_handler)); - } - if (m_parent) - return m_parent->get(name); - return std::nullopt; -} - -/// @brief Insert a binding of type T. -template -void ScopeStack::insert(const std::string& name, const PDDLElement& element, const std::optional& position) -{ - assert(!m_stack.empty()); - m_stack.back()->insert(name, element, position); -} - -} \ No newline at end of file diff --git a/include/loki/details/pddl/type.hpp b/include/loki/details/pddl/type.hpp index e2fc724e..cf10b72f 100644 --- a/include/loki/details/pddl/type.hpp +++ b/include/loki/details/pddl/type.hpp @@ -49,6 +49,10 @@ class TypeImpl : public Base const std::string& get_name() const; const TypeList& get_bases() const; }; + +/// @brief Collects all types from a hierarchy. +extern TypeSet collect_types_from_hierarchy(const TypeList& types); + } #endif diff --git a/src/pddl/exceptions.cpp b/src/pddl/exceptions.cpp index 13e2a2c3..de5d6e4d 100644 --- a/src/pddl/exceptions.cpp +++ b/src/pddl/exceptions.cpp @@ -21,7 +21,9 @@ #include "loki/details/pddl/function.hpp" #include "loki/details/pddl/function_skeleton.hpp" #include "loki/details/pddl/object.hpp" +#include "loki/details/pddl/parameter.hpp" #include "loki/details/pddl/predicate.hpp" +#include "loki/details/pddl/variable.hpp" #include @@ -152,6 +154,13 @@ UnexpectedDerivedPredicateInEffect::UnexpectedDerivedPredicateInEffect(const std { } +IncompatibleObjectTypeError::IncompatibleObjectTypeError(const Object& object, const Variable& variable, const std::string& error_handler_output) : + SemanticParserError("The object with name \"" + object->get_name() + "\" does not satisfy the type requirement of variable with name \"" + + variable->get_name() + "\".", + error_handler_output) +{ +} + ExpectedDerivedPredicate::ExpectedDerivedPredicate(const std::string& name, const std::string& error_handler_output) : SemanticParserError("The predicate with name \"" + name + "\" is not a derived predicate.", error_handler_output) { diff --git a/src/pddl/parser.cpp b/src/pddl/parser.cpp index 46a96286..df7f04da 100644 --- a/src/pddl/parser.cpp +++ b/src/pddl/parser.cpp @@ -64,7 +64,7 @@ Domain parse(const ast::Domain& domain_node, Context& context) { if (!context.requirements->test(RequirementEnum::TYPING)) { - throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.get_error_handler()(domain_node.types.value(), "")); + throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.top().get_error_handler()(domain_node.types.value(), "")); } types = parse(domain_node.types.value(), context); } @@ -127,7 +127,7 @@ Problem parse(const ast::Problem& problem_node, Context& context, const Domain& const auto domain_name = parse(problem_node.domain_name.name); if (domain_name != domain->get_name()) { - throw MismatchedDomainError(domain, domain_name, context.scopes.get_error_handler()(problem_node.domain_name, "")); + throw MismatchedDomainError(domain, domain_name, context.scopes.top().get_error_handler()(problem_node.domain_name, "")); } /* Problem name section */ const auto problem_name = parse(problem_node.problem_name.name); diff --git a/src/pddl/parser/common.cpp b/src/pddl/parser/common.cpp index 768f17ae..1cba78ad 100644 --- a/src/pddl/parser/common.cpp +++ b/src/pddl/parser/common.cpp @@ -45,10 +45,10 @@ Term TermDeclarationTermVisitor::operator()(const ast::Name& node) const { const auto constant_name = parse(node); // Test for undefined constant. - const auto binding = context.scopes.get(constant_name); + const auto binding = context.scopes.top().get_object(constant_name); if (!binding.has_value()) { - throw UndefinedConstantError(constant_name, context.scopes.get_error_handler()(node, "")); + throw UndefinedConstantError(constant_name, context.scopes.top().get_error_handler()(node, "")); } // Constant are not tracked and hence must not be untracked. // Construct Term and return it @@ -63,17 +63,17 @@ Term TermDeclarationTermVisitor::operator()(const ast::Variable& node) const { const auto variable = parse(node, context); // Test for multiple definition - const auto binding = context.scopes.get(variable->get_name()); + const auto binding = context.scopes.top().get_variable(variable->get_name()); if (binding.has_value()) { - const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:"); + const auto message_1 = context.scopes.top().get_error_handler()(node, "Defined here:"); const auto [_constant, position, error_handler] = binding.value(); assert(position.has_value()); const auto message_2 = error_handler(position.value(), "First defined here:"); throw MultiDefinitionVariableError(variable->get_name(), message_1 + message_2); } // Add binding to scope - context.scopes.insert(variable->get_name(), variable, node); + context.scopes.top().insert_variable(variable->get_name(), variable, node); // Construct Term and return it const auto term = context.factories.get_or_create_term_variable(variable); // Add position of PDDL object @@ -87,10 +87,10 @@ Term TermReferenceTermVisitor::operator()(const ast::Name& node) const { const auto object_name = parse(node); // Test for undefined constant. - const auto binding = context.scopes.get(object_name); + const auto binding = context.scopes.top().get_object(object_name); if (!binding.has_value()) { - throw UndefinedConstantError(object_name, context.scopes.get_error_handler()(node, "")); + throw UndefinedConstantError(object_name, context.scopes.top().get_error_handler()(node, "")); } // Construct Term and return it const auto [object, _position, _error_handler] = binding.value(); @@ -105,10 +105,10 @@ Term TermReferenceTermVisitor::operator()(const ast::Variable& node) const { const auto variable = parse(node, context); // Test for undefined variable - const auto binding = context.scopes.get(variable->get_name()); + const auto binding = context.scopes.top().get_variable(variable->get_name()); if (!binding.has_value()) { - throw UndefinedVariableError(variable->get_name(), context.scopes.get_error_handler()(node, "")); + throw UndefinedVariableError(variable->get_name(), context.scopes.top().get_error_handler()(node, "")); } // Construct Term and return it const auto term = context.factories.get_or_create_term_variable(variable); diff --git a/src/pddl/parser/conditions.cpp b/src/pddl/parser/conditions.cpp index 958373e0..705c12f3 100644 --- a/src/pddl/parser/conditions.cpp +++ b/src/pddl/parser/conditions.cpp @@ -50,7 +50,7 @@ Condition parse(const ast::GoalDescriptorLiteral& node, Context& context) // requires :negative-preconditions if (!context.requirements->test(RequirementEnum::NEGATIVE_PRECONDITIONS)) { - throw UndefinedRequirementError(RequirementEnum::NEGATIVE_PRECONDITIONS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::NEGATIVE_PRECONDITIONS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::NEGATIVE_PRECONDITIONS); const auto condition = context.factories.get_or_create_condition_literal(parse(node.literal, context)); @@ -71,7 +71,7 @@ Condition parse(const ast::GoalDescriptorOr& node, Context& context) // requires :disjunctive-preconditions if (!context.requirements->test(RequirementEnum::DISJUNCTIVE_PRECONDITIONS)) { - throw UndefinedRequirementError(RequirementEnum::DISJUNCTIVE_PRECONDITIONS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::DISJUNCTIVE_PRECONDITIONS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::DISJUNCTIVE_PRECONDITIONS); auto condition_list = parse(node.goal_descriptors, context); @@ -85,7 +85,7 @@ Condition parse(const ast::GoalDescriptorNot& node, Context& context) // requires :disjunctive-preconditions if (!context.requirements->test(RequirementEnum::NEGATIVE_PRECONDITIONS)) { - throw UndefinedRequirementError(RequirementEnum::NEGATIVE_PRECONDITIONS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::NEGATIVE_PRECONDITIONS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::NEGATIVE_PRECONDITIONS); auto child_condition = parse(node.goal_descriptor, context); @@ -98,7 +98,7 @@ Condition parse(const ast::GoalDescriptorImply& node, Context& context) { if (!context.requirements->test(RequirementEnum::DISJUNCTIVE_PRECONDITIONS)) { - throw UndefinedRequirementError(RequirementEnum::DISJUNCTIVE_PRECONDITIONS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::DISJUNCTIVE_PRECONDITIONS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::DISJUNCTIVE_PRECONDITIONS); auto condition_left = parse(node.goal_descriptor_left, context); @@ -112,7 +112,7 @@ Condition parse(const ast::GoalDescriptorExists& node, Context& context) { if (!context.requirements->test(RequirementEnum::EXISTENTIAL_PRECONDITIONS)) { - throw UndefinedRequirementError(RequirementEnum::EXISTENTIAL_PRECONDITIONS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::EXISTENTIAL_PRECONDITIONS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::EXISTENTIAL_PRECONDITIONS); context.scopes.open_scope(); @@ -144,7 +144,7 @@ Condition parse(const ast::GoalDescriptorForall& node, Context& context) { if (!context.requirements->test(RequirementEnum::UNIVERSAL_PRECONDITIONS)) { - throw UndefinedRequirementError(RequirementEnum::UNIVERSAL_PRECONDITIONS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::UNIVERSAL_PRECONDITIONS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::UNIVERSAL_PRECONDITIONS); return parse_condition_forall(node.typed_list_of_variables, node.goal_descriptor, context); @@ -154,7 +154,7 @@ Condition parse(const ast::GoalDescriptorFunctionComparison& node, Context& cont { if (!context.requirements->test(RequirementEnum::NUMERIC_FLUENTS)) { - throw UndefinedRequirementError(RequirementEnum::NUMERIC_FLUENTS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::NUMERIC_FLUENTS, context.scopes.top().get_error_handler()(node, "")); } throw NotImplementedError("parse(const ast::GoalDescriptorFunctionComparison& node, Context& context)"); } @@ -177,7 +177,7 @@ Condition parse(const ast::ConstraintGoalDescriptorForall& node, Context& contex { if (!context.requirements->test(RequirementEnum::UNIVERSAL_PRECONDITIONS)) { - throw UndefinedRequirementError(RequirementEnum::UNIVERSAL_PRECONDITIONS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::UNIVERSAL_PRECONDITIONS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::UNIVERSAL_PRECONDITIONS); return parse_condition_forall(node.typed_list_of_variables, node.constraint_goal_descriptor, context); @@ -253,7 +253,7 @@ Condition parse(const ast::PreconditionGoalDescriptorPreference& node, Context& { if (!context.requirements->test(RequirementEnum::PREFERENCES)) { - throw UndefinedRequirementError(RequirementEnum::PREFERENCES, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::PREFERENCES, context.scopes.top().get_error_handler()(node, "")); } throw NotImplementedError("parse(const ast::PreconditionGoalDescriptorPreference& node, Context& context)"); } @@ -262,7 +262,7 @@ Condition parse(const ast::PreconditionGoalDescriptorForall& node, Context& cont { if (!context.requirements->test(RequirementEnum::UNIVERSAL_PRECONDITIONS)) { - throw UndefinedRequirementError(RequirementEnum::UNIVERSAL_PRECONDITIONS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::UNIVERSAL_PRECONDITIONS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::UNIVERSAL_PRECONDITIONS); return parse_condition_forall(node.typed_list_of_variables, node.precondition_goal_descriptor, context); diff --git a/src/pddl/parser/constants.cpp b/src/pddl/parser/constants.cpp index 05057a53..5f45b33a 100644 --- a/src/pddl/parser/constants.cpp +++ b/src/pddl/parser/constants.cpp @@ -28,10 +28,10 @@ namespace loki static void test_multiple_definition(const Object& constant, const ast::Name& node, const Context& context) { const auto constant_name = constant->get_name(); - const auto binding = context.scopes.get(constant_name); + const auto binding = context.scopes.top().get_object(constant_name); if (binding.has_value()) { - const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:"); + const auto message_1 = context.scopes.top().get_error_handler()(node, "Defined here:"); auto message_2 = std::string(""); const auto [_object, position, error_handler] = binding.value(); if (position.has_value()) @@ -45,7 +45,7 @@ static void test_multiple_definition(const Object& constant, const ast::Name& no static void insert_context_information(const Object& constant, const ast::Name& node, Context& context) { context.positions.push_back(constant, node); - context.scopes.insert(constant->get_name(), constant, node); + context.scopes.top().insert_object(constant->get_name(), constant, node); } static Object parse_constant_definition(const ast::Name& node, const TypeList& type_list, Context& context) @@ -76,8 +76,8 @@ ConstantListVisitor::ConstantListVisitor(Context& context_) : context(context_) ObjectList ConstantListVisitor::operator()(const std::vector& name_nodes) { // std::vector has single base type "object" - assert(context.scopes.get("object").has_value()); - const auto [type, _position, _error_handler] = context.scopes.get("object").value(); + assert(context.scopes.top().get_type("object").has_value()); + const auto [type, _position, _error_handler] = context.scopes.top().get_type("object").value(); return parse_constant_definitions(name_nodes, TypeList { type }, context); } @@ -85,7 +85,7 @@ ObjectList ConstantListVisitor::operator()(const ast::TypedListOfNamesRecursivel { if (!context.requirements->test(RequirementEnum::TYPING)) { - throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.get_error_handler()(typed_list_of_names_recursively_node, "")); + throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.top().get_error_handler()(typed_list_of_names_recursively_node, "")); } context.references.untrack(RequirementEnum::TYPING); const auto type_list = boost::apply_visitor(TypeReferenceTypeVisitor(context), typed_list_of_names_recursively_node.type); diff --git a/src/pddl/parser/effects.cpp b/src/pddl/parser/effects.cpp index b25fee8a..ebfea369 100644 --- a/src/pddl/parser/effects.cpp +++ b/src/pddl/parser/effects.cpp @@ -71,7 +71,7 @@ Effect parse(const ast::EffectProductionLiteral& node, Context& context) if (context.derived_predicates.count(literal->get_atom()->get_predicate())) { - throw UnexpectedDerivedPredicateInEffect(literal->get_atom()->get_predicate()->get_name(), context.scopes.get_error_handler()(node, "")); + throw UnexpectedDerivedPredicateInEffect(literal->get_atom()->get_predicate()->get_name(), context.scopes.top().get_error_handler()(node, "")); } context.positions.push_back(effect, node); @@ -82,16 +82,16 @@ Effect parse(const ast::EffectProductionNumericFluentTotalCost& node, Context& c { if (!context.requirements->test(RequirementEnum::ACTION_COSTS)) { - throw UndefinedRequirementError(RequirementEnum::ACTION_COSTS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::ACTION_COSTS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::ACTION_COSTS); const auto assign_operator_increase = parse(node.assign_operator_increase); auto function_name = parse(node.function_symbol_total_cost.name); assert(function_name == "total-cost"); - auto binding = context.scopes.get(function_name); + auto binding = context.scopes.top().get_function_skeleton(function_name); if (!binding.has_value()) { - throw UndefinedFunctionSkeletonError(function_name, context.scopes.get_error_handler()(node.function_symbol_total_cost, "")); + throw UndefinedFunctionSkeletonError(function_name, context.scopes.top().get_error_handler()(node.function_symbol_total_cost, "")); } const auto [function_skeleton, _position, _error_handler] = binding.value(); const auto function = context.factories.get_or_create_function(function_skeleton, TermList {}); @@ -106,7 +106,7 @@ Effect parse(const ast::EffectProductionNumericFluentGeneral& node, Context& con { if (!context.requirements->test(RequirementEnum::NUMERIC_FLUENTS)) { - throw UndefinedRequirementError(RequirementEnum::NUMERIC_FLUENTS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::NUMERIC_FLUENTS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::NUMERIC_FLUENTS); const auto assign_operator = parse(node.assign_operator); @@ -149,7 +149,7 @@ Effect parse(const ast::EffectConditional& node, Context& context) // requires :conditional-effects if (!context.requirements->test(RequirementEnum::CONDITIONAL_EFFECTS)) { - throw UndefinedRequirementError(RequirementEnum::CONDITIONAL_EFFECTS, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::CONDITIONAL_EFFECTS, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::CONDITIONAL_EFFECTS); const auto effect = boost::apply_visitor(EffectVisitor(context), node); diff --git a/src/pddl/parser/functions.cpp b/src/pddl/parser/functions.cpp index 1d158f97..a6585d09 100644 --- a/src/pddl/parser/functions.cpp +++ b/src/pddl/parser/functions.cpp @@ -95,7 +95,7 @@ Function parse(const ast::FunctionHead& node, Context& context) } if (function_skeleton->get_parameters().size() != term_list.size()) { - throw MismatchedFunctionSkeletonTermListError(function_skeleton, term_list, context.scopes.get_error_handler()(node, "")); + throw MismatchedFunctionSkeletonTermListError(function_skeleton, term_list, context.scopes.top().get_error_handler()(node, "")); } const auto function = context.factories.get_or_create_function(function_skeleton, term_list); context.positions.push_back(function, node); @@ -107,10 +107,10 @@ Function parse(const ast::FunctionHead& node, Context& context) FunctionSkeleton parse_function_skeleton_reference(const ast::FunctionSymbol& node, Context& context) { auto function_name = parse(node.name); - auto binding = context.scopes.get(function_name); + auto binding = context.scopes.top().get_function_skeleton(function_name); if (!binding.has_value()) { - throw UndefinedFunctionSkeletonError(function_name, context.scopes.get_error_handler()(node, "")); + throw UndefinedFunctionSkeletonError(function_name, context.scopes.top().get_error_handler()(node, "")); } const auto [function_skeleton, _position, _error_handler] = binding.value(); context.references.untrack(function_skeleton); @@ -120,10 +120,10 @@ FunctionSkeleton parse_function_skeleton_reference(const ast::FunctionSymbol& no static void test_multiple_definition(const FunctionSkeleton& function_skeleton, const ast::Name& node, const Context& context) { const auto function_name = function_skeleton->get_name(); - const auto binding = context.scopes.get(function_name); + const auto binding = context.scopes.top().get_function_skeleton(function_name); if (binding.has_value()) { - const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:"); + const auto message_1 = context.scopes.top().get_error_handler()(node, "Defined here:"); auto message_2 = std::string(""); const auto [_function_skeleton, position, error_handler] = binding.value(); if (position.has_value()) @@ -137,7 +137,7 @@ static void test_multiple_definition(const FunctionSkeleton& function_skeleton, static void insert_context_information(const FunctionSkeleton& function_skeleton, const ast::Name& node, Context& context) { context.positions.push_back(function_skeleton, node); - context.scopes.insert(function_skeleton->get_name(), function_skeleton, node); + context.scopes.top().insert_function_skeleton(function_skeleton->get_name(), function_skeleton, node); } FunctionSkeleton parse(const ast::AtomicFunctionSkeletonTotalCost& node, Context& context) @@ -153,8 +153,8 @@ FunctionSkeleton parse(const ast::AtomicFunctionSkeletonTotalCost& node, Context context.references.untrack(RequirementEnum::ACTION_COSTS); context.references.untrack(RequirementEnum::NUMERIC_FLUENTS); - assert(context.scopes.get("number").has_value()); - const auto [type, _position, _error_handler] = context.scopes.get("number").value(); + assert(context.scopes.top().get_type("number").has_value()); + const auto [type, _position, _error_handler] = context.scopes.top().get_type("number").value(); auto function_name = parse(node.function_symbol.name); auto function_skeleton = context.factories.get_or_create_function_skeleton(function_name, ParameterList {}, type); @@ -181,8 +181,8 @@ FunctionSkeleton parse(const ast::AtomicFunctionSkeletonGeneral& node, Context& auto function_parameters = boost::apply_visitor(ParameterListVisitor(context), node.arguments); context.scopes.close_scope(); - assert(context.scopes.get("number").has_value()); - const auto [type, _position, _error_handler] = context.scopes.get("number").value(); + assert(context.scopes.top().get_type("number").has_value()); + const auto [type, _position, _error_handler] = context.scopes.top().get_type("number").value(); auto function_name = parse(node.function_symbol.name); auto function_skeleton = context.factories.get_or_create_function_skeleton(function_name, function_parameters, type); @@ -313,7 +313,7 @@ Function parse(const ast::BasicFunctionTerm& node, Context& context) } if (function_skeleton->get_parameters().size() != term_list.size()) { - throw MismatchedFunctionSkeletonTermListError(function_skeleton, term_list, context.scopes.get_error_handler()(node, "")); + throw MismatchedFunctionSkeletonTermListError(function_skeleton, term_list, context.scopes.top().get_error_handler()(node, "")); } const auto function = context.factories.get_or_create_function(function_skeleton, term_list); context.positions.push_back(function, node); diff --git a/src/pddl/parser/ground_literal.cpp b/src/pddl/parser/ground_literal.cpp index f67d9028..3dd9b01f 100644 --- a/src/pddl/parser/ground_literal.cpp +++ b/src/pddl/parser/ground_literal.cpp @@ -28,10 +28,10 @@ namespace loki Atom parse(const ast::AtomicFormulaOfNamesPredicate& node, Context& context) { const auto name = parse(node.predicate.name); - const auto binding = context.scopes.get(name); + const auto binding = context.scopes.top().get_predicate(name); if (!binding.has_value()) { - throw UndefinedPredicateError(name, context.scopes.get_error_handler()(node, "")); + throw UndefinedPredicateError(name, context.scopes.top().get_error_handler()(node, "")); } auto term_list = TermList(); for (const auto& name_node : node.names) @@ -41,7 +41,7 @@ Atom parse(const ast::AtomicFormulaOfNamesPredicate& node, Context& context) const auto [predicate, _position, _error_handler] = binding.value(); if (predicate->get_parameters().size() != term_list.size()) { - throw MismatchedPredicateTermListError(predicate, term_list, context.scopes.get_error_handler()(node, "")); + throw MismatchedPredicateTermListError(predicate, term_list, context.scopes.top().get_error_handler()(node, "")); } const auto atom = context.factories.get_or_create_atom(predicate, term_list); context.positions.push_back(atom, node); @@ -52,10 +52,10 @@ Atom parse(const ast::AtomicFormulaOfNamesEquality& node, Context& context) { if (!context.requirements->test(RequirementEnum::EQUALITY)) { - throw UndefinedRequirementError(RequirementEnum::EQUALITY, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::EQUALITY, context.scopes.top().get_error_handler()(node, "")); } - assert(context.scopes.get("=").has_value()); - const auto [equal_predicate, _position, _error_handler] = context.scopes.get("=").value(); + assert(context.scopes.top().get_predicate("=").has_value()); + const auto [equal_predicate, _position, _error_handler] = context.scopes.top().get_predicate("=").value(); const auto term_left = context.factories.get_or_create_term_object(parse_object_reference(node.name_left, context)); const auto term_right = context.factories.get_or_create_term_object(parse_object_reference(node.name_right, context)); const auto atom = context.factories.get_or_create_atom(equal_predicate, TermList { term_left, term_right }); diff --git a/src/pddl/parser/literal.cpp b/src/pddl/parser/literal.cpp index 65a99f2f..a4c95ddc 100644 --- a/src/pddl/parser/literal.cpp +++ b/src/pddl/parser/literal.cpp @@ -26,10 +26,10 @@ namespace loki Atom parse(const ast::AtomicFormulaOfTermsPredicate& node, Context& context) { auto predicate_name = parse(node.predicate.name); - auto binding = context.scopes.get(predicate_name); + auto binding = context.scopes.top().get_predicate(predicate_name); if (!binding.has_value()) { - throw UndefinedPredicateError(predicate_name, context.scopes.get_error_handler()(node.predicate, "")); + throw UndefinedPredicateError(predicate_name, context.scopes.top().get_error_handler()(node.predicate, "")); } auto term_list = TermList(); for (const auto& term_node : node.terms) @@ -39,7 +39,7 @@ Atom parse(const ast::AtomicFormulaOfTermsPredicate& node, Context& context) const auto [predicate, _position, _error_handler] = binding.value(); if (predicate->get_parameters().size() != term_list.size()) { - throw MismatchedPredicateTermListError(predicate, term_list, context.scopes.get_error_handler()(node, "")); + throw MismatchedPredicateTermListError(predicate, term_list, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(predicate); const auto atom = context.factories.get_or_create_atom(predicate, term_list); @@ -52,11 +52,11 @@ Atom parse(const ast::AtomicFormulaOfTermsEquality& node, Context& context) // requires :equality if (!context.requirements->test(RequirementEnum::EQUALITY)) { - throw UndefinedRequirementError(RequirementEnum::EQUALITY, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::EQUALITY, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::EQUALITY); - assert(context.scopes.get("=").has_value()); - const auto [equal_predicate, _position, _error_handler] = context.scopes.get("=").value(); + assert(context.scopes.top().get_predicate("=").has_value()); + const auto [equal_predicate, _position, _error_handler] = context.scopes.top().get_predicate("=").value(); auto left_term = boost::apply_visitor(TermReferenceTermVisitor(context), node.term_left); auto right_term = boost::apply_visitor(TermReferenceTermVisitor(context), node.term_right); const auto atom = context.factories.get_or_create_atom(equal_predicate, TermList { left_term, right_term }); diff --git a/src/pddl/parser/objects.cpp b/src/pddl/parser/objects.cpp index f4745594..75ecd626 100644 --- a/src/pddl/parser/objects.cpp +++ b/src/pddl/parser/objects.cpp @@ -31,10 +31,10 @@ ObjectListVisitor::ObjectListVisitor(Context& context_) : context(context_) {} Object parse_object_reference(const ast::Name& name_node, Context& context) { const auto name = parse(name_node); - const auto binding = context.scopes.get(name); + const auto binding = context.scopes.top().get_object(name); if (!binding.has_value()) { - throw UndefinedObjectError(name, context.scopes.get_error_handler()(name_node, "")); + throw UndefinedObjectError(name, context.scopes.top().get_error_handler()(name_node, "")); } const auto [object, _position, _error_handler] = binding.value(); context.positions.push_back(object, name_node); @@ -44,10 +44,10 @@ Object parse_object_reference(const ast::Name& name_node, Context& context) static void test_multiple_definition_object(const std::string& object_name, const ast::Name& name_node, const Context& context) { - const auto binding = context.scopes.get(object_name); + const auto binding = context.scopes.top().get_object(object_name); if (binding.has_value()) { - const auto message_1 = context.scopes.get_error_handler()(name_node, "Defined here:"); + const auto message_1 = context.scopes.top().get_error_handler()(name_node, "Defined here:"); auto message_2 = std::string(""); const auto [_object, position, error_handler] = binding.value(); if (position.has_value()) @@ -64,7 +64,7 @@ static Object parse_object_definition(const ast::Name& name_node, const TypeList test_multiple_definition_object(name, name_node, context); const auto object = context.factories.get_or_create_object(name, type_list); context.positions.push_back(object, name_node); - context.scopes.insert(name, object, name_node); + context.scopes.top().insert_object(name, object, name_node); return object; } @@ -81,8 +81,8 @@ static ObjectList parse_object_definitions(const std::vector& name_no ObjectList ObjectListVisitor::operator()(const std::vector& name_nodes) { // std::vector has single base type "object" - assert(context.scopes.get("object").has_value()); - const auto [type, _position, _error_handler] = context.scopes.get("object").value(); + assert(context.scopes.top().get_type("object").has_value()); + const auto [type, _position, _error_handler] = context.scopes.top().get_type("object").value(); auto object_list = parse_object_definitions(name_nodes, TypeList { type }, context); return object_list; } @@ -91,7 +91,7 @@ ObjectList ObjectListVisitor::operator()(const ast::TypedListOfNamesRecursively& { if (!context.requirements->test(RequirementEnum::TYPING)) { - throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::TYPING); // TypedListOfNamesRecursively has user defined base types diff --git a/src/pddl/parser/parameters.cpp b/src/pddl/parser/parameters.cpp index 3c186d55..c9dd4dbc 100644 --- a/src/pddl/parser/parameters.cpp +++ b/src/pddl/parser/parameters.cpp @@ -28,10 +28,10 @@ namespace loki static void test_multiple_definition(const Variable& variable, const ast::Variable& node, const Context& context) { - const auto binding = context.scopes.get(variable->get_name()); + const auto binding = context.scopes.top().get_variable(variable->get_name()); if (binding.has_value()) { - const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:"); + const auto message_1 = context.scopes.top().get_error_handler()(node, "Defined here:"); auto message_2 = std::string(""); const auto [_variable, position, error_handler] = binding.value(); if (position.has_value()) @@ -44,7 +44,7 @@ static void test_multiple_definition(const Variable& variable, const ast::Variab static void insert_context_information(const Variable& variable, const ast::Variable& node, Context& context) { - context.scopes.insert(variable->get_name(), variable, node); + context.scopes.top().insert_variable(variable->get_name(), variable, node); } static Parameter parse_parameter_definition(const ast::Variable& variable_node, const TypeList& type_list, Context& context) @@ -83,7 +83,7 @@ ParameterList ParameterListVisitor::operator()(const ast::TypedListOfVariablesRe // requires :typing if (!context.requirements->test(RequirementEnum::TYPING)) { - throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::TYPING); const auto type_list = boost::apply_visitor(TypeReferenceTypeVisitor(context), node.type); diff --git a/src/pddl/parser/predicates.cpp b/src/pddl/parser/predicates.cpp index bbfee032..b1106159 100644 --- a/src/pddl/parser/predicates.cpp +++ b/src/pddl/parser/predicates.cpp @@ -27,10 +27,10 @@ namespace loki static void test_multiple_definition(const Predicate& predicate, const ast::Predicate& node, const Context& context) { const auto predicate_name = predicate->get_name(); - const auto binding = context.scopes.get(predicate_name); + const auto binding = context.scopes.top().get_predicate(predicate_name); if (binding.has_value()) { - const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:"); + const auto message_1 = context.scopes.top().get_error_handler()(node, "Defined here:"); auto message_2 = std::string(""); const auto [_predicate, position, error_handler] = binding.value(); if (position.has_value()) @@ -44,7 +44,7 @@ static void test_multiple_definition(const Predicate& predicate, const ast::Pred static void insert_context_information(const Predicate& predicate, const ast::Predicate& node, Context& context) { context.positions.push_back(predicate, node); - context.scopes.insert(predicate->get_name(), predicate, node); + context.scopes.top().insert_predicate(predicate->get_name(), predicate, node); } static Predicate parse_predicate_definition(const ast::AtomicFormulaSkeleton& node, Context& context) diff --git a/src/pddl/parser/reference_utils.cpp b/src/pddl/parser/reference_utils.cpp index 7a993f02..90ae513f 100644 --- a/src/pddl/parser/reference_utils.cpp +++ b/src/pddl/parser/reference_utils.cpp @@ -36,12 +36,39 @@ void test_variable_references(const ParameterList& parameter_list, const Context { if (context.references.exists(parameter->get_variable())) { - const auto [variable, position, error_handler] = context.scopes.get(parameter->get_variable()->get_name()).value(); + const auto [variable, position, error_handler] = context.scopes.top().get_variable(parameter->get_variable()->get_name()).value(); throw UnusedVariableError(variable->get_name(), error_handler(position.value(), "")); } } } +void track_variable_type_information(const ParameterList& parameter_list, Context& context) +{ + for (const auto& parameter : parameter_list) + { + context.scopes.top().insert_variable_types(parameter->get_variable(), collect_types_from_hierarchy(parameter->get_bases())); + } +} + +void test_object_type_consistent_with_variable(const Object& object, const Variable& variable, const Context& context) +{ + const auto& variable_types = context.scopes.top().get_variable_types(variable); + bool is_consistent = false; + for (const auto& type : collect_types_from_hierarchy(object->get_bases())) + { + if (variable_types.count(type)) + { + is_consistent = true; + break; + } + } + if (!is_consistent) + { + const auto [_object, position, error_handler] = context.scopes.top().get_object(object->get_name()).value(); + throw IncompatibleObjectTypeError(object, variable, error_handler(position.value(), "")); + } +} + void track_predicate_references(const PredicateList& predicate_list, Context& context) { for (const auto& predicate : predicate_list) @@ -56,7 +83,7 @@ void test_predicate_references(const PredicateList& predicate_list, const Contex { if (context.references.exists(predicate)) { - const auto [_predicate, position, error_handler] = context.scopes.get(predicate->get_name()).value(); + const auto [_predicate, position, error_handler] = context.scopes.top().get_predicate(predicate->get_name()).value(); throw UnusedPredicateError(predicate->get_name(), error_handler(position.value(), "")); } } @@ -76,7 +103,7 @@ void test_function_skeleton_references(const FunctionSkeletonList& function_skel { if (context.references.exists(function_skeleton)) { - const auto [_function_skeleton, position, error_handler] = context.scopes.get(function_skeleton->get_name()).value(); + const auto [_function_skeleton, position, error_handler] = context.scopes.top().get_function_skeleton(function_skeleton->get_name()).value(); throw UnusedFunctionSkeletonError(function_skeleton->get_name(), error_handler(position.value(), "")); } } @@ -96,7 +123,7 @@ void test_object_references(const ObjectList& object_list, const Context& contex { if (context.references.exists(object)) { - const auto [_object, position, error_handler] = context.scopes.get(object->get_name()).value(); + const auto [_object, position, error_handler] = context.scopes.top().get_object(object->get_name()).value(); throw UnusedObjectError(object->get_name(), error_handler(position.value(), "")); } } diff --git a/src/pddl/parser/reference_utils.hpp b/src/pddl/parser/reference_utils.hpp index b6579785..37bb209f 100644 --- a/src/pddl/parser/reference_utils.hpp +++ b/src/pddl/parser/reference_utils.hpp @@ -31,6 +31,10 @@ extern void track_variable_references(const ParameterList& parameter_list, Conte extern void test_variable_references(const ParameterList& parameter_list, const Context& context); +extern void track_variable_type_information(const ParameterList& parameter_list, Context& context); + +extern void test_object_type_consistent_with_variable(const Object& object, const Variable& variable, const Context& context); + extern void track_predicate_references(const PredicateList& predicate_list, Context& context); extern void test_predicate_references(const PredicateList& predicate_list, const Context& context); diff --git a/src/pddl/parser/requirements.cpp b/src/pddl/parser/requirements.cpp index b4f262ba..0a28c83d 100644 --- a/src/pddl/parser/requirements.cpp +++ b/src/pddl/parser/requirements.cpp @@ -99,7 +99,7 @@ RequirementEnumSet parse(const ast::RequirementObjectFluents& node, Context& con { // Track context.references.track(RequirementEnum::OBJECT_FLUENTS); - throw UnsupportedRequirementError(RequirementEnum::OBJECT_FLUENTS, context.scopes.get_error_handler()(node, "")); + throw UnsupportedRequirementError(RequirementEnum::OBJECT_FLUENTS, context.scopes.top().get_error_handler()(node, "")); return { RequirementEnum::OBJECT_FLUENTS }; } @@ -138,7 +138,7 @@ RequirementEnumSet parse(const ast::RequirementAdl&, Context& context) RequirementEnumSet parse(const ast::RequirementDurativeActions& node, Context& context) { - throw UnsupportedRequirementError(RequirementEnum::DURATIVE_ACTIONS, context.scopes.get_error_handler()(node, "")); + throw UnsupportedRequirementError(RequirementEnum::DURATIVE_ACTIONS, context.scopes.top().get_error_handler()(node, "")); // Track context.references.track(RequirementEnum::DURATIVE_ACTIONS); @@ -154,7 +154,7 @@ RequirementEnumSet parse(const ast::RequirementDerivedPredicates&, Context& cont RequirementEnumSet parse(const ast::RequirementTimedInitialLiterals& node, Context& context) { - throw UnsupportedRequirementError(RequirementEnum::TIMED_INITIAL_LITERALS, context.scopes.get_error_handler()(node, "")); + throw UnsupportedRequirementError(RequirementEnum::TIMED_INITIAL_LITERALS, context.scopes.top().get_error_handler()(node, "")); // Track context.references.track(RequirementEnum::TIMED_INITIAL_LITERALS); @@ -164,7 +164,7 @@ RequirementEnumSet parse(const ast::RequirementTimedInitialLiterals& node, Conte RequirementEnumSet parse(const ast::RequirementPreferences& node, Context& context) { - throw UnsupportedRequirementError(RequirementEnum::PREFERENCES, context.scopes.get_error_handler()(node, "")); + throw UnsupportedRequirementError(RequirementEnum::PREFERENCES, context.scopes.top().get_error_handler()(node, "")); // Track context.references.track(RequirementEnum::PREFERENCES); @@ -173,7 +173,7 @@ RequirementEnumSet parse(const ast::RequirementPreferences& node, Context& conte RequirementEnumSet parse(const ast::RequirementConstraints& node, Context& context) { - throw UnsupportedRequirementError(RequirementEnum::CONSTRAINTS, context.scopes.get_error_handler()(node, "")); + throw UnsupportedRequirementError(RequirementEnum::CONSTRAINTS, context.scopes.top().get_error_handler()(node, "")); // Track context.references.track(RequirementEnum::CONSTRAINTS); diff --git a/src/pddl/parser/structure.cpp b/src/pddl/parser/structure.cpp index 1b3fadab..0f774786 100644 --- a/src/pddl/parser/structure.cpp +++ b/src/pddl/parser/structure.cpp @@ -64,14 +64,14 @@ Axiom parse(const ast::Axiom& node, Context& context) { if (!context.requirements->test(RequirementEnum::DERIVED_PREDICATES)) { - throw UndefinedRequirementError(RequirementEnum::DERIVED_PREDICATES, context.scopes.get_error_handler()(node, "")); + throw UndefinedRequirementError(RequirementEnum::DERIVED_PREDICATES, context.scopes.top().get_error_handler()(node, "")); } context.references.untrack(RequirementEnum::DERIVED_PREDICATES); const auto literal = parse(node.literal, context); if (context.derived_predicates.count(literal->get_atom()->get_predicate()) == 0) { - throw ExpectedDerivedPredicate(literal->get_atom()->get_predicate()->get_name(), context.scopes.get_error_handler()(node, "")); + throw ExpectedDerivedPredicate(literal->get_atom()->get_predicate()->get_name(), context.scopes.top().get_error_handler()(node, "")); } const auto condition = parse(node.goal_descriptor, context); diff --git a/src/pddl/parser/types.cpp b/src/pddl/parser/types.cpp index 2652b771..6e16fdb2 100644 --- a/src/pddl/parser/types.cpp +++ b/src/pddl/parser/types.cpp @@ -69,7 +69,7 @@ TypeReferenceTypeVisitor::TypeReferenceTypeVisitor(const Context& context_) : co TypeList TypeReferenceTypeVisitor::operator()(const ast::TypeObject&) { - const auto binding = context.scopes.get("object"); + const auto binding = context.scopes.top().get_type("object"); assert(binding.has_value()); const auto [type, _position, _error_handler] = binding.value(); return { type }; @@ -77,7 +77,7 @@ TypeList TypeReferenceTypeVisitor::operator()(const ast::TypeObject&) TypeList TypeReferenceTypeVisitor::operator()(const ast::TypeNumber&) { - const auto binding = context.scopes.get("number"); + const auto binding = context.scopes.top().get_type("number"); assert(binding.has_value()); const auto [type, _position, _error_handler] = binding.value(); return { type }; @@ -86,10 +86,10 @@ TypeList TypeReferenceTypeVisitor::operator()(const ast::TypeNumber&) TypeList TypeReferenceTypeVisitor::operator()(const ast::Name& node) { auto name = parse(node); - auto binding = context.scopes.get(name); + auto binding = context.scopes.top().get_type(name); if (!binding.has_value()) { - throw UndefinedTypeError(name, context.scopes.get_error_handler()(node, "")); + throw UndefinedTypeError(name, context.scopes.top().get_error_handler()(node, "")); } const auto [type, _position, _error_handler] = binding.value(); context.positions.push_back(type, node); @@ -112,10 +112,10 @@ TypeList TypeReferenceTypeVisitor::operator()(const ast::TypeEither& node) static void test_multiple_definition(const Type& type, const ast::Name& node, const Context& context) { const auto type_name = type->get_name(); - const auto binding = context.scopes.get(type_name); + const auto binding = context.scopes.top().get_type(type_name); if (binding.has_value()) { - const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:"); + const auto message_1 = context.scopes.top().get_error_handler()(node, "Defined here:"); auto message_2 = std::string(""); const auto [_type, position, error_handler] = binding.value(); if (position.has_value()) @@ -130,20 +130,20 @@ static void test_reserved_type(const Type& type, const ast::Name& node, const Co { if (type->get_name() == "object") { - throw ReservedTypeError("object", context.scopes.get_error_handler()(node, "")); + throw ReservedTypeError("object", context.scopes.top().get_error_handler()(node, "")); } // We also reserve type name number although PDDL specification allows it. // However, this allows using regular types as function types for simplicity. if (type->get_name() == "number") { - throw ReservedTypeError("number", context.scopes.get_error_handler()(node, "")); + throw ReservedTypeError("number", context.scopes.top().get_error_handler()(node, "")); } } static void insert_context_information(const Type& type, const ast::Name& node, Context& context) { context.positions.push_back(type, node); - context.scopes.insert(type->get_name(), type, node); + context.scopes.top().insert_type(type->get_name(), type, node); } static Type parse_type_definition(const ast::Name& node, const TypeList& type_list, Context& context) @@ -171,8 +171,8 @@ TypeDeclarationTypedListOfNamesVisitor::TypeDeclarationTypedListOfNamesVisitor(C TypeList TypeDeclarationTypedListOfNamesVisitor::operator()(const std::vector& name_nodes) { // std::vector has single base type "object" - assert(context.scopes.get("object").has_value()); - const auto [type_object, _position, _error_handler] = context.scopes.get("object").value(); + assert(context.scopes.top().get_type("object").has_value()); + const auto [type_object, _position, _error_handler] = context.scopes.top().get_type("object").value(); const auto type_list = parse_type_definitions(name_nodes, TypeList { type_object }, context); return type_list; } @@ -182,7 +182,7 @@ TypeList TypeDeclarationTypedListOfNamesVisitor::operator()(const ast::TypedList // requires :typing if (!context.requirements->test(RequirementEnum::TYPING)) { - throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.get_error_handler()(typed_list_of_names_recursively_node, "")); + throw UndefinedRequirementError(RequirementEnum::TYPING, context.scopes.top().get_error_handler()(typed_list_of_names_recursively_node, "")); } context.references.untrack(RequirementEnum::TYPING); // TypedListOfNamesRecursively has user defined base types. diff --git a/src/pddl/scope.cpp b/src/pddl/scope.cpp index d4be1d38..ed4a4994 100644 --- a/src/pddl/scope.cpp +++ b/src/pddl/scope.cpp @@ -19,11 +19,150 @@ namespace loki { -Scope::Scope(const Scope* parent_scope) : m_parent_scope(parent_scope) {} +Scope::Scope(const PDDLErrorHandler& error_handler, const Scope* parent_scope) : m_error_handler(error_handler), m_parent_scope(parent_scope) {} + +std::optional> Scope::get_type(const std::string& name) const +{ + const auto it = m_types.find(name); + if (it != m_types.end()) + return std::make_tuple(it->second.first, it->second.second, m_error_handler); + if (m_parent_scope) + { + return m_parent_scope->get_type(name); + } + return std::nullopt; +} + +std::optional> Scope::get_object(const std::string& name) const +{ + const auto it = m_objects.find(name); + if (it != m_objects.end()) + return std::make_tuple(it->second.first, it->second.second, m_error_handler); + if (m_parent_scope) + { + return m_parent_scope->get_object(name); + } + return std::nullopt; +} + +std::optional> Scope::get_function_skeleton(const std::string& name) const +{ + const auto it = m_function_skeletons.find(name); + if (it != m_function_skeletons.end()) + return std::make_tuple(it->second.first, it->second.second, m_error_handler); + if (m_parent_scope) + { + return m_parent_scope->get_function_skeleton(name); + } + return std::nullopt; +} + +std::optional> Scope::get_variable(const std::string& name) const +{ + const auto it = m_variables.find(name); + if (it != m_variables.end()) + return std::make_tuple(it->second.first, it->second.second, m_error_handler); + if (m_parent_scope) + { + return m_parent_scope->get_variable(name); + } + return std::nullopt; +} + +std::optional> Scope::get_predicate(const std::string& name) const +{ + const auto it = m_predicates.find(name); + if (it != m_predicates.end()) + return std::make_tuple(it->second.first, it->second.second, m_error_handler); + if (m_parent_scope) + { + return m_parent_scope->get_predicate(name); + } + return std::nullopt; +} + +std::optional> Scope::get_derived_predicate(const std::string& name) const +{ + const auto it = m_derived_predicates.find(name); + if (it != m_derived_predicates.end()) + return std::make_tuple(it->second.first, it->second.second, m_error_handler); + if (m_parent_scope) + { + return m_parent_scope->get_derived_predicate(name); + } + return std::nullopt; +} + +const TypeSet& Scope::get_variable_types(const Variable& variable) const +{ + const auto it = m_variable_types.find(variable); + if (it != m_variable_types.end()) + { + return it->second; + } + if (m_parent_scope) + { + return m_parent_scope->get_variable_types(variable); + } + else + { + throw std::logic_error("Expected types of variable to be known."); + } +} + +void Scope::insert_type(const std::string& name, const Type& element, const std::optional& position) +{ + assert(!this->get_type(name)); + m_types.emplace(name, BindingValueType(element, position)); +} + +void Scope::insert_object(const std::string& name, const Object& element, const std::optional& position) +{ + assert(!this->get_object(name)); + m_objects.emplace(name, BindingValueType(element, position)); +} + +void Scope::insert_function_skeleton(const std::string& name, const FunctionSkeleton& element, const std::optional& position) +{ + assert(!this->get_function_skeleton(name)); + m_function_skeletons.emplace(name, BindingValueType(element, position)); +} + +void Scope::insert_variable(const std::string& name, const Variable& element, const std::optional& position) +{ + assert(!this->get_variable(name)); + m_variables.emplace(name, BindingValueType(element, position)); +} + +void Scope::insert_predicate(const std::string& name, const Predicate& element, const std::optional& position) +{ + assert(!this->get_predicate(name)); + m_predicates.emplace(name, BindingValueType(element, position)); +} + +void Scope::insert_derived_predicate(const std::string& name, const Predicate& element, const std::optional& position) +{ + assert(!this->get_derived_predicate(name)); + m_derived_predicates.emplace(name, BindingValueType(element, position)); +} + +void Scope::insert_variable_types(const Variable& variable, const TypeSet& types) +{ + assert(!this->m_variable_types.count(variable)); + m_variable_types.emplace(variable, types); +} + +const PDDLErrorHandler& Scope::get_error_handler() const { return m_error_handler; } ScopeStack::ScopeStack(const PDDLErrorHandler& error_handler, const ScopeStack* parent) : m_error_handler(error_handler), m_parent(parent) {} -void ScopeStack::open_scope() { m_stack.push_back(m_stack.empty() ? std::make_unique() : std::make_unique(m_stack.back().get())); } +void ScopeStack::open_scope() +{ + // Link to parent Scope across parent ScopeStacks. + m_stack.push_back(m_stack.empty() ? + (m_parent ? std::make_unique(m_error_handler, m_parent->m_stack.back().get()) : std::make_unique(m_error_handler)) : + std::make_unique(m_error_handler, m_stack.back().get())); +} void ScopeStack::close_scope() { @@ -31,7 +170,12 @@ void ScopeStack::close_scope() m_stack.pop_back(); } -const PDDLErrorHandler& ScopeStack::get_error_handler() const { return m_error_handler; } +Scope& ScopeStack::top() +{ + assert(!m_stack.empty()); + return *m_stack.back(); +} +const Scope& ScopeStack::top() const { return top(); } const std::deque>& ScopeStack::get_stack() const { return m_stack; } diff --git a/src/pddl/type.cpp b/src/pddl/type.cpp index 2eb869ab..dedaec92 100644 --- a/src/pddl/type.cpp +++ b/src/pddl/type.cpp @@ -59,4 +59,19 @@ const std::string& TypeImpl::get_name() const { return m_name; } const TypeList& TypeImpl::get_bases() const { return m_bases; } +static void collect_types_from_hierarchy_recursively(const TypeList& types, TypeSet& ref_result) +{ + for (const auto& type : types) + { + ref_result.insert(type); + collect_types_from_hierarchy_recursively(type->get_bases(), ref_result); + } +} + +TypeSet collect_types_from_hierarchy(const TypeList& types) +{ + auto result = TypeSet {}; + collect_types_from_hierarchy_recursively(types, result); + return result; +} } diff --git a/src/utils/parser.cpp b/src/utils/parser.cpp index 24f70a4a..c3d2e06d 100644 --- a/src/utils/parser.cpp +++ b/src/utils/parser.cpp @@ -63,8 +63,8 @@ DomainParser::DomainParser(const fs::path& file_path) : // Create base types. const auto base_type_object = context.factories.get_or_create_type("object", TypeList()); const auto base_type_number = context.factories.get_or_create_type("number", TypeList()); - context.scopes.insert("object", base_type_object, {}); - context.scopes.insert("number", base_type_number, {}); + context.scopes.top().insert_type("object", base_type_object, {}); + context.scopes.top().insert_type("number", base_type_number, {}); // Create equal predicate with name "=" and two parameters "?left_arg" and "?right_arg" const auto binary_parameterlist = @@ -73,9 +73,8 @@ DomainParser::DomainParser(const fs::path& file_path) : }; const auto equal_predicate = context.factories.get_or_create_predicate("=", binary_parameterlist); - context.scopes.insert("=", equal_predicate, {}); + context.scopes.top().insert_predicate("=", equal_predicate, {}); - std::cout << "Started semantic domain file: " << file_path << std::endl; m_domain = parse(node, context); // Only the global scope remains