diff --git a/compiler/passes/convert-typed-uast.cpp b/compiler/passes/convert-typed-uast.cpp index d485326c2af9..9dbe16c4eb3c 100644 --- a/compiler/passes/convert-typed-uast.cpp +++ b/compiler/passes/convert-typed-uast.cpp @@ -1485,6 +1485,18 @@ Type* TConverter::helpConvertType(const types::Type* t) { case typetags::CPtrType: return helpConvertPtrType(t->toPtrType()); case typetags::HeapBufferType: return helpConvertPtrType(t->toPtrType()); + // Interfaces require something clever (creating a constrained generic + // function), and we don't have that yet. + case typetags::InterfaceType: + CHPL_UNIMPL("convert interface type"); + return dtUnknown; // TODO + + // placeholders only occur in interface resolution and should not be + // reachable + case typetags::PlaceholderType: + INT_FATAL("should not be reachable"); + return dtUnknown; + // implementation detail tags (should not be reachable) case typetags::START_ManageableType: case typetags::END_ManageableType: diff --git a/frontend/include/chpl/parsing/parsing-queries.h b/frontend/include/chpl/parsing/parsing-queries.h index 15cba9fdbd87..8d36e1e45870 100644 --- a/frontend/include/chpl/parsing/parsing-queries.h +++ b/frontend/include/chpl/parsing/parsing-queries.h @@ -521,6 +521,11 @@ bool idIsPrivateDecl(Context* context, ID id); */ bool idIsFunction(Context* context, ID id); +/** + Returns true if the ID is an interface + */ +bool idIsInterface(Context* context, ID id); + /** Returns true if the ID is marked 'extern'. */ @@ -557,6 +562,11 @@ const ID& idToParentId(Context* context, ID id); */ ID idToParentFunctionId(Context* context, ID id); +/** + Returns the parent interface ID given an ID. + */ +ID idToParentInterfaceId(Context* context, ID id); + /** Returns the parent AST node given an AST node */ diff --git a/frontend/include/chpl/resolution/ResolutionContext.h b/frontend/include/chpl/resolution/ResolutionContext.h index eda8f68ea353..a5e7478f9c5a 100644 --- a/frontend/include/chpl/resolution/ResolutionContext.h +++ b/frontend/include/chpl/resolution/ResolutionContext.h @@ -40,6 +40,7 @@ class ResolvedFunction; class TypedFnSignature; class UntypedFnSignature; class MatchingIdsWithName; +class ImplementationWitness; /** This class is used to manage stack frames that may be necessary while @@ -113,7 +114,7 @@ class ResolutionContext { when they are destroyed. */ class Frame { public: - enum Kind { FUNCTION, MODULE, UNKNOWN }; + enum Kind { FUNCTION, MODULE, INTERFACE, UNKNOWN }; private: friend class ResolutionContext; @@ -126,6 +127,8 @@ class ResolutionContext { Resolver* rv_ = nullptr; const ResolvedFunction* rf_ = nullptr; + const types::InterfaceType* ift_ = nullptr; + const ImplementationWitness* witness_ = nullptr; int64_t index_ = BASE_FRAME_INDEX; Store cachedResults_; Kind kind_ = UNKNOWN; @@ -137,6 +140,9 @@ class ResolutionContext { Frame(const ResolvedFunction* rf, int64_t index) : rf_(rf), index_(index), kind_(FUNCTION) { } + Frame(const types::InterfaceType* ift, const ImplementationWitness* witness, int64_t index) + : ift_(ift), witness_(witness), index_(index), kind_(INTERFACE) { + } public: ~Frame() = default; @@ -155,12 +161,14 @@ class ResolutionContext { } Resolver* rv() { return rv_; } - const ResolvedFunction* rf() { return rf_; } + const ResolvedFunction* rf() const { return rf_; } + const types::InterfaceType* ift() const { return ift_; } + const ImplementationWitness* witness() const { return witness_; } bool isEmpty() { return !rv() && !rf(); } const ID& id() const; const TypedFnSignature* signature() const; - const ResolutionResultByPostorderID* resolutionById() const; + const types::QualifiedType typeForContainedId(ResolutionContext* rc, const ID& id) const; bool isUnstable() const; template @@ -183,6 +191,7 @@ class ResolutionContext { Frame baseFrame_; const Frame* pushFrame(const ResolvedFunction* rf); + const Frame* pushFrame(const types::InterfaceType* t, const ImplementationWitness* witness); const Frame* pushFrame(Resolver* rv, Frame::Kind kind); void popFrame(const ResolvedFunction* rf); void popFrame(Resolver* rv); diff --git a/frontend/include/chpl/resolution/can-pass.h b/frontend/include/chpl/resolution/can-pass.h index 0f6607f20204..ef547fab0fe5 100644 --- a/frontend/include/chpl/resolution/can-pass.h +++ b/frontend/include/chpl/resolution/can-pass.h @@ -204,6 +204,17 @@ CanPassResult canPassScalar(Context* context, return CanPassResult::canPassScalar(context, actualType, formalType); } +/* Returns true if, all other things equal, a type with substitutions + 'instances' is an instantiation of a type with substitutions 'generics'. + + If 'allowMissing' is true, considers missing substitutions in 'generics' + to be "any type". Otherwise, requires that each susbtitution in + instances is matched by an existing substitution in generics. */ +bool canInstantiateSubstitutions(Context* context, + const SubstitutionsMap& instances, + const SubstitutionsMap& generics, + bool allowMissing); + /* When trying to combine two kinds, you can't just pick one. For instance, if any type in the list is a value, the result should be a value, and if any type in the list is const, the diff --git a/frontend/include/chpl/resolution/interface-types.h b/frontend/include/chpl/resolution/interface-types.h new file mode 100644 index 000000000000..96e3633dcace --- /dev/null +++ b/frontend/include/chpl/resolution/interface-types.h @@ -0,0 +1,184 @@ +/* + * Copyright 2024 Hewlett Packard Enterprise Development LP + * Other additional copyright holders may be indicated within. + * + * The entirety of this work is licensed under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CHPL_RESOLUTION_INTERFACE_TYPES_H +#define CHPL_RESOLUTION_INTERFACE_TYPES_H + +#include "chpl/framework/ID.h" +#include "chpl/framework/update-functions.h" +#include "chpl/types/InterfaceType.h" +#include "chpl/util/hash.h" + +namespace chpl { +namespace resolution { + +/* + Represent a resolved form of an 'implements'-like statement, which asserts + that a (possibly-generic) list of types satisfies an interface. All + of the following are transformed into implementation points: + + record R : I {} + implements R(I); + R implements I; + + The types of the "actuals" (R in the above examples) are stored as substitutions + for the InterfaceType. This means that in the case of a concrete implementation + point, its interface() will be equal to constraint being searched for. + */ +class ImplementationPoint { + private: + // The interface being implemented (instantiated with the types of the arguments) + const types::InterfaceType* interface_; + // The ID of the implementation statement + ID id_; + + ImplementationPoint(const types::InterfaceType* interface, + ID id) + : interface_(interface), id_(id) {} + + static owned const& + getImplementationPoint(Context* context, + const types::InterfaceType* interface, + ID id); + + public: + static const ImplementationPoint* + get(Context* context, const types::InterfaceType* interface, ID id); + + static bool update(owned& lhs, + owned& rhs) { + return defaultUpdateOwned(lhs, rhs); + } + bool operator==(const ImplementationPoint& other) const { + return interface_ == other.interface_ && + id_ == other.id_; + } + bool operator!=(const ImplementationPoint& other) const { + return !(*this == other); + } + void mark(Context* context) const; + void stringify(std::ostream& ss, chpl::StringifyKind stringKind) const; + + /* get the (possibly-generic) interface being implemented, which includes + substitutions from the actuals. */ + const types::InterfaceType* interface() const { return interface_; } + + /* get the ID of the implementation point (statement or inheritance expression) */ + const ID& id() const { return id_; } +}; + +/* + Represents evidence that a particular type or list of types implements + a given interface. This includes: + + * associated constraints (e.g., a 'totalOrder' interface requires a + 'partialOrder' interface to be satisfied). + * associated types (e.g., a 'collection' interface requires a 'Element' + type to be defined; in the standard library, the 'contextManager' interface + specifies the type of the underlying resource). + * required functions (e.g., a 'hashable' interface requires a 'hash' function + to be defined). + */ +class ImplementationWitness { + public: + using ConstraintMap = std::unordered_map; + using AssociatedTypeMap = types::PlaceholderMap; + using FunctionMap = std::unordered_map; + + private: + + ConstraintMap associatedConstraints_; + AssociatedTypeMap associatedTypes_; + FunctionMap requiredFns_; + + ImplementationWitness(ConstraintMap associatedConstraints, + AssociatedTypeMap associatedTypes, + FunctionMap requiredFns) + : associatedConstraints_(std::move(associatedConstraints)), + associatedTypes_(std::move(associatedTypes)), + requiredFns_(std::move(requiredFns)) {} + + static const owned& + getImplementationWitness(Context* context, + ConstraintMap associatedConstraints, + AssociatedTypeMap associatedTypes, + FunctionMap requiredFns); + + public: + static ImplementationWitness* + get(Context* context, + ConstraintMap associatedConstraints, + AssociatedTypeMap associatedTypes, + FunctionMap requiredFns); + + static bool update(owned& lhs, + owned& rhs) { + return defaultUpdateOwned(lhs, rhs); + } + bool operator==(const ImplementationWitness& other) const { + return associatedConstraints_ == other.associatedConstraints_ && + associatedTypes_ == other.associatedTypes_ && + requiredFns_ == other.requiredFns_; + } + bool operator!=(const ImplementationWitness& other) const { + return !(*this == other); + } + void mark(Context* context) const { + chpl::mark{}(context, associatedTypes_); + chpl::mark{}(context, requiredFns_); + chpl::mark{}(context, associatedConstraints_); + } + void stringify(std::ostream& ss, chpl::StringifyKind stringKind) const; + + /** Get the associated constraints. */ + const ConstraintMap& associatedConstraints() const { + return associatedConstraints_; + } + + /** Get the associated types. */ + const AssociatedTypeMap& associatedTypes() const { + return associatedTypes_; + } + + /** Get the required functions. */ + const FunctionMap& requiredFns() const { + return requiredFns_; + } +}; + +} // end namespace resolution +} // end namespace chpl + + +namespace std { + +template<> struct hash { + size_t operator()(const chpl::resolution::ImplementationWitness::ConstraintMap& key) const { + return chpl::hashUnorderedMap(key); + } +}; + +template<> struct hash { + size_t operator()(const chpl::resolution::ImplementationWitness::FunctionMap& key) const { + return chpl::hashUnorderedMap(key); + } +}; + +} +#endif diff --git a/frontend/include/chpl/resolution/resolution-error-classes-list.h b/frontend/include/chpl/resolution/resolution-error-classes-list.h index b9ef31241201..b9519543b821 100644 --- a/frontend/include/chpl/resolution/resolution-error-classes-list.h +++ b/frontend/include/chpl/resolution/resolution-error-classes-list.h @@ -55,8 +55,18 @@ ERROR_CLASS(IncompatibleKinds, types::QualifiedType::Kind, const uast::AstNode*, ERROR_CLASS(IncompatibleRangeBounds, const uast::Range*, types::QualifiedType, types::QualifiedType) ERROR_CLASS(IncompatibleTypeAndInit, const uast::AstNode*, const uast::AstNode*, const uast::AstNode*, const types::Type*, const types::Type*) ERROR_CLASS(IncompatibleYieldTypes, const uast::AstNode*, std::vector>) +ERROR_CLASS(InterfaceAmbiguousFn, const types::InterfaceType*, ID, const uast::Function*, std::vector) +ERROR_CLASS(InterfaceInvalidIntent, const types::InterfaceType*, ID, const resolution::TypedFnSignature*, const resolution::TypedFnSignature*) +ERROR_CLASS(InterfaceMissingAssociatedType, const types::InterfaceType*, ID, const uast::Variable*, resolution::CallInfo, std::vector) +ERROR_CLASS(InterfaceMissingFn, const types::InterfaceType*, ID, const resolution::TypedFnSignature*, resolution::CallInfo, std::vector) +ERROR_CLASS(InterfaceMultipleImplements, const uast::AggregateDecl*, const types::InterfaceType*, ID, ID) +ERROR_CLASS(InterfaceNaryInInherits, const uast::AggregateDecl*, const types::InterfaceType*, ID) +ERROR_CLASS(InterfaceReorderedFnFormals, const types::InterfaceType*, ID, const resolution::TypedFnSignature*, const resolution::TypedFnSignature*) ERROR_CLASS(InvalidClassCast, const uast::PrimCall*, types::QualifiedType) ERROR_CLASS(InvalidDomainCall, const uast::FnCall*, std::vector) +ERROR_CLASS(InvalidImplementsActual, const uast::Implements*, const uast::AstNode*, types::QualifiedType) +ERROR_CLASS(InvalidImplementsArity, const uast::Implements*, const types::InterfaceType*, std::vector) +ERROR_CLASS(InvalidImplementsInterface, const uast::Implements*, types::QualifiedType) ERROR_CLASS(InvalidIndexCall, const uast::FnCall*, types::QualifiedType) ERROR_CLASS(InvalidNewTarget, const uast::New*, types::QualifiedType) ERROR_CLASS(InvalidParamCast, const uast::AstNode*, types::QualifiedType, types::QualifiedType) @@ -71,6 +81,7 @@ ERROR_CLASS(MultipleInheritance, const uast::Class*, const uast::AstNode*, const ERROR_CLASS(MultipleQuestionArgs, const uast::FnCall*, const uast::AstNode*, const uast::AstNode*) ERROR_CLASS(NestedClassFieldRef, const uast::TypeDecl*, const uast::TypeDecl*, const uast::AstNode*, ID) ERROR_CLASS(NoMatchingCandidates, const uast::AstNode*, resolution::CallInfo, std::vector) +ERROR_CLASS(NonClassInheritance, const uast::AggregateDecl*, const uast::AstNode*, const types::Type*) ERROR_CLASS(NonIterable, const uast::AstNode*, const uast::AstNode*, types::QualifiedType, std::vector>) ERROR_CLASS(NoMatchingEnumValue, const uast::AstNode*, const types::EnumType*, types::QualifiedType) ERROR_CLASS(NotInModule, const uast::Dot*, ID, UniqueString, ID, bool) diff --git a/frontend/include/chpl/resolution/resolution-queries.h b/frontend/include/chpl/resolution/resolution-queries.h index feb35562e427..dffdc8a0b41e 100644 --- a/frontend/include/chpl/resolution/resolution-queries.h +++ b/frontend/include/chpl/resolution/resolution-queries.h @@ -21,6 +21,7 @@ #define CHPL_RESOLUTION_RESOLUTION_QUERIES_H #include "chpl/resolution/resolution-types.h" +#include "chpl/resolution/interface-types.h" #include "chpl/resolution/scope-types.h" namespace chpl { @@ -34,6 +35,13 @@ namespace resolution { */ const ResolutionResultByPostorderID& resolveModuleStmt(Context* context, ID id); +/** + Specialized version of resolveModuleStmt when the statement is an + 'implements'. This does the work of constructing an 'ImplementationPoint'. + */ +const ImplementationPoint* resolveImplementsStmt(Context* rc, + ID id); + /** Resolve the contents of a Module */ @@ -122,12 +130,25 @@ typedSignatureInitial(ResolutionContext* rc, const UntypedFnSignature* untyped); const TypedFnSignature* typedSignatureInitialForId(ResolutionContext* rc, ID id); +/** + Compute an initial TypedFnSignature, but using placeholder types for + type queries and "any type" markers. This TypedFnSignature can serve + as a template for satisfying interface. + */ +const TypedFnSignature* +typedSignatureTemplateForId(ResolutionContext* rc, ID id); + /** Returns a Type that represents the initial type provided by a TypeDecl (e.g. Class, Record, etc). This type does not store the fields. */ const types::Type* initialTypeForTypeDecl(Context* context, ID declId); +/** + Returns a Type that represents the initial type provided by an Interface + declaration. */ +const types::Type* initialTypeForInterface(Context* context, ID declId); + /** Resolve a single field decl (which could be e.g. a MultiDecl) within a CompositeType. @@ -276,6 +297,16 @@ const ResolvedFunction* resolveFunction(ResolutionContext* rc, const TypedFnSignature* sig, const PoiScope* poiScope); + +/** + Given a scope corresponding to a module, find all visible + implementation points for a particular interface. + */ +const std::vector* +visibileImplementationPointsForInterface(Context* context, + const Scope* scope, + ID interfaceId); + /** Helper to resolve a concrete function using the above queries. Will return `nullptr` if the function is generic or has a `where false`. @@ -460,6 +491,30 @@ const TypedFnSignature* tryResolveDeinit(Context* context, const types::Type* t, const PoiScope* poiScope = nullptr); +/** + Given an instantiated interface constraint, such as 'hashable(int)', + search an implementation point that matches and verify its validity. + If no matching implementation point is found, returns nullptr. + + An implementation point can be invalid if it accepts the expected actuals + from the interface, but the types do not provide the required functions, + associated types, etc. + */ +const ImplementationWitness* findMatchingImplementationPoint(ResolutionContext* rc, + const types::InterfaceType* ift, + const CallScopeInfo& inScopes); + +/** + Given the location of an implementation point, check that the constraints + of the interface are satisfied at that position. This is used as part of + 'findMatchingImplementationPoint', but can be used standalone if a desired + implementation point is already known. + */ +const ImplementationWitness* checkInterfaceConstraints(ResolutionContext* rc, + const types::InterfaceType* ift, + const ID& implPointId, + const CallScopeInfo& inScopes); + /** Given a type 't', compute whether or not 't' is default initializable. If 't' is a generic type, it is considered non-default-initializable. diff --git a/frontend/include/chpl/resolution/resolution-types.h b/frontend/include/chpl/resolution/resolution-types.h index 249681c9990d..ba59c514ad97 100644 --- a/frontend/include/chpl/resolution/resolution-types.h +++ b/frontend/include/chpl/resolution/resolution-types.h @@ -45,6 +45,10 @@ namespace resolution { using SubstitutionsMap = types::CompositeType::SubstitutionsMap; +SubstitutionsMap substituteInMap(Context* context, + const SubstitutionsMap& substituteIn, + const types::PlaceholderMap& subs); + /** In some situations, we may decide not to resolve a call. This could @@ -945,6 +949,9 @@ class TypedFnSignature { std::vector formalTypes, const TypedFnSignature* inferredFrom); + const TypedFnSignature* substitute(Context* context, + const types::PlaceholderMap& subs) const; + bool operator==(const TypedFnSignature& other) const { return untypedSignature_ == other.untypedSignature_ && formalTypes_ == other.formalTypes_ && @@ -1246,6 +1253,8 @@ enum CandidateFailureReason { FAIL_WHERE_CLAUSE, /* A parenful call to a parenless function or vice versa. */ FAIL_PARENLESS_MISMATCH, + /* An interface tried to resolve an associated type function but it didn't return a type */ + FAIL_INTERFACE_NOT_TYPE_INTENT, /* Some other, generic reason. */ FAIL_CANDIDATE_OTHER, }; diff --git a/frontend/include/chpl/types/ArrayType.h b/frontend/include/chpl/types/ArrayType.h index 9241eff73f46..bd8097589b7c 100644 --- a/frontend/include/chpl/types/ArrayType.h +++ b/frontend/include/chpl/types/ArrayType.h @@ -72,6 +72,13 @@ class ArrayType final : public CompositeType { const QualifiedType& domainType, const QualifiedType& eltType); + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return getArrayType(context, + domainType().substitute(context, subs), + eltType().substitute(context, subs)); + } + QualifiedType domainType() const { auto it = subs_.find(domainId); if (it != subs_.end()) { diff --git a/frontend/include/chpl/types/BasicClassType.h b/frontend/include/chpl/types/BasicClassType.h index 6aae0c7005a6..630971d75886 100644 --- a/frontend/include/chpl/types/BasicClassType.h +++ b/frontend/include/chpl/types/BasicClassType.h @@ -20,6 +20,7 @@ #ifndef CHPL_TYPES_BASIC_CLASS_TYPE_H #define CHPL_TYPES_BASIC_CLASS_TYPE_H +#include "chpl/resolution/resolution-types.h" #include "chpl/types/ManageableType.h" #include "chpl/framework/global-strings.h" @@ -74,6 +75,14 @@ class BasicClassType final : public ManageableType { const BasicClassType* instantiatedFrom, CompositeType::SubstitutionsMap subs); + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, id(), name(), + Type::substitute(context, parentType_, subs), + Type::substitute(context, (const BasicClassType*) instantiatedFrom_, subs), + resolution::substituteInMap(context, subs_, subs)); + } + static const BasicClassType* getRootClassType(Context* context); static const BasicClassType* getReduceScanOpType(Context* context); diff --git a/frontend/include/chpl/types/CPtrType.h b/frontend/include/chpl/types/CPtrType.h index c45df14e940c..8be9c65a3289 100644 --- a/frontend/include/chpl/types/CPtrType.h +++ b/frontend/include/chpl/types/CPtrType.h @@ -62,6 +62,11 @@ class CPtrType final : public PtrType { return isConst() ? getConstId(context) : getId(context); } + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, Type::substitute(context, eltType_, subs)); + } + const CPtrType* withoutConst(Context* context) const; bool isConst() const { return isConst_; } diff --git a/frontend/include/chpl/types/ClassType.h b/frontend/include/chpl/types/ClassType.h index 37d60cd2a514..ba010d6170e2 100644 --- a/frontend/include/chpl/types/ClassType.h +++ b/frontend/include/chpl/types/ClassType.h @@ -83,6 +83,13 @@ class ClassType final : public Type { const Type* manager, ClassTypeDecorator decorator); + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, + Type::substitute(context, manageableType_, subs), + Type::substitute(context, manager_, subs), decorator_); + } + /** Returns the ClassTypeDecorator for this ClassType. This decorator indicates the memory management strategy. */ ClassTypeDecorator decorator() const { return decorator_; } diff --git a/frontend/include/chpl/types/CompositeType.h b/frontend/include/chpl/types/CompositeType.h index 4b529809f289..42352422c92e 100644 --- a/frontend/include/chpl/types/CompositeType.h +++ b/frontend/include/chpl/types/CompositeType.h @@ -126,7 +126,7 @@ class CompositeType : public Type { bool compositeTypeContentsMatchInner(const CompositeType* other) const { return id_ == other->id_ && - name_ != other->name_ && + name_ == other->name_ && instantiatedFrom_ == other->instantiatedFrom_ && subs_ == other->subs_; } @@ -149,6 +149,11 @@ class CompositeType : public Type { public: virtual ~CompositeType() = 0; // this is an abstract base class + /* print the substitutions map like it would be printed for a composite type. */ + static void stringifySubstitutions(std::ostream& ss, + chpl::StringifyKind stringKind, + const SubstitutionsMap& subs); + virtual void stringify(std::ostream& ss, chpl::StringifyKind stringKind) const override; diff --git a/frontend/include/chpl/types/DomainType.h b/frontend/include/chpl/types/DomainType.h index aeda1ae87492..8674111bba70 100644 --- a/frontend/include/chpl/types/DomainType.h +++ b/frontend/include/chpl/types/DomainType.h @@ -21,6 +21,7 @@ #define CHPL_TYPES_DOMAIN_TYPE_H #include "chpl/types/CompositeType.h" +#include "chpl/resolution/resolution-types.h" namespace chpl { namespace types { @@ -101,6 +102,14 @@ class DomainType final : public CompositeType { const QualifiedType& idxType, const QualifiedType& parSafe); + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return getDomainType(context, id(), name(), + Type::substitute(context, (const DomainType*) instantiatedFrom_, subs), + resolution::substituteInMap(context, subs_, subs), + kind_).get(); + } + /** Get the default distribution type */ static const QualifiedType& getDefaultDistType(Context* context); diff --git a/frontend/include/chpl/types/FnIteratorType.h b/frontend/include/chpl/types/FnIteratorType.h index d6846ba68e41..3884a66d9c4d 100644 --- a/frontend/include/chpl/types/FnIteratorType.h +++ b/frontend/include/chpl/types/FnIteratorType.h @@ -60,6 +60,11 @@ class FnIteratorType final : public IteratorType { const resolution::PoiScope* poiScope, const resolution::TypedFnSignature* iteratorFn); + virtual const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, poiScope_, iteratorFn_->substitute(context, subs)); + } + const resolution::TypedFnSignature* iteratorFn() const { return iteratorFn_; } diff --git a/frontend/include/chpl/types/HeapBufferType.h b/frontend/include/chpl/types/HeapBufferType.h index fbffd57a0518..75fbca251d6e 100644 --- a/frontend/include/chpl/types/HeapBufferType.h +++ b/frontend/include/chpl/types/HeapBufferType.h @@ -52,6 +52,11 @@ class HeapBufferType final : public PtrType { return getId(context); } + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, Type::substitute(context, eltType_, subs)); + } + virtual void stringify(std::ostream& ss, chpl::StringifyKind stringKind) const override; }; diff --git a/frontend/include/chpl/types/InterfaceType.h b/frontend/include/chpl/types/InterfaceType.h new file mode 100644 index 000000000000..4b65cb0dd406 --- /dev/null +++ b/frontend/include/chpl/types/InterfaceType.h @@ -0,0 +1,132 @@ +/* + * Copyright 2024 Hewlett Packard Enterprise Development LP + * Other additional copyright holders may be indicated within. + * + * The entirety of this work is licensed under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CHPL_TYPES_INTERFACE_TYPE_H +#define CHPL_TYPES_INTERFACE_TYPE_H + +#include "chpl/resolution/resolution-types.h" +#include "chpl/types/Type.h" +#include "chpl/types/CompositeType.h" +#include "chpl/types/QualifiedType.h" + +namespace chpl { +namespace types { + +/* + A type representing an interface constraint, either generic (as in + 'hashable') or instantiated with a set of types as actuals (e.g., + 'hashable(int)'). + + This type does not mean that the actuals ('int' in the above example) + satisfy the interface; rather, it can be thought of as a claim that + the interface is satisfied, which may be true or false. + + The following cases both create instances of InterfaceType: + + record R : I {} // the 'I' has type InterfaceType(name='I', ...) + implements R(I); // the 'R(I)' has type InterfaceType(name='I', subs = {... => R }) + +*/ +class InterfaceType final : public Type { + public: + using SubstitutionsMap = CompositeType::SubstitutionsMap; + + private: + // The ID of the interface's declaration + ID id_; + // The name of the interface, for convenience + UniqueString name_; + // The substitutions for the interface's formals. If the interface is + // generic, this will be empty; otherwise, this will have exactly one + // substitution for each of the interface's formals. + SubstitutionsMap subs_; + + // check that the substitutions are valid for the given interface. + // executed from an assertion, so does not haev an impact in release mode. + static bool validateSubstitutions(Context* context, + const ID& id, + SubstitutionsMap& subs); + + InterfaceType(ID id, UniqueString name, SubstitutionsMap subs) + : Type(typetags::InterfaceType), id_(std::move(id)), name_(std::move(name)), + subs_(std::move(subs)) {} + + bool contentsMatchInner(const Type* other) const override { + auto rhs = (const InterfaceType*) other; + return id_ == rhs->id_ && + name_ == rhs->name_ && + subs_ == rhs->subs_; + } + + void markUniqueStringsInner(Context* context) const override { + id_.mark(context); + name_.mark(context); + ::chpl::mark{}(context, subs_); + } + + // computed from substitutions, c.f. CompositeType + Genericity genericity() const override { + return MAYBE_GENERIC; + } + + static owned const& + getInterfaceType(Context* context, + ID id, + UniqueString name, + SubstitutionsMap subs); + + public: + static const InterfaceType* get(Context* context, + ID id, + UniqueString name, + SubstitutionsMap subs); + + /* Get an interface type, assigning the types in the list to the interface's + formals in order. If the number of types does not match the interface's + number of formals, returns nullptr. + */ + static const InterfaceType* withTypes(Context* context, + const InterfaceType* ift, + std::vector types); + + /** Returns true if 'this' is an instantiation of genericType */ + bool isInstantiationOf(Context* context, + const InterfaceType* genericType) const; + + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, id_, name_, resolution::substituteInMap(context, subs_, subs)); + } + + /* Get the ID of the interface. */ + const ID& id() const { return id_; } + + /* Get the name of the interface. */ + UniqueString name() const { return name_; } + + /* Get the substitutions for the interface (mapping formals to thir types). */ + const SubstitutionsMap& substitutions() const { return subs_; } + + void stringify(std::ostream& ss, StringifyKind stringKind) const override; +}; + +} // end namespace types +} // end namespace chpl + +#endif diff --git a/frontend/include/chpl/types/LoopExprIteratorType.h b/frontend/include/chpl/types/LoopExprIteratorType.h index 9dd9306ec34d..82da82853f0f 100644 --- a/frontend/include/chpl/types/LoopExprIteratorType.h +++ b/frontend/include/chpl/types/LoopExprIteratorType.h @@ -114,6 +114,13 @@ class LoopExprIteratorType final : public IteratorType { QualifiedType iterand, ID sourceLocation); + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, yieldType_.substitute(context, subs), + poiScope_, isZippered_, supportsParallel_, + iterand_.substitute(context, subs), sourceLocation_); + } + const QualifiedType& yieldType() const { return yieldType_; } diff --git a/frontend/include/chpl/types/PlaceholderType.h b/frontend/include/chpl/types/PlaceholderType.h new file mode 100644 index 000000000000..d1580380bc8f --- /dev/null +++ b/frontend/include/chpl/types/PlaceholderType.h @@ -0,0 +1,111 @@ +/* + * Copyright 2024 Hewlett Packard Enterprise Development LP + * Other additional copyright holders may be indicated within. + * + * The entirety of this work is licensed under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CHPL_TYPES_PLACEHOLDER_TYPE_H +#define CHPL_TYPES_PLACEHOLDER_TYPE_H + +#include "chpl/types/Type.h" +#include "chpl/resolution/resolution-types.h" + +namespace chpl { +namespace types { + +/* An opaque placeholder type that was introduced for some AST node. + Currently, these are used for the interface resolution process, where + they serve both as temporary types for interface formals like 'Self', + and as the types of 'type queries'. In the following code: + + interface I { + proc Self.foo(x: ?t1, y: ?t2); + } + + All of 'Self', 't1', and 't2' will have type 'PlaceholderType(..)', with IDs given + by the (implicit) formal declaration and the type queries, respectively. + + Because placeholder types are unique (created from a particular position + in the AST), they cannot be passed to any concrete function argument. This + means that when searcing for witnesses for 'foo', functions with concrete 'x' + will always be rejected (as desired). Moreover, as part of checking applicabiity, + placeholder types will be substituted for type queries in the candidate + (non-interface) function, and thus enforce a matching pattern of genericity. + + As an example, the following function will match the constraint above: + + proc R.foo(x, y) {} + + because 'x' will be instantiated with 'PlaceholderType(t1)', and + 'y' with 'PlaceholderType(t2)'. However, the following function will not match: + + proc R.foo(x, y: x.type) {} + + Because the type of 'y' will be 'PlaceholderType(t1)', which is not + equal to 'PlaceholderType(t2)'. + + Placeholder types can be eliminated from types using the 'substitute' method, + which will replace them with the corresponding "real" type. + */ +class PlaceholderType final : public Type { + private: + // the syntactic ID for which this placeholder type was created + ID id_; + + PlaceholderType(ID id) + : Type(typetags::PlaceholderType), id_(std::move(id)) {} + + bool contentsMatchInner(const Type* other) const override { + auto rhs = (const PlaceholderType*) other; + return id_ == rhs->id_; + } + + void markUniqueStringsInner(Context* context) const override { + id_.mark(context); + } + + Genericity genericity() const override { + return CONCRETE; + } + + static owned const& + getPlaceholderType(Context* context, + ID id); + + public: + static const PlaceholderType* get(Context* context, + ID id); + + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + auto it = subs.find(id_); + if (it != subs.end()) { + return it->second; + } + return this; + } + + /* Get the ID for this placeholder type. */ + const ID& id() const { return id_; } + + void stringify(std::ostream& ss, + chpl::StringifyKind stringKind) const override; +}; + +} // end namespace types +} // end namespace chpl + +#endif diff --git a/frontend/include/chpl/types/PromotionIteratorType.h b/frontend/include/chpl/types/PromotionIteratorType.h index 3c1d650b32ec..a97df0feb0ef 100644 --- a/frontend/include/chpl/types/PromotionIteratorType.h +++ b/frontend/include/chpl/types/PromotionIteratorType.h @@ -69,6 +69,12 @@ class PromotionIteratorType final : public IteratorType { const resolution::TypedFnSignature* scalarFn, resolution::SubstitutionsMap promotedFormals); + virtual const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, poiScope_, scalarFn_->substitute(context, subs), + resolution::substituteInMap(context, promotedFormals_, subs)); + } + const resolution::TypedFnSignature* scalarFn() const { return scalarFn_; } diff --git a/frontend/include/chpl/types/QualifiedType.h b/frontend/include/chpl/types/QualifiedType.h index fe7e9ed6c04c..ec149ada967b 100644 --- a/frontend/include/chpl/types/QualifiedType.h +++ b/frontend/include/chpl/types/QualifiedType.h @@ -90,6 +90,13 @@ class QualifiedType final { CHPL_ASSERT(param_ == nullptr || kind_ == Kind::PARAM); } + /** replaces placeholders (as in PlaceholderType in the type) according + to their values in the 'subs' map. See also Type::substitute. */ + const QualifiedType substitute(Context* context, + const PlaceholderMap& subs) const { + return QualifiedType(kind_, Type::substitute(context, type_, subs), param_); + } + /** Returns the kind of the expression this QualifiedType represents */ Kind kind() const { return kind_; } /** diff --git a/frontend/include/chpl/types/RecordType.h b/frontend/include/chpl/types/RecordType.h index 29d97ce499ea..8bde58934133 100644 --- a/frontend/include/chpl/types/RecordType.h +++ b/frontend/include/chpl/types/RecordType.h @@ -20,6 +20,7 @@ #ifndef CHPL_TYPES_RECORD_TYPE_H #define CHPL_TYPES_RECORD_TYPE_H +#include "chpl/resolution/resolution-types.h" #include "chpl/types/CompositeType.h" namespace chpl { @@ -59,6 +60,13 @@ class RecordType final : public CompositeType { const RecordType* instantiatedFrom, CompositeType::SubstitutionsMap subs); + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, id(), name(), + Type::substitute(context, (const RecordType*) instantiatedFrom_, subs), + resolution::substituteInMap(context, subs_, subs)); + } + ~RecordType() = default; diff --git a/frontend/include/chpl/types/TupleType.h b/frontend/include/chpl/types/TupleType.h index a17a8ad843db..07e50d79abae 100644 --- a/frontend/include/chpl/types/TupleType.h +++ b/frontend/include/chpl/types/TupleType.h @@ -21,6 +21,7 @@ #define CHPL_TYPES_TUPLE_TYPE_H #include "chpl/types/CompositeType.h" +#include "chpl/resolution/resolution-types.h" namespace chpl { namespace types { @@ -100,6 +101,14 @@ class TupleType final : public CompositeType { QualifiedType paramSize, QualifiedType starEltType); + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return getTupleType(context, + Type::substitute(context, (const TupleType*) instantiatedFrom_, subs), + resolution::substituteInMap(context, subs_, subs), + isVarArgTuple_).get(); + } + /** Return the generic tuple type `_tuple` */ static const TupleType* getGenericTupleType(Context* context); diff --git a/frontend/include/chpl/types/Type.h b/frontend/include/chpl/types/Type.h index dd80157d0a59..6cd493a54af6 100644 --- a/frontend/include/chpl/types/Type.h +++ b/frontend/include/chpl/types/Type.h @@ -25,13 +25,16 @@ #include "chpl/framework/mark-functions.h" #include "chpl/types/TypeTag.h" #include "chpl/uast/Pragma.h" +#include "chpl/util/hash.h" #include namespace chpl { + namespace uast { class Decl; } + namespace types { @@ -54,6 +57,7 @@ namespace types { #undef TYPE_DECL class Type; +using PlaceholderMap = std::unordered_map; namespace detail { @@ -159,6 +163,31 @@ class Type { bool completeMatch(const Type* other) const; + /** replaces placeholders (as in PlaceholderType in the type) according + to their values in the 'subs' map. */ + virtual const Type* substitute(Context* context, + const PlaceholderMap& subs) const { + return this; + } + + /** For a given subclass of 'Type', replaces placeholders (as in + PlaceholderType in the type) according to their values in the 'subs' map, + handling the case in which the type is null. + + Since replacing placeholders ought not to change which subclass + the type is, asserts and casts the result back to the same subclass. */ + template + static const TargetType* substitute(Context* context, + const TargetType* type, + const PlaceholderMap& subs) { + if (!type) return type; + auto substituted = type->substitute(context, subs); + CHPL_ASSERT(substituted); + auto cast = substituted->template to(); + CHPL_ASSERT(cast); + return cast; + } + virtual void stringify(std::ostream& ss, chpl::StringifyKind stringKind) const; /** Check if this type is particular subclass. The call someType->is() @@ -323,6 +352,11 @@ class Type { namespace detail { +template <> +inline bool typeIs(const Type* type) { + return true; +} + /// \cond DO_NOT_DOCUMENT #define TYPE_IS(NAME) \ template <> \ @@ -343,6 +377,16 @@ namespace detail { #undef TYPE_END_SUBCLASSES #undef TYPE_IS +template <> +inline const Type* typeToConst(const Type* type) { + return type; +} + +template <> +inline Type* typeTo(Type* type) { + return type; +} + /// \cond DO_NOT_DOCUMENT #define TYPE_TO(NAME) \ template <> \ @@ -395,4 +439,12 @@ template<> struct mark { // TODO: is there a reasonable way to define std::less on Type*? // Comparing pointers would lead to some nondeterministic ordering. +namespace std { + template<> struct hash { + inline size_t operator()(const chpl::types::PlaceholderMap& k) const{ + return chpl::hashUnorderedMap(k); + } + }; +} + #endif diff --git a/frontend/include/chpl/types/UnionType.h b/frontend/include/chpl/types/UnionType.h index 34f6062785cf..263affa34b5f 100644 --- a/frontend/include/chpl/types/UnionType.h +++ b/frontend/include/chpl/types/UnionType.h @@ -20,6 +20,7 @@ #ifndef CHPL_TYPES_UNION_TYPE_H #define CHPL_TYPES_UNION_TYPE_H +#include "chpl/resolution/resolution-types.h" #include "chpl/types/CompositeType.h" namespace chpl { @@ -56,6 +57,13 @@ class UnionType final : public CompositeType { const UnionType* instantiatedFrom, CompositeType::SubstitutionsMap subs); + const Type* substitute(Context* context, + const PlaceholderMap& subs) const override { + return get(context, id_, name_, + Type::substitute(context, (UnionType*) instantiatedFrom_, subs), + resolution::substituteInMap(context, subs_, subs)); + } + ~UnionType() = default; /** If this type represents an instantiated type, diff --git a/frontend/include/chpl/types/all-types.h b/frontend/include/chpl/types/all-types.h index 9f7e61a31d1b..0b3676a004d8 100644 --- a/frontend/include/chpl/types/all-types.h +++ b/frontend/include/chpl/types/all-types.h @@ -36,12 +36,14 @@ #include "chpl/types/FnIteratorType.h" #include "chpl/types/HeapBufferType.h" #include "chpl/types/ImagType.h" +#include "chpl/types/InterfaceType.h" #include "chpl/types/IntType.h" #include "chpl/types/IteratorType.h" #include "chpl/types/LoopExprIteratorType.h" #include "chpl/types/NilType.h" #include "chpl/types/NothingType.h" #include "chpl/types/Param.h" +#include "chpl/types/PlaceholderType.h" #include "chpl/types/PrimitiveType.h" #include "chpl/types/PtrType.h" #include "chpl/types/PromotionIteratorType.h" diff --git a/frontend/include/chpl/types/type-classes-list.h b/frontend/include/chpl/types/type-classes-list.h index 91c3d347b6f2..72c28454fc13 100644 --- a/frontend/include/chpl/types/type-classes-list.h +++ b/frontend/include/chpl/types/type-classes-list.h @@ -44,6 +44,7 @@ TYPE_NODE(CStringType) TYPE_NODE(ErroneousType) TYPE_NODE(NilType) TYPE_NODE(NothingType) +TYPE_NODE(PlaceholderType) TYPE_NODE(UnknownType) TYPE_NODE(VoidType) @@ -100,6 +101,7 @@ TYPE_BEGIN_SUBCLASSES(DeclaredType) TYPE_NODE(EnumType) TYPE_NODE(ExternType) TYPE_NODE(FunctionType) + TYPE_NODE(InterfaceType) TYPE_BEGIN_SUBCLASSES(CompositeType) TYPE_NODE(ArrayType) diff --git a/frontend/include/chpl/uast/Interface.h b/frontend/include/chpl/uast/Interface.h index b4fd7dd0b305..11a2ec83788a 100644 --- a/frontend/include/chpl/uast/Interface.h +++ b/frontend/include/chpl/uast/Interface.h @@ -168,7 +168,7 @@ class Interface final : public NamedDecl { Return the i'th interface formal. */ const AstNode* formal(int i) const { - CHPL_ASSERT(i >= 0 && i < numBodyStmts_); + CHPL_ASSERT(i >= 0 && i < numInterfaceFormals_); auto ret = child(i + interfaceFormalsChildNum_); CHPL_ASSERT(ret); return ret; diff --git a/frontend/include/chpl/util/hash.h b/frontend/include/chpl/util/hash.h index d258f49a12fc..8a0572c0a394 100644 --- a/frontend/include/chpl/util/hash.h +++ b/frontend/include/chpl/util/hash.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -128,6 +129,20 @@ inline size_t hashSet(const std::set& key) { return ret; } +template +struct FirstElementComparator { + bool operator()(const std::pair& a, const std::pair& b) const { + return a.first < b.first; + } +}; + +template +inline size_t hashUnorderedMap(const std::unordered_map& key) { + std::vector> sorted(key.begin(), key.end()); + std::sort(sorted.begin(), sorted.end(), FirstElementComparator()); + return hashVector(sorted); +} + template inline size_t hashOwned(const chpl::owned& key) { size_t ret = 0; @@ -145,8 +160,8 @@ inline size_t hashPair(const std::pair& key) { return ret; } -template -inline size_t hashMap(const std::map& key) { +template +inline size_t hashMap(const Map& key) { size_t ret = 0; // Just iterate and hash, relying on std::map being a sorted container. diff --git a/frontend/lib/parsing/parsing-queries.cpp b/frontend/lib/parsing/parsing-queries.cpp index 2358f3dd01da..050beca70c8e 100644 --- a/frontend/lib/parsing/parsing-queries.cpp +++ b/frontend/lib/parsing/parsing-queries.cpp @@ -1337,21 +1337,33 @@ bool idIsNestedFunction(Context* context, ID id) { if (id.isEmpty() || !idIsFunction(context, id)) return false; for (auto up = id.parentSymbolId(context); up; up = up.parentSymbolId(context)) { - if (idIsFunction(context, up)) return true; + if (idIsFunction(context, up) || idIsInterface(context, up)) return true; } return false; } -bool idIsFunction(Context* context, ID id) { - // Functions always have their own ID symbol scope, - // and if it's not a function, we can return false - // without doing further work. +template +bool idIsSymbolDefiningScope(Context* context, ID id, Predicate&& predicate) { if (!id.isSymbolDefiningScope()) { return false; } AstTag tag = idToTag(context, id); - return asttags::isFunction(tag); + return predicate(tag); +} + +bool idIsFunction(Context* context, ID id) { + // Functions always have their own ID symbol scope, + // and if it's not a function, we can return false + // without doing further work. + return idIsSymbolDefiningScope(context, id, asttags::isFunction); +} + +bool idIsInterface(Context* context, ID id) { + // Interfaces always have their own ID symbol scope, + // and if it's not an interface, we can return false + // without doing further work. + return idIsSymbolDefiningScope(context, id, asttags::isInterface); } static bool @@ -1492,16 +1504,25 @@ const ID& idToParentId(Context* context, ID id) { return QUERY_END(result); } -ID idToParentFunctionId(Context* context, ID id) { +template +ID idToParentSymbolId(Context* context, ID id, Predicate&& predicate) { if (id.isEmpty()) return {}; for (auto up = id; up; up = up.parentSymbolId(context)) { if (up == id) continue; - // Get the first parent function (a parent could be a record/class/etc). - if (parsing::idIsFunction(context, up)) return up; + // Get the first matching symbol + if (predicate(context, up)) return up; } return {}; } +ID idToParentFunctionId(Context* context, ID id) { + return idToParentSymbolId(context, id, parsing::idIsFunction); +} + +ID idToParentInterfaceId(Context* context, ID id) { + return idToParentSymbolId(context, id, parsing::idIsInterface); +} + const uast::AstNode* parentAst(Context* context, const uast::AstNode* node) { if (node == nullptr) return nullptr; auto parentId = idToParentId(context, node->id()); diff --git a/frontend/lib/resolution/CMakeLists.txt b/frontend/lib/resolution/CMakeLists.txt index 959f34a51298..1d300fe2d8af 100644 --- a/frontend/lib/resolution/CMakeLists.txt +++ b/frontend/lib/resolution/CMakeLists.txt @@ -29,6 +29,7 @@ target_sources(ChplFrontend-obj disambiguation.cpp extern-blocks.cpp intents.cpp + interface-types.cpp maybe-const.cpp prims.cpp ResolutionContext.cpp diff --git a/frontend/lib/resolution/ResolutionContext.cpp b/frontend/lib/resolution/ResolutionContext.cpp index 79935aa110e6..21c983227604 100644 --- a/frontend/lib/resolution/ResolutionContext.cpp +++ b/frontend/lib/resolution/ResolutionContext.cpp @@ -19,6 +19,7 @@ #include "chpl/resolution/ResolutionContext.h" #include "chpl/resolution/resolution-types.h" +#include "chpl/types/PlaceholderType.h" #include "Resolver.h" namespace chpl { @@ -59,7 +60,10 @@ canUseGlobalCache(Context* context, const MatchingIdsWithName& ids) { } const ID& ResolutionContext::Frame::id() const { - if (auto ast = rv_->symbol) return ast->id(); + if (rv_) { + if (auto ast = rv_->symbol) return ast->id(); + } + if (ift_) return ift_->id(); return EMPTY_AST_ID; } @@ -70,9 +74,40 @@ const TypedFnSignature* ResolutionContext::Frame::signature() const { return nullptr; } -const ResolutionResultByPostorderID* -ResolutionContext::Frame::resolutionById() const { - return rv_ ? &rv_->byPostorder : nullptr; +static types::QualifiedType placeholderForId(ResolutionContext* rc, const ID& id) { + return types::QualifiedType(types::QualifiedType::TYPE, + types::PlaceholderType::get(rc->context(), id)); +} + +const types::QualifiedType +ResolutionContext::Frame::typeForContainedId(ResolutionContext* rc, const ID& id) const { + if (rv_) { + return rv_->byPostorder.byId(id).type(); + } + + // For interfaces, just return placeholders for the associated types + // and the interface parameters. + // + // In the future, to make generics interoperate with interface-based + // generics, we will need to do more here. This is pending design discussion, + // though. + if (ift_) { + auto subIt = ift_->substitutions().find(id); + if (subIt != ift_->substitutions().end()) + return placeholderForId(rc, subIt->first); + + if (witness_) { + // search associated types + auto atIt = witness_->associatedTypes().find(id); + if (atIt != witness_->associatedTypes().end()) + return placeholderForId(rc, atIt->first); + + // TODO: search additional constraints, required functions? + } + } else { + CHPL_ASSERT(witness_ == nullptr); + } + return types::QualifiedType(); } const ResolutionContext::Frame* ResolutionContext:: @@ -93,13 +128,23 @@ pushFrame(const ResolvedFunction* rf) { return ret; } +const ResolutionContext::Frame* ResolutionContext:: +pushFrame(const types::InterfaceType* ift, const ImplementationWitness* witness) { + int64_t index = (int64_t) frames_.size(); + frames_.push_back({ift, witness, index}); + auto ret = lastFrame(); + if (ret->isUnstable()) numUnstableFrames_++; + return ret; +} + bool ResolutionContext::Frame::isUnstable() const { - return rv_ != nullptr; + return rv_ != nullptr || ift_ != nullptr || witness_ != nullptr; } void ResolutionContext::popFrame(Resolver* rv) { CHPL_ASSERT(!frames_.empty() && "Frame stack underflow!"); - CHPL_ASSERT(frames_.back().rv() == rv); + CHPL_ASSERT(frames_.back().rv() == rv || + frames_.back().ift()); if (frames_.empty()) return; if (frames_.back().isUnstable()) numUnstableFrames_--; diff --git a/frontend/lib/resolution/Resolver.cpp b/frontend/lib/resolution/Resolver.cpp index b947214fd362..af744055a191 100644 --- a/frontend/lib/resolution/Resolver.cpp +++ b/frontend/lib/resolution/Resolver.cpp @@ -243,6 +243,31 @@ Resolver::createForModuleStmt(ResolutionContext* rc, const Module* mod, return ret; } +Resolver +Resolver::createForInterfaceStmt(ResolutionContext* rc, + const uast::Interface* interface, + const types::InterfaceType* ift, + const ImplementationWitness* witness, + const uast::AstNode* stmt, + ResolutionResultByPostorderID& byPostorder) { + const AstNode* symbol = interface; + const Block* fnBody = nullptr; + if (auto fn = stmt->toFunction()) { + symbol = fn; + fnBody = fn->body(); + } + + auto ret = Resolver(rc->context(), symbol, byPostorder, nullptr); + ret.curStmt = stmt; + ret.byPostorder.setupForSymbol(symbol); + ret.rc = rc; + ret.signatureOnly = true; + ret.fnBody = fnBody; + rc->pushFrame(ift, witness); + ret.didPushFrame = true; + return ret; +} + Resolver Resolver::createForScopeResolvingModuleStmt( Context* context, const Module* mod, @@ -612,7 +637,7 @@ isOuterVariable(Resolver& rv, const Identifier* ident, const ID& target) { auto tag = parsing::idToTag(context, targetParentSymbolId); - if (tag == asttags::Function) return true; + if (tag == asttags::Function || tag == asttags::Interface) return true; if (tag == asttags::Module) { // Module-scope variables are not considered outer-variables. However, @@ -1527,6 +1552,14 @@ static QualifiedType computeTypeDefaults(Resolver& resolver, return type; } +static const Type* getAnyType(Resolver& resolver, const ID& anchor) { + // If we use placeholders, we don't create 'AnyTypes' anywhere, + // and instead invent new placeholder types. + return resolver.usePlaceholders + ? PlaceholderType::get(resolver.context, anchor)->to() + : AnyType::get(resolver.context)->to(); +} + // useType will be used to set the type if it is not nullptr void Resolver::resolveNamedDecl(const NamedDecl* decl, const Type* useType) { if (scopeResolveOnly) @@ -1690,7 +1723,7 @@ void Resolver::resolveNamedDecl(const NamedDecl* decl, const Type* useType) { // primary method. This does not, however, mean that its type should be // AnyType; it is not adjusted here. - typeExprT = QualifiedType(QualifiedType::TYPE, AnyType::get(context)); + typeExprT = QualifiedType(QualifiedType::TYPE, getAnyType(*this, decl->id())); } else if (isFieldOrFormal) { // figure out if we should potentially infer the type from the init expr // (we do so if it's not a field or a formal) @@ -2228,8 +2261,11 @@ void Resolver::resolveTupleDecl(const TupleDecl* td, // Note: we seem to rely on tuple components being 'var', and relying on // the tuple's kind instead. Without this, the current instantiation // logic won't allow, for example, passing (1, 2, 3) to (?, ?, ?). - auto anyType = QualifiedType(QualifiedType::VAR, AnyType::get(context)); - std::vector eltTypes(td->numDecls(), anyType); + std::vector eltTypes; + for (auto decl : td->decls()) { + eltTypes.push_back(QualifiedType(QualifiedType::VAR, + getAnyType(*this, decl->id()))); + } auto tup = TupleType::getQualifiedTuple(context, eltTypes); useT = QualifiedType(declKind, tup); } else { @@ -2650,6 +2686,9 @@ QualifiedType Resolver::typeForId(const ID& id, bool localGenericToUnknown) { return QualifiedType(QualifiedType::TYPE, t); } else if (asttags::isModule(tag)) { return QualifiedType(QualifiedType::MODULE, nullptr); + } else if (asttags::isInterface(tag)) { + const Type* t = initialTypeForInterface(context, id); + return QualifiedType(QualifiedType::TYPE, t); } if (asttags::isFunction(tag)) { @@ -3377,9 +3416,10 @@ bool Resolver::lookupOuterVariable(QualifiedType& out, // Otherwise, it's a variable, so walk up parent frames and look up // the variable's type using the resolution results. - } else if (ID parentFn = parsing::idToParentFunctionId(context, target)) { + } else if (parsing::idToParentFunctionId(context, target) || + parsing::idToParentInterfaceId(context, target)) { if (auto f = rc->findFrameWithId(target)) { - type = f->resolutionById()->byId(target).type(); + type = f->typeForContainedId(rc, target); outerVariables.add(mention, target, type); } } @@ -3397,7 +3437,7 @@ void Resolver::resolveIdentifier(const Identifier* ident) { CHPL_ASSERT(declStack.size() > 0); const Decl* inDecl = declStack.back(); if (inDecl->isVarLikeDecl() && ident->name() == USTR("?")) { - result.setType(QualifiedType(QualifiedType::TYPE, AnyType::get(context))); + result.setType(QualifiedType(QualifiedType::TYPE, getAnyType(*this, ident->id()))); return; } @@ -3638,6 +3678,17 @@ void Resolver::exit(const uast::Init* init) { } bool Resolver::enter(const TypeQuery* tq) { + if (usePlaceholders) { + // If we're resolving an interface, create a placeholder for the type + // query. This way, we get a concrete type for `foo(?x)`, which is + // desireable when validating user-provided functions against the + // interface signature. + ResolvedExpression& result = byPostorder.byAst(tq); + result.setType(QualifiedType(QualifiedType::TYPE, + PlaceholderType::get(context, tq->id()))); + return false; + } + if (skipTypeQueries) { return false; } @@ -5379,7 +5430,7 @@ static bool handleArrayTypeExpr(Resolver& rv, if (loop->numStmts() == 1) { bodyType = rv.byPostorder.byAst(loop->stmt(0)).type(); } else { - bodyType = QualifiedType(QualifiedType::TYPE, AnyType::get(rv.context)); + bodyType = QualifiedType(QualifiedType::TYPE, getAnyType(rv, loop->id())); } // The body wasn't a type, so this isn't an array type expression diff --git a/frontend/lib/resolution/Resolver.h b/frontend/lib/resolution/Resolver.h index fd0a74c11ad7..e619847815b7 100644 --- a/frontend/lib/resolution/Resolver.h +++ b/frontend/lib/resolution/Resolver.h @@ -89,6 +89,7 @@ struct Resolver { const PoiScope* poiScope = nullptr; const uast::Decl* ignoreSubstitutionFor = nullptr; bool skipTypeQueries = false; + bool usePlaceholders = false; // internal variables ResolutionContext emptyResolutionContext; @@ -177,6 +178,15 @@ struct Resolver { const uast::AstNode* modStmt, ResolutionResultByPostorderID& byPostorder); + static Resolver + createForInterfaceStmt(ResolutionContext* rc, + const uast::Interface* interface, + const types::InterfaceType* ift, + const ImplementationWitness* witness, + const uast::AstNode* stmt, + ResolutionResultByPostorderID& byPostorder); + + // set up Resolver to scope resolve a Module static Resolver createForScopeResolvingModuleStmt( diff --git a/frontend/lib/resolution/can-pass.cpp b/frontend/lib/resolution/can-pass.cpp index 40910fbe3aaa..78e5794eea07 100644 --- a/frontend/lib/resolution/can-pass.cpp +++ b/frontend/lib/resolution/can-pass.cpp @@ -906,6 +906,14 @@ CanPassResult CanPassResult::canInstantiate(Context* context, // TODO: check for constrained generic types + if (auto actualIt = actualT->toInterfaceType()) { + if (auto formalIt = formalT->toInterfaceType()) { + if (actualIt->isInstantiationOf(context, formalIt)) { + return instantiate(); + } + } + } + if (auto actualCt = actualT->toClassType()) { // check for instantiating classes if (auto formalCt = formalT->toClassType()) { @@ -1204,6 +1212,53 @@ CanPassResult CanPassResult::canPass(Context* context, return got; } +bool canInstantiateSubstitutions(Context* context, + const SubstitutionsMap& instances, + const SubstitutionsMap& generics, + bool allowMissing) { + // Check to see if the substitutions in `instaces` are all instantiations + // of the substitutions in `generics` + // + // check, for each substitution in mySubs, that it matches + // or is an instantiation of pSubs. + + for (const auto& mySubPair : instances) { + ID mySubId = mySubPair.first; + QualifiedType mySubType = mySubPair.second; + + // look for a substitution in pSubs with the same ID + auto pSearch = generics.find(mySubId); + if (pSearch != generics.end()) { + QualifiedType pSubType = pSearch->second; + // check the types + auto r = canPass(context, mySubType, pSubType); + if (r.passes() && !r.promotes() && !r.converts()) { + // instantiation and same-type passing are allowed here + } else { + // it was not an instantiation + return false; + } + } else if (!allowMissing) { + // If the ID isn't found, then that means the generic component doesn't + // exist in the other type, which means this cannot be an instantiation + // of the other type. + // + // How could we reach this condition? One path here involves passing a + // tuple to a tuple formal with a fewer number of elements. For example, + // passing "(1, 2, 3)" to "(int, ?)". + return false; + } else { + // A substitution is missing in the partial type, but we have one. + // For a composite type, that might just mean that we are foo(X, Y), + // while the partial is foo(X, ?) -- partially generic. So, this is + // fine, don't return false. + } + } + + return true; + +} + void KindProperties::invalidate() { isRef = isConst = isType = isParam = isValid = false; } diff --git a/frontend/lib/resolution/default-functions.cpp b/frontend/lib/resolution/default-functions.cpp index 7812a16f2a68..4c894bd9a7b2 100644 --- a/frontend/lib/resolution/default-functions.cpp +++ b/frontend/lib/resolution/default-functions.cpp @@ -1348,16 +1348,39 @@ getCompilerGeneratedBinaryOp(Context* context, return getCompilerGeneratedBinaryOpQuery(context, lhs, rhs, name); } -const BuilderResult& -buildTypeConstructor(Context* context, ID typeID) { - QUERY_BEGIN(buildTypeConstructor, context, typeID); +static owned typeConstructorFnForInterface(Context* context, + const Interface* itf, + Builder* builder, + Location fnLoc) { + AstList formals; + for (auto formal : itf->formals()) { + formals.push_back(formal->copy()); + } - auto bld = Builder::createForGeneratedCode(context, typeID); - auto builder = bld.get(); - auto dummyLoc = parsing::locateId(context, typeID); + auto genFn = Function::build(builder, fnLoc, {}, + Decl::Visibility::PUBLIC, + Decl::Linkage::DEFAULT_LINKAGE, + /*linkageName=*/{}, + itf->name(), + /*inline=*/false, /*override=*/false, + Function::Kind::PROC, + /*receiver=*/{}, + Function::ReturnIntent::DEFAULT_RETURN_INTENT, + // throws, primaryMethod, parenless + false, false, false, + std::move(formals), + // returnType, where, lifetime, body + {}, {}, {}, {}); + + + return genFn; +} +static owned typeConstructorFnForComposite(Context* context, + const CompositeType* ct, + Builder* builder, + Location fnLoc) { AstList formals; - auto ct = initialTypeForTypeDecl(context, typeID)->getCompositeType(); if (auto bct = ct->toBasicClassType()) { auto parent = bct->parentClassType(); @@ -1392,7 +1415,7 @@ buildTypeConstructor(Context* context, ID typeID) { auto typeExpr = fieldDecl->typeExpression(); auto initExpr = fieldDecl->initExpression(); auto kind = fieldDecl->kind() == Variable::PARAM ? Variable::PARAM : Variable::TYPE; - owned formal = Formal::build(builder, dummyLoc, + owned formal = Formal::build(builder, fnLoc, /*attributeGroup=*/nullptr, fieldDecl->name(), (Formal::Intent)kind, @@ -1402,7 +1425,7 @@ buildTypeConstructor(Context* context, ID typeID) { } } - auto genFn = Function::build(builder, dummyLoc, {}, + auto genFn = Function::build(builder, fnLoc, {}, Decl::Visibility::PUBLIC, Decl::Linkage::DEFAULT_LINKAGE, /*linkageName=*/{}, @@ -1416,6 +1439,26 @@ buildTypeConstructor(Context* context, ID typeID) { std::move(formals), // returnType, where, lifetime, body {}, {}, {}, {}); + return genFn; +} + +const BuilderResult& +buildTypeConstructor(Context* context, ID typeID) { + QUERY_BEGIN(buildTypeConstructor, context, typeID); + + auto bld = Builder::createForGeneratedCode(context, typeID); + auto builder = bld.get(); + auto dummyLoc = parsing::locateId(context, typeID); + + auto tag = parsing::idToTag(context, typeID); + owned genFn; + if (asttags::isInterface(tag)) { + auto itf = parsing::idToAst(context, typeID)->toInterface(); + genFn = typeConstructorFnForInterface(context, itf, builder, dummyLoc); + } else { + auto ct = initialTypeForTypeDecl(context, typeID)->getCompositeType(); + genFn = typeConstructorFnForComposite(context, ct, builder, dummyLoc); + } builder->noteChildrenLocations(genFn.get(), dummyLoc); builder->addToplevelExpression(std::move(genFn)); diff --git a/frontend/lib/resolution/intents.cpp b/frontend/lib/resolution/intents.cpp index eb387233bf91..72e9160a5491 100644 --- a/frontend/lib/resolution/intents.cpp +++ b/frontend/lib/resolution/intents.cpp @@ -98,6 +98,9 @@ static QualifiedType::Kind defaultIntentForType(const Type* t, return QualifiedType::CONST_IN; } + if (t->isPlaceholderType()) + return QualifiedType::DEFAULT_INTENT; + // Otherwise, it should be a generic type that we will // instantiate before computing the final intent. CHPL_ASSERT(t->genericity() != Type::CONCRETE); diff --git a/frontend/lib/resolution/interface-types.cpp b/frontend/lib/resolution/interface-types.cpp new file mode 100644 index 000000000000..b4859295574a --- /dev/null +++ b/frontend/lib/resolution/interface-types.cpp @@ -0,0 +1,105 @@ +/* + * Copyright 2024 Hewlett Packard Enterprise Development LP + * Other additional copyright holders may be indicated within. + * + * The entirety of this work is licensed under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "chpl/resolution/interface-types.h" + +namespace chpl { +namespace resolution { + +using namespace uast; +using namespace types; + +owned const& +ImplementationPoint::getImplementationPoint(Context* context, const InterfaceType* interface, + ID id) { + QUERY_BEGIN(getImplementationPoint, context, interface, id); + auto result = toOwned(new ImplementationPoint(interface, std::move(id))); + return QUERY_END(result); +} + +const ImplementationPoint* +ImplementationPoint::get(Context* context, const InterfaceType* interface, + ID id) { + return getImplementationPoint(context, interface, std::move(id)).get(); +} + +void ImplementationPoint::mark(Context* context) const { + interface_->mark(context); + id_.mark(context); +} + +void ImplementationPoint::stringify(std::ostream& ss, chpl::StringifyKind stringKind) const { + ss << "implements("; + interface_->stringify(ss, stringKind); + ss << " via )"; + id_.stringify(ss, stringKind); +} + +const owned& +ImplementationWitness::getImplementationWitness(Context* context, + ConstraintMap associatedConstraints, + AssociatedTypeMap associatedTypes, + FunctionMap requiredFns) { + QUERY_BEGIN(getImplementationWitness, context, associatedConstraints, + associatedTypes, requiredFns); + + auto result = toOwned(new ImplementationWitness(std::move(associatedConstraints), + std::move(associatedTypes), + std::move(requiredFns))); + + return QUERY_END(result); +} + +ImplementationWitness* ImplementationWitness::get(Context* context, + ConstraintMap associatedConstraints, + AssociatedTypeMap associatedTypes, + FunctionMap requiredFns) { + return getImplementationWitness(context, std::move(associatedConstraints), + std::move(associatedTypes), + std::move(requiredFns)).get(); +} + +void ImplementationWitness::stringify(std::ostream& ss, chpl::StringifyKind stringKind) const { + ss << "witness("; + ss << "constraints: "; + for (auto& c : associatedConstraints_) { + c.first.stringify(ss, stringKind); + ss << " => "; + c.second->stringify(ss, stringKind); + ss << ", "; + } + ss << "associated types: "; + for (auto& a : associatedTypes_) { + a.first.stringify(ss, stringKind); + ss << " => "; + a.second->stringify(ss, stringKind); + ss << ", "; + } + ss << "required functions: "; + for (auto& f : requiredFns_) { + f.first.stringify(ss, stringKind); + ss << " => "; + f.second.stringify(ss, stringKind); + ss << ", "; + } + ss << ")"; +} + +} // end namespace resolution +} // end namespace chpl diff --git a/frontend/lib/resolution/prims.cpp b/frontend/lib/resolution/prims.cpp index 06fc14916779..c7a37be0f3ff 100644 --- a/frontend/lib/resolution/prims.cpp +++ b/frontend/lib/resolution/prims.cpp @@ -106,6 +106,11 @@ static QualifiedType makeParamBool(Context* context, bool b) { BoolParam::get(context, b) }; } +static QualifiedType makeParamInt(Context* context, int64_t i) { + return { QualifiedType::PARAM, IntType::get(context, 0), + IntParam::get(context, i) }; +} + static QualifiedType makeParamString(Context* context, UniqueString s) { return { QualifiedType::PARAM, RecordType::getStringType(context), StringParam::get(context, s) }; @@ -283,6 +288,44 @@ static QualifiedType primCallResolves(ResolutionContext* rc, BoolParam::get(context, callAndFnResolved)); } +static QualifiedType primImplementsInterface(Context* context, + const PrimCall* astForErr, + const CallInfo& ci, + const Scope* inScope, + const PoiScope* inPoiScope) { + if (ci.numActuals() != 2) return QualifiedType(); + + auto& type = ci.actual(0).type(); + auto& ifqt = ci.actual(1).type(); + + if (ifqt.kind() != QualifiedType::TYPE || + ifqt.isUnknownOrErroneous() || + !ifqt.type()->isInterfaceType()) return QualifiedType(); + + auto ift = ifqt.type()->toInterfaceType(); + auto instantiatedIft = InterfaceType::withTypes(context, ift, { type }); + if (!instantiatedIft) return QualifiedType(); + + ResolutionContext rc(context); + auto inScopes = CallScopeInfo::forNormalCall(inScope, inPoiScope); + auto witness = + findMatchingImplementationPoint(&rc, instantiatedIft, inScopes); + + if (witness) { + return makeParamInt(context, 0); + } + + // try automatically satisfy the interface if it's in the standard modules. + if (parsing::idIsInBundledModule(context, ift->id())) { + auto runResult = context->runAndTrackErrors([&](Context* context) { + return checkInterfaceConstraints(&rc, instantiatedIft, astForErr->id(), inScopes); + }); + witness = runResult.result(); + } + + return makeParamInt(context, witness ? 1 : 2); +} + static QualifiedType computeDomainType(Context* context, const CallInfo& ci) { if (ci.numActuals() == 3) { auto type = DomainType::getRectangularType(context, @@ -1262,10 +1305,13 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, break; case PRIM_RESOLVES: - case PRIM_IMPLEMENTS_INTERFACE: CHPL_UNIMPL("various primitives"); break; + case PRIM_IMPLEMENTS_INTERFACE: + type = primImplementsInterface(context, call, ci, inScope, inPoiScope); + break; + case PRIM_IS_STAR_TUPLE_TYPE: if (ci.numActuals() == 1) { bool result = false; diff --git a/frontend/lib/resolution/resolution-error-classes-list.cpp b/frontend/lib/resolution/resolution-error-classes-list.cpp index 7bf9e9f9d8f3..90559a5c1815 100644 --- a/frontend/lib/resolution/resolution-error-classes-list.cpp +++ b/frontend/lib/resolution/resolution-error-classes-list.cpp @@ -610,6 +610,282 @@ void ErrorIncompatibleYieldTypes::write(ErrorWriterBase& wr) const { } } +static std::string buildTupleDeclName(const uast::TupleDecl* tup) { + std::string ret = "("; + int count = 0; + for (auto decl : tup->decls()) { + if (count != 0) { + ret += ","; + } + count += 1; + + if (decl->isTupleDecl()) { + ret += buildTupleDeclName(decl->toTupleDecl()); + } else { + ret += decl->toFormal()->name().str(); + } + } + + if (count == 1) { + ret += ","; + } + + ret += ")"; + + return ret; +} + +template +static void printRejectedCandidates(ErrorWriterBase& wr, + const ID& anchorId, + const resolution::CallInfo& ci, + const std::vector& rejected, + const char* passedThingArticle, + const char* passedThing, + const char* expectedThingArticle, + const char* expectedThing, + GetActual&& getActual) { + unsigned int printCount = 0; + static const unsigned int maxPrintCount = 2; + for (auto& candidate : rejected) { + if (printCount == maxPrintCount) break; + printCount++; + + auto reason = candidate.reason(); + wr.message(""); + if (reason == resolution::FAIL_CANNOT_PASS && + /* skip printing detailed info_ here because computing the formal-actual + map will go poorly with an unknown formal. */ + candidate.formalReason() != resolution::FAIL_UNKNOWN_FORMAL_TYPE) { + auto fn = candidate.initialForErr(); + resolution::FormalActualMap fa(fn, ci); + auto badPass = fa.byFormalIdx(candidate.formalIdx()); + auto formalDecl = badPass.formal(); + const uast::AstNode* actualExpr = getActual(badPass.actualIdx()); + + wr.note(fn->id(), "the following candidate didn't match because ", passedThingArticle, " ", passedThing, " couldn't be passed to ", expectedThingArticle, " ", expectedThing, ":"); + wr.code(fn->id(), { formalDecl }); + + std::string formalName; + if (auto named = formalDecl->toNamedDecl()) { + formalName = "'" + named->name().str() + "'"; + } else if (formalDecl->isTupleDecl()) { + formalName = "'" + buildTupleDeclName(formalDecl->toTupleDecl()) + "'"; + } + + if (badPass.formalType().isUnknown()) { + // The formal type can be unknown in an initial instantiation if it + // depends on the previous formals' types. In that case, don't print it + // and say something nicer. + wr.message("The instantiated type of ", expectedThing, " ", formalName, + " does not allow ", passedThing, "s of type '", badPass.actualType().type(), "'."); + } else { + wr.message("The ", expectedThing, " ", formalName, " expects ", badPass.formalType(), + ", but the ", passedThing, " was ", badPass.actualType(), "."); + } + + if (actualExpr) { + wr.code(actualExpr, { actualExpr }); + } + + auto formalReason = candidate.formalReason(); + if (formalReason == resolution::FAIL_INCOMPATIBLE_NILABILITY) { + auto formalDec = badPass.formalType().type()->toClassType()->decorator(); + auto actualDec = badPass.actualType().type()->toClassType()->decorator(); + + wr.message("The ", expectedThing, " expects a ", nilabilityStr(formalDec), " class, " + "but the ", passedThing, " is ", nilabilityStr(actualDec), "."); + } else if (formalReason == resolution::FAIL_INCOMPATIBLE_MGR) { + auto formalMgr = badPass.formalType().type()->toClassType()->manager(); + auto actualMgr = badPass.actualType().type()->toClassType()->manager(); + + wr.message("A class with '", actualMgr, "' management cannot be passed to ", expectedThingArticle, " ", expectedThing, " with '", formalMgr, "' management."); + } else if (formalReason == resolution::FAIL_EXPECTED_SUBTYPE) { + wr.message("Formals with kind '", badPass.formalType().kind(), + "' expect the ", passedThing, " to be a subtype, but '", badPass.actualType().type(), + "' is not a subtype of '", badPass.formalType().type(), "'."); + } else if (formalReason == resolution::FAIL_INCOMPATIBLE_TUPLE_SIZE) { + auto formalTup = badPass.formalType().type()->toTupleType(); + auto actualTup = badPass.actualType().type()->toTupleType(); + + wr.message("A tuple with ", actualTup->numElements(), + " elements cannot be passed to a tuple formal with ", + formalTup->numElements(), " elements."); + } else if (formalReason == resolution::FAIL_INCOMPATIBLE_TUPLE_STAR) { + auto formalTup = badPass.formalType().type()->toTupleType(); + auto actualTup = badPass.actualType().type()->toTupleType(); + + const char* formalStr = formalTup->isStarTuple() ? "is" : "is not"; + const char* actualStr = actualTup->isStarTuple() ? "is" : "is not"; + + wr.message("A ", expectedThing, " that ", formalStr, " a star tuple cannot accept ", passedThingArticle," ", passedThing," that ", actualStr, "."); + } else if (formalReason == resolution::FAIL_NOT_EXACT_MATCH) { + wr.message("The 'ref' intent requires the ", expectedThing, " and ", passedThing, " types to match exactly."); + } + } else { + std::string reasonStr = ""; + if (reason == resolution::FAIL_FORMAL_ACTUAL_MISMATCH) { + reasonStr = std::string("the provided ") + passedThing + "s could not be mapped to its " + expectedThing + "s:"; + } else if (reason == resolution::FAIL_VARARG_MISMATCH) { + reasonStr = "the number of varargs was incorrect:"; + } else if (reason == resolution::FAIL_WHERE_CLAUSE) { + reasonStr = "the 'where' clause evaluated to 'false':"; + } else if (reason == resolution::FAIL_PARENLESS_MISMATCH) { + if (ci.isParenless()) { + reasonStr = "it is parenful, but the call was parenless:"; + } else { + reasonStr = "it is parenless, but the call was parenful:"; + } + } else if (reason == resolution::FAIL_INTERFACE_NOT_TYPE_INTENT) { + reasonStr = "it did not return a type as required from associated type procedures:"; + } + + if (reasonStr.empty()) { + wr.note(candidate.idForErr(), "the following candidate didn't match:"); + } else { + wr.note(candidate.idForErr(), "the following candidate didn't match ", + "because ", reasonStr); + } + wr.code(candidate.idForErr()); + } + } + + if (printCount < rejected.size()) { + wr.message(""); + wr.note(locationOnly(anchorId), "omitting ", rejected.size() - printCount, " more candidates that didn't match."); + } +} + +void ErrorInterfaceAmbiguousFn::write(ErrorWriterBase& wr) const { + auto interface = std::get(info_); + auto implPoint = std::get(info_); + auto fn = std::get(info_); + auto candidate = std::get>(info_); + + wr.heading(kind_, type_, implPoint, "unable to disambiguate candidates for function '", fn->name(), "'."); + wr.codeForDef(fn->id()); + wr.message("Required by the interface '", interface->name(), "':"); + wr.codeForDef(interface->id()); + wr.note(implPoint, "while checking the implementation point here:"); + wr.code(implPoint, { implPoint }); + + unsigned int printCount = 0; + static const unsigned int maxPrintCount = 2; + for (auto sig : candidate) { + if (printCount == maxPrintCount) break; + printCount++; + + wr.message(""); + wr.note(sig->id(), "one candidate was here:"); + wr.code(sig->id()); + } +} + +void ErrorInterfaceInvalidIntent::write(ErrorWriterBase& wr) const { + auto interface = std::get(info_); + auto implPoint = std::get(info_); + auto fnTemplate = std::get<2>(info_); + auto fnReal = std::get<3>(info_); + + wr.heading(kind_, type_, implPoint, "candidate for function '", fnTemplate->untyped()->name(), "' has mismatched return intent."); + wr.codeForDef(fnTemplate->id()); + wr.message("Required by the interface '", interface->name(), "':"); + wr.codeForDef(interface->id()); + wr.note(implPoint, "while checking the implementation point here:"); + wr.code(implPoint, { implPoint }); + wr.note(fnReal->id(), "the provided candidate does not have a matching return intent:"); + wr.codeForDef(fnReal->id()); +} + +void ErrorInterfaceMissingAssociatedType::write(ErrorWriterBase& wr) const { + auto interface = std::get(info_); + auto implPoint = std::get(info_); + auto var = std::get(info_); + auto ci = std::get(info_); + auto rejected = std::get>(info_); + + wr.heading(kind_, type_, implPoint, "unable to find matching candidates for associated type '", var->name(), "'."); + wr.codeForDef(var->id()); + wr.message("Required by the interface '", interface->name(), "':"); + wr.codeForDef(interface->id()); + wr.note(implPoint, "while checking the implementation point here:"); + wr.code(implPoint, { implPoint }); + wr.message("Associated types are resolved as 'type' calls on types constrained by the interface."); + + printRejectedCandidates(wr, implPoint, ci, rejected, "an", "actual", "a", "formal", [](int) -> const uast::AstNode* { + return nullptr; + }); +} + +void ErrorInterfaceMissingFn::write(ErrorWriterBase& wr) const { + auto interface = std::get(info_); + auto implPoint = std::get(info_); + auto fn = std::get(info_); + auto ci = std::get(info_); + auto rejected = std::get>(info_); + + wr.heading(kind_, type_, implPoint, "unable to find matching candidates for function '", fn->untyped()->name(), "'."); + wr.codeForDef(fn->id()); + wr.message("Required by the interface '", interface->name(), "':"); + wr.codeForDef(interface->id()); + wr.note(implPoint, "while checking the implementation point here:"); + wr.code(implPoint, { implPoint }); + + printRejectedCandidates(wr, implPoint, ci, rejected, "a", "required formal", "a", "cadidate formal", [fn](int idx) -> const uast::AstNode* { + if (idx >= 0 && idx < fn->numFormals()) { + return fn->untyped()->formalDecl(idx); + } + return nullptr; + }); +} + +void ErrorInterfaceMultipleImplements::write(ErrorWriterBase& wr) const { + auto ad = std::get(info_); + auto interface = std::get(info_); + auto implPoint = std::get<2>(info_); + auto otherImplPoint = std::get<3>(info_); + + wr.heading(kind_, type_, ad, "multiple implementations of interface '", interface->name(), "' found."); + wr.message("While analyzing the definition of type '", ad->name(), "', defined here:"); + wr.codeForDef(ad); + wr.note(implPoint, "the interface '", interface->name(), "' is first implemented here:"); + wr.code(implPoint, { implPoint }); + wr.note(otherImplPoint, "it is also implemented here:"); + wr.code(otherImplPoint, { otherImplPoint }); +} + +void ErrorInterfaceNaryInInherits::write(ErrorWriterBase& wr) const { + auto ad = std::get(info_); + auto interface = std::get(info_); + auto implPoint = std::get(info_); + + wr.heading(kind_, type_, ad, "cannot use interface '", interface->name(), "' in inheritance expression as it is not a unary interface."); + wr.message("While analyzing the definition of type '", ad->name(), "', defined here:"); + wr.codeForDef(ad); + wr.note(implPoint, "found the interface '", interface->name(), "' in an inheritance list here:"); + wr.code(implPoint, { implPoint }); + wr.note(interface->id(), "However, the interface '", interface->name(), "' is defined to be a non-unary interface here:"); + wr.codeForDef(interface->id()); + wr.message("Only unary interfaces (those with a single type parameter like 'Self') can be used in inheritance expressions."); + wr.message("To implement n-ary interfaces, consider using a standalone 'implements' statement."); +} + +void ErrorInterfaceReorderedFnFormals::write(ErrorWriterBase& wr) const { + auto interface = std::get(info_); + auto implPoint = std::get(info_); + auto fnTemplate = std::get<2>(info_); + auto fnReal = std::get<3>(info_); + + wr.heading(kind_, type_, implPoint, "candidate for function '", fnTemplate->untyped()->name(), "' does not have the same order of formal names."); + wr.codeForDef(fnTemplate->id()); + wr.message("Required by the interface '", interface->name(), "':"); + wr.codeForDef(interface->id()); + wr.note(implPoint, "while checking the implementation point here:"); + wr.code(implPoint, { implPoint }); + wr.note(fnReal->id(), "the provided candidate defined here does not have matching formals:"); + wr.codeForDef(fnReal->id()); +} + void ErrorInvalidClassCast::write(ErrorWriterBase& wr) const { auto primCall = std::get(info_); auto& type = std::get(info_); @@ -696,6 +972,37 @@ void ErrorInvalidDomainCall::write(ErrorWriterBase& wr) const { } } +void ErrorInvalidImplementsActual::write(ErrorWriterBase& wr) const { + auto impl = std::get(info_); + auto actual = std::get(info_); + auto& qt = std::get(info_); + + wr.heading(kind_, type_, impl, "invalid use of 'implements' with an actual that is ", qt, "."); + wr.codeForLocation(impl); + wr.message("The actual is provided here:"); + wr.code(actual, { actual }); + wr.message("Only 'type' actuals are allowed in implementation points."); +} + +void ErrorInvalidImplementsArity::write(ErrorWriterBase& wr) const { + auto impl = std::get(info_); + auto interface = std::get(info_); + auto& actuals = std::get>(info_); + std::ignore = actuals; + + wr.heading(kind_, type_, impl, "wrong number of actuals in 'implements' statement for interface '", interface->name(), "'."); + wr.codeForLocation(impl); +} + +void ErrorInvalidImplementsInterface::write(ErrorWriterBase& wr) const { + auto impl = std::get(info_); + auto& qt = std::get(info_); + + wr.heading(kind_, type_, impl, "invalid 'implements' statement."); + wr.codeForLocation(impl); + wr.message("The statement attempts to implement ", qt, ", which is not an interface."); +} + void ErrorInvalidIndexCall::write(ErrorWriterBase& wr) const { auto fnCall = std::get(info_); auto& type = std::get(info_); @@ -980,31 +1287,6 @@ void ErrorNestedClassFieldRef::write(ErrorWriterBase& wr) const { wr.codeForDef(id); } -static std::string buildTupleDeclName(const uast::TupleDecl* tup) { - std::string ret = "("; - int count = 0; - for (auto decl : tup->decls()) { - if (count != 0) { - ret += ","; - } - count += 1; - - if (decl->isTupleDecl()) { - ret += buildTupleDeclName(decl->toTupleDecl()); - } else { - ret += decl->toFormal()->name().str(); - } - } - - if (count == 1) { - ret += ","; - } - - ret += ")"; - - return ret; -} - void ErrorNoMatchingCandidates::write(ErrorWriterBase& wr) const { auto node = std::get(info_); auto call = node->toCall(); @@ -1014,115 +1296,24 @@ void ErrorNoMatchingCandidates::write(ErrorWriterBase& wr) const { wr.heading(kind_, type_, node, "unable to resolve call to '", ci.name(), "': no matching candidates."); wr.code(node); - unsigned int printCount = 0; - static const unsigned int maxPrintCount = 2; - for (auto& candidate : rejected) { - if (printCount == maxPrintCount) break; - printCount++; - - auto reason = candidate.reason(); - wr.message(""); - if (reason == resolution::FAIL_CANNOT_PASS && - /* skip printing detailed info_ here because computing the formal-actual - map will go poorly with an unknown formal. */ - candidate.formalReason() != resolution::FAIL_UNKNOWN_FORMAL_TYPE) { - auto fn = candidate.initialForErr(); - resolution::FormalActualMap fa(fn, ci); - auto badPass = fa.byFormalIdx(candidate.formalIdx()); - auto formalDecl = badPass.formal(); - const uast::AstNode* actualExpr = nullptr; - if (call && 0 <= badPass.actualIdx() && badPass.actualIdx() < call->numActuals()) { - actualExpr = call->actual(badPass.actualIdx()); - } - - wr.note(fn->id(), "the following candidate didn't match because an actual couldn't be passed to a formal:"); - wr.code(fn->id(), { formalDecl }); - - std::string formalName; - if (auto named = formalDecl->toNamedDecl()) { - formalName = "'" + named->name().str() + "'"; - } else if (formalDecl->isTupleDecl()) { - formalName = "'" + buildTupleDeclName(formalDecl->toTupleDecl()) + "'"; - } - - if (badPass.formalType().isUnknown()) { - // The formal type can be unknown in an initial instantiation if it - // depends on the previous formals' types. In that case, don't print it - // and say something nicer. - wr.message("The instantiated type of formal ", formalName, - " does not allow actuals of type '", badPass.actualType().type(), "'."); - } else { - wr.message("The formal ", formalName, " expects ", badPass.formalType(), - ", but the actual was ", badPass.actualType(), "."); - } - - if (actualExpr) { - wr.code(actualExpr, { actualExpr }); - } - - auto formalReason = candidate.formalReason(); - if (formalReason == resolution::FAIL_INCOMPATIBLE_NILABILITY) { - auto formalDec = badPass.formalType().type()->toClassType()->decorator(); - auto actualDec = badPass.actualType().type()->toClassType()->decorator(); - - wr.message("The formal expects a ", nilabilityStr(formalDec), " class, " - "but the actual is ", nilabilityStr(actualDec), "."); - } else if (formalReason == resolution::FAIL_INCOMPATIBLE_MGR) { - auto formalMgr = badPass.formalType().type()->toClassType()->manager(); - auto actualMgr = badPass.actualType().type()->toClassType()->manager(); - - wr.message("A class with '", actualMgr, "' management cannot be passed to a formal with '", formalMgr, "' management."); - } else if (formalReason == resolution::FAIL_EXPECTED_SUBTYPE) { - wr.message("Formals with kind '", badPass.formalType().kind(), - "' expect the actual to be a subtype, but '", badPass.actualType().type(), - "' is not a subtype of '", badPass.formalType().type(), "'."); - } else if (formalReason == resolution::FAIL_INCOMPATIBLE_TUPLE_SIZE) { - auto formalTup = badPass.formalType().type()->toTupleType(); - auto actualTup = badPass.actualType().type()->toTupleType(); - - wr.message("A tuple with ", actualTup->numElements(), - " elements cannot be passed to a tuple formal with ", - formalTup->numElements(), " elements."); - } else if (formalReason == resolution::FAIL_INCOMPATIBLE_TUPLE_STAR) { - auto formalTup = badPass.formalType().type()->toTupleType(); - auto actualTup = badPass.actualType().type()->toTupleType(); - - const char* formalStr = formalTup->isStarTuple() ? "is" : "is not"; - const char* actualStr = actualTup->isStarTuple() ? "is" : "is not"; - - wr.message("A formal that ", formalStr, " a star tuple cannot accept an actual actual that ", actualStr, "."); - } else if (formalReason == resolution::FAIL_NOT_EXACT_MATCH) { - wr.message("The 'ref' intent requires the formal and actual types to match exactly."); - } - } else { - const char* reasonStr = nullptr; - if (reason == resolution::FAIL_FORMAL_ACTUAL_MISMATCH) { - reasonStr = "the provided actuals could not be mapped to its formals:"; - } else if (reason == resolution::FAIL_VARARG_MISMATCH) { - reasonStr = "the number of varargs was incorrect:"; - } else if (reason == resolution::FAIL_WHERE_CLAUSE) { - reasonStr = "the 'where' clause evaluated to 'false':"; - } else if (reason == resolution::FAIL_PARENLESS_MISMATCH) { - if (ci.isParenless()) { - reasonStr = "it is parenful, but the call was parenless:"; - } else { - reasonStr = "it is parenless, but the call was parenful:"; - } - } - if (!reasonStr) { - wr.note(candidate.idForErr(), "the following candidate didn't match:"); - } else { - wr.note(candidate.idForErr(), "the following candidate didn't match ", - "because ", reasonStr); - } - wr.code(candidate.idForErr()); + printRejectedCandidates(wr, node->id(), ci, rejected, "an", "actual", "a", "formal", [call](int idx) -> const uast::AstNode* { + if (call && 0 <= idx && idx < call->numActuals()) { + return call->actual(idx); } - } + return nullptr; + }); +} - if (printCount < rejected.size()) { - wr.message(""); - wr.note(locationOnly(node), "omitting ", rejected.size() - printCount, " more candidates that didn't match."); - } +void ErrorNonClassInheritance::write(ErrorWriterBase& wr) const { + auto ad = std::get(info_); + auto inheritanceExpr = std::get(info_); + auto& type = std::get(info_); + + wr.heading(kind_, type_, ad, "attempt for non-class type to inherit from a type."); + wr.message("While analyzing the definition of type '", ad->name(), "', defined here:"); + wr.codeForDef(ad); + wr.note(inheritanceExpr, "found an inheritance expression referring to type '", type, "' here:"); + wr.code(inheritanceExpr, { inheritanceExpr }); } static void printTheseResults( diff --git a/frontend/lib/resolution/resolution-queries.cpp b/frontend/lib/resolution/resolution-queries.cpp index 3e82c365041e..95a6594aeba8 100644 --- a/frontend/lib/resolution/resolution-queries.cpp +++ b/frontend/lib/resolution/resolution-queries.cpp @@ -170,6 +170,19 @@ static void updateTypeForModuleLevelSplitInit(Context* context, ID id, lhs.setType(useType); } +static void checkImplementationPoint(ResolutionContext* rc, const ImplementationPoint* implPoint) { + if (getTypeGenericity(rc->context(), implPoint->interface()) == types::Type::CONCRETE) { + auto inScope = scopeForId(rc->context(), implPoint->id()); + auto inScopes = CallScopeInfo::forNormalCall(inScope, nullptr); + std::ignore = checkInterfaceConstraints(rc, implPoint->interface(), implPoint->id(), inScopes); + // checkInterfaceConstraints emits an error already, nothing to do. + } +} + +static const std::map>& +collectImplementationPointsInModule(Context* context, + const Module* mod); + const ResolutionResultByPostorderID& resolveModule(Context* context, ID id) { QUERY_BEGIN(resolveModule, context, id); @@ -229,6 +242,14 @@ const ResolutionResultByPostorderID& resolveModule(Context* context, ID id) { } checkThrows(rc, result, mod); callInitDeinit(r); + + // check interface implementations in this module + auto implPoints = collectImplementationPointsInModule(context, mod); + for (const auto& pair : implPoints) { + for (const auto implPoint : pair.second) { + checkImplementationPoint(rc, implPoint); + } + } } } @@ -575,7 +596,8 @@ static bool errorIfParentFramesNotPresent(ResolutionContext* rc, static const TypedFnSignature* typedSignatureInitialImpl(ResolutionContext* rc, - const UntypedFnSignature* untypedSig) { + const UntypedFnSignature* untypedSig, + bool usePlaceholders) { Context* context = rc->context(); const TypedFnSignature* result = nullptr; const AstNode* ast = parsing::idToAst(context, untypedSig->id()); @@ -621,12 +643,15 @@ typedSignatureInitialImpl(ResolutionContext* rc, ResolutionResultByPostorderID r; auto visitor = Resolver::createForInitialSignature(rc, fn, r); + visitor.usePlaceholders = usePlaceholders; // visit the formals, but not the return type or body for (auto formal : fn->formals()) formal->traverse(visitor); if (!visitor.outerVariables.isEmpty()) { - CHPL_ASSERT(parentSignature); + // outer variables can come from a parent function or from + // an interface containing the function. + CHPL_ASSERT(parentSignature || parsing::idToParentInterfaceId(context, fn->id())); // Outer variables can't be typed without stack frames, so give up. if (errorIfParentFramesNotPresent(rc, untypedSig)) return nullptr; @@ -654,7 +679,7 @@ typedSignatureInitialImpl(ResolutionContext* rc, } if (!visitor.outerVariables.isEmpty()) { - CHPL_ASSERT(parentSignature); + CHPL_ASSERT(parentSignature || parsing::idToParentInterfaceId(context, fn->id())); // Outer variables can't be typed without stack frames, so give up. if (errorIfParentFramesNotPresent(rc, untypedSig)) return nullptr; @@ -684,7 +709,7 @@ const TypedFnSignature* const& typedSignatureInitial(ResolutionContext* rc, const UntypedFnSignature* untypedSig) { CHPL_RESOLUTION_QUERY_BEGIN(typedSignatureInitial, rc, untypedSig); - auto ret = typedSignatureInitialImpl(rc, untypedSig); + auto ret = typedSignatureInitialImpl(rc, untypedSig, /* usePlaceholders */ false); return CHPL_RESOLUTION_QUERY_END(ret); } @@ -703,6 +728,21 @@ typedSignatureInitialForId(ResolutionContext* rc, ID id) { return typedSignatureInitialForIdQuery(rc, std::move(id)); } +static const TypedFnSignature* const& +typedSignatureTemplateForIdQuery(ResolutionContext* rc, ID id) { + CHPL_RESOLUTION_QUERY_BEGIN(typedSignatureTemplateForIdQuery, rc, id); + Context* context = rc->context(); + const UntypedFnSignature* uSig = UntypedFnSignature::get(context, id); + const TypedFnSignature* ret = uSig ? typedSignatureInitialImpl(rc, uSig, /* usePlaceholders */ true) + : nullptr; + return CHPL_RESOLUTION_QUERY_END(ret); +} + +const TypedFnSignature* +typedSignatureTemplateForId(ResolutionContext* rc, ID id) { + return typedSignatureTemplateForIdQuery(rc, id); +} + // initedInParent is true if the decl variable is inited due to a parent // uast node. This comes up for TupleDecls. static void helpSetFieldTypes(const CompositeType* ct, @@ -785,6 +825,21 @@ const Type* initialTypeForTypeDecl(Context* context, ID declId) { return initialTypeForTypeDeclQuery(context, declId); } +static const Type* const& +initialTypeForInterfaceQuery(Context* context, ID declId) { + QUERY_BEGIN(initialTypeForInterfaceQuery, context, declId); + const Type* result = nullptr; + auto ast = parsing::idToAst(context, declId); + if (auto itf = ast->toInterface()) { + result = InterfaceType::get(context, itf->id(), itf->name(), /* subs */ {}); + } + return QUERY_END(result); +} + +const Type* initialTypeForInterface(Context* context, ID declId) { + return initialTypeForInterfaceQuery(context, declId); +} + const ResolvedFields& resolveFieldDecl(Context* context, const CompositeType* ct, ID fieldId, @@ -1244,6 +1299,32 @@ static Type::Genericity getFieldsGenericity(Context* context, return g; } +static Type::Genericity getInterfaceActualsGenericity(Context* context, + const InterfaceType* ift, + std::set& ignore) { + // add the current type to the ignore set, and stop now + // if it is already in the ignore set. + auto it = ignore.insert(ift); + if (it.second == false) { + // set already contained ct, so stop & consider it concrete + return Type::CONCRETE; + } + + if (ift->substitutions().empty()) return Type::GENERIC; + + auto itf = parsing::idToAst(context, ift->id())->toInterface(); + CHPL_ASSERT(itf); + for (auto formal : itf->formals()) { + // if the substitutions aren't empty, expect substitutions for all types + auto& qt = ift->substitutions().at(formal->id()); + if (getTypeGenericityIgnoring(context, qt.type(), ignore) != Type::CONCRETE) { + return Type::GENERIC; + } + } + + return Type::CONCRETE; +} + Type::Genericity getTypeGenericityIgnoring(Context* context, const Type* t, std::set& ignore) { if (t == nullptr) @@ -1271,7 +1352,11 @@ Type::Genericity getTypeGenericityIgnoring(Context* context, const Type* t, // MAYBE_GENERIC should only be returned for CompositeType / // ClassType right now. - CHPL_ASSERT(t->isCompositeType() || t->isClassType()); + CHPL_ASSERT(t->isCompositeType() || t->isClassType() || t->isInterfaceType()); + + if (auto ift = t->toInterfaceType()) { + return getInterfaceActualsGenericity(context, ift, ignore); + } // the tuple type that isn't an instantiation is a generic type if (auto tt = t->toTupleType()) { @@ -1548,7 +1633,15 @@ typeConstructorInitialQuery(Context* context, const Type* t) const TypedFnSignature* result = nullptr; - ID id = t->getCompositeType()->id(); + ID id; + if (auto ct = t->getCompositeType()) { + id = ct->id(); + } else if (auto ift = t->toInterfaceType()) { + id = ift->id(); + } else { + CHPL_ASSERT(false && "invalid argument to typeConstructorInitialQuery"); + } + UniqueString name; std::vector formals; @@ -2812,6 +2905,156 @@ const ResolvedFunction* resolveFunction(ResolutionContext* rc, return helpResolveFunction(rc, sig, poiScope, skipIfRunning); } +static const ImplementationPoint* const& +resolveImplementsStmtQuery(Context* context, ID id) { + QUERY_BEGIN(resolveImplementsStmtQuery, context, id); + const ImplementationPoint* result = nullptr; + + auto byPostorder = resolveModuleStmt(context, id); + auto ast = parsing::idToAst(context, id); + CHPL_ASSERT(ast->isImplements()); + auto impl = ast->toImplements(); + + auto interfaceExpr = impl->interfaceExpr(); + QualifiedType interfaceQt; + if (auto interfaceIdent = interfaceExpr->toIdentifier()) { + interfaceQt = byPostorder.byAst(interfaceIdent).type(); + } else if (auto interfaceCall = interfaceExpr->toFnCall()) { + interfaceQt = byPostorder.byAst(interfaceCall->calledExpression()).type(); + } + + if (!interfaceQt.isType() || interfaceQt.isUnknown() || + interfaceQt.type()->toInterfaceType() == nullptr) { + CHPL_REPORT(context, InvalidImplementsInterface, impl, interfaceQt); + } else { + auto genericIft = interfaceQt.type()->toInterfaceType(); + std::vector actuals; + + auto addActual = [&byPostorder, &actuals, context, impl](const AstNode* actual) { + auto& actualType = byPostorder.byAst(actual).type(); + if (actualType.isUnknownOrErroneous()) { + return false; + } else if (!actualType.isType()) { + CHPL_REPORT(context, InvalidImplementsActual, impl, actual, actualType); + return false; + } + actuals.push_back(actualType); + return true; + }; + + bool addPoint = true; + if (auto typeIdent = impl->typeIdent()) { + addPoint &= addActual(typeIdent); + } + if (auto interfaceCall = impl->interfaceExpr()->toFnCall()) { + for (auto actual : interfaceCall->actuals()) { + if (!addPoint) break; // already found a broken actual + addPoint &= addActual(actual); + } + } + + if (addPoint) { + auto ift = InterfaceType::withTypes(context, genericIft, actuals); + if (!ift) { + CHPL_REPORT(context, InvalidImplementsArity, impl, genericIft, actuals); + } else { + result = ImplementationPoint::get(context, ift, id); + } + } + } + + return QUERY_END(result); +} + +const ImplementationPoint* resolveImplementsStmt(Context* context, + ID id) { + return resolveImplementsStmtQuery(context, id); +} + +static const std::map>& +collectImplementationPointsInModule(Context* context, + const uast::Module* module) { + QUERY_BEGIN(collectImplementationPointsInModule, context, module); + std::map> byInterfaceId; + + for (auto stmt : module->stmts()) { + if (auto ad = stmt->toAggregateDecl()) { + auto& implPoints = getImplementedInterfaces(context, ad); + for (auto implPoint : implPoints) { + byInterfaceId[implPoint->interface()->id()].push_back(implPoint); + } + } else if (auto implements = stmt->toImplements()) { + auto implPoint = resolveImplementsStmt(context, implements->id()); + if (implPoint) { + byInterfaceId[implPoint->interface()->id()].push_back(implPoint); + } + } + } + + return QUERY_END(byInterfaceId); +} + +static const std::map>& +collectImplementationPointsInScope(Context* context, + const Scope* scope) { + QUERY_BEGIN(collectImplementationPointsInScope, context, scope); + CHPL_ASSERT(scope->moduleScope() == scope); + + auto module = parsing::idToAst(context, scope->id())->toModule(); + auto& result = collectImplementationPointsInModule(context, module); + + return QUERY_END(result); +} + +static void +helpCollectVisibileImplementationPoints(Context* context, + const Scope* scope, + std::unordered_set& seen, + std::map>& into) { + auto insertResult = seen.insert(scope); + if (!insertResult.second) return; + + auto& inScope = collectImplementationPointsInScope(context, scope); + for (auto& pointsForScope : inScope) { + auto& copyInto = into[pointsForScope.first]; + for (auto implPoint : pointsForScope.second) { + copyInto.push_back(implPoint); + } + } + + if (auto visStmts = resolveVisibilityStmts(context, scope)) { + for (auto visClause : visStmts->visibilityClauses()) { + auto nextScope = visClause.scope(); + if (nextScope && asttags::isModule(nextScope->tag())) { + helpCollectVisibileImplementationPoints(context, nextScope, seen, into); + } + } + } +} + +static const std::map>& +visibleImplementationPoints(Context* context, + const Scope* scope) { + QUERY_BEGIN(visibleImplementationPoints, context, scope); + std::map> result; + std::unordered_set seen; + helpCollectVisibileImplementationPoints(context, scope, seen, result); + return QUERY_END(result); +} + +const std::vector* +visibileImplementationPointsForInterface(Context* context, + const Scope* scope, + ID id) { + auto& allInstantiationPoints = visibleImplementationPoints(context, scope); + auto it = allInstantiationPoints.find(id); + if (it != allInstantiationPoints.end()) { + return &it->second; + } + + return nullptr; +} + const ResolvedFunction* resolveConcreteFunction(Context* context, ID id) { if (id.isEmpty()) return nullptr; @@ -5139,6 +5382,343 @@ const TypedFnSignature* tryResolveDeinit(Context* context, return c.mostSpecific().only().fn(); } +static bool +matchImplementationPoint(ResolutionContext* rc, + const InterfaceType* ift, + const ImplementationPoint* implPoint, + bool& outIsGeneric) { + if (ift->id() != implPoint->interface()->id()) return false; + + // Use 'const var' so that canPass doesn't say that a type instantiates + // itself. + auto actualT = QualifiedType(QualifiedType::CONST_VAR, ift); + auto formalT = QualifiedType(QualifiedType::CONST_VAR, implPoint->interface()); + auto got = canPass(rc->context(), actualT, formalT); + + outIsGeneric = got.instantiates(); + return got.passes(); +} + +const ImplementationWitness* findMatchingImplementationPoint(ResolutionContext* rc, + const types::InterfaceType* ift, + const CallScopeInfo& inScopes) { + auto implPoints = + visibileImplementationPointsForInterface(rc->context(), inScopes.lookupScope()->moduleScope(), ift->id()); + + // TODO: this matches production, in which the first matching generic + // implementation is used if no concrete one is found. It's probably + // better to use the same disambiguation rules as functions, though. + // I don't see a particularly nice way to do that, though. + const ImplementationPoint* generic = nullptr; + if (implPoints) { + for (auto implPoint : *implPoints) { + bool isGeneric = false; + if (matchImplementationPoint(rc, ift, implPoint, isGeneric)) { + if (isGeneric && generic == nullptr) { + generic = implPoint; + } else if (!isGeneric) { + // For a concrete instantiation point, the current search scope + // is irrelevant; use the point's scope for the search. + + auto implScope = scopeForId(rc->context(), implPoint->id()); + auto checkScope = CallScopeInfo::forNormalCall(implScope, nullptr); + if (auto witness = checkInterfaceConstraints(rc, ift, implPoint->id(), checkScope)) { + return witness; + } + } + } + } + } + + if (generic) { + // For a generic instantiation point, construct a new PoI scope from + // the current search scope, and use the point's scope for the search. + + auto implScope = scopeForId(rc->context(), generic->id()); + auto poiScope = pointOfInstantiationScope(rc->context(), inScopes.callScope(), inScopes.poiScope()); + auto checkScope = CallScopeInfo::forNormalCall(implScope, poiScope); + if (auto witness = checkInterfaceConstraints(rc, ift, generic->id(), checkScope)) { + return witness; + } + } + + return nullptr; +} + +static ID searchFunctionByTemplate(ResolutionContext* rc, + const InterfaceType* iftForErr, + const ID& implPointIdForErr, + const Function* fn, + const TypedFnSignature* tfs, + const CallScopeInfo& inScopes) { + std::vector rejected; + std::vector ambiguous; + + std::vector actuals; + for (int i = 0; i < tfs->numFormals(); i++) { + auto decl = tfs->untyped()->formalDecl(i); + auto name = UniqueString(); + if (auto formal = decl->toFormal()) { + name = formal->name(); + } else if (decl->isVarArgFormal()) { + CHPL_UNIMPL("vararg formals in interface function requirements"); + return ID(); + } + actuals.emplace_back(tfs->formalType(i), name); + } + + CallInfo ci { + tfs->untyped()->name(), + /* calledType */ QualifiedType(), + /* isMethodCall */ fn->isMethod(), + /* hasQuestionArg */ false, + /* isParenless */ fn->isParenless(), + std::move(actuals) + }; + + // TODO: how to note this? + auto c = + resolveGeneratedCall(rc->context(), fn, ci, inScopes); + + bool failed = c.exprType().isUnknownOrErroneous(); + if (failed && fn->body()) { + // template has a default implementation; return it, we're good. + return fn->id(); + } else if (failed && c.mostSpecific().isAmbiguous()) { + // TODO: no way at this time to collected candidates rejected due to + // ambiguity, so just report an empty list. + CHPL_REPORT(rc->context(), InterfaceAmbiguousFn, iftForErr, implPointIdForErr, + fn, std::move(ambiguous)); + return ID(); + } else if (failed) { + // Failed to find a call, not due to ambiguity. Re-run call and gather + // rejected candidates. + resolveGeneratedCall(rc->context(), fn, ci, inScopes, &rejected); + CHPL_REPORT(rc->context(), InterfaceMissingFn, iftForErr, implPointIdForErr, + tfs, ci, std::move(rejected)); + return ID(); + } + + CHPL_ASSERT(!failed); + + const TypedFnSignature* foundFn = nullptr; + if (c.mostSpecific().numBest() > 1) { + CHPL_UNIMPL("return intent overloading in interface constraint checking"); + return ID(); + } else { + // There's only one function; we should still check its intent. + foundFn = c.mostSpecific().only().fn(); + auto foundIntent = parsing::idToFnReturnIntent(rc->context(), foundFn->id()); + + if (foundIntent == fn->returnIntent()) { + // fine + } else { + CHPL_REPORT(rc->context(), InterfaceInvalidIntent, iftForErr, implPointIdForErr, + tfs, foundFn); + return ID(); + } + } + + // Validate that the formal names are in the right order. We could do this + // by resolving a generated call with every actual being named, but this + // will miss some cases, like when the template is: + // + // proc f(x: int, y: int) + // + // and the actual function is: + // + // proc f(y: int, x: int): + // + // such calls would resolve match with and without named actuals, but the + // actuals are not in the right order compared to the template. + FormalActualMap faMap(foundFn, ci); + + int lastActualPosition = -1; + for (auto& formalActual : faMap.byFormals()) { + if (formalActual.actualIdx() == -1) { + // Allow defaulted formals to be skipped vs. a template, so that a template + // + // proc foo(); + // + // Can match: + // + // proc foo(x: int = 10) {} + continue; + } + + if (formalActual.actualIdx() <= lastActualPosition) { + // the actuals in the call (which are formals from the template) + // are re-ordered compared to the foundFn. This is an error. + CHPL_REPORT(rc->context(), InterfaceReorderedFnFormals, iftForErr, implPointIdForErr, + tfs, foundFn); + return ID(); + } + lastActualPosition = formalActual.actualIdx(); + } + + // ordering is fine, so foundFn is good to go. + return foundFn->id(); +} + +static QualifiedType searchForAssociatedType(ResolutionContext* rc, + const InterfaceType* iftForErr, + const ID& implPointIdForErr, + const QualifiedType& receiverType, + const Variable* td, + const CallScopeInfo& inScopes) { + // Set up a parenless type-proc call to compute associated type + auto ci = CallInfo( + td->name(), + /* calledType */ QualifiedType(), + /* isMethodCall */ true, + /* hasQuestionArg */ false, + /* isParenless */ true, + /* actuals */ { { receiverType, USTR("this") } } + ); + + // TODO: how to note this? + auto c = + resolveGeneratedCall(rc->context(), td, ci, inScopes); + + std::vector rejected; + bool failed = c.exprType().isUnknownOrErroneous(); + bool notType = false; + if (!failed) { + notType = !c.exprType().isType(); + if (notType) { + rejected.push_back( + ApplicabilityResult::failure(c.mostSpecific().only().fn()->id(), + FAIL_INTERFACE_NOT_TYPE_INTENT)); + } + } else { + resolveGeneratedCall(rc->context(), td, ci, inScopes, &rejected); + } + + if (failed || notType) { + CHPL_REPORT(rc->context(), InterfaceMissingAssociatedType, iftForErr, + implPointIdForErr, td, ci, std::move(rejected)); + return QualifiedType(); + } + + return c.exprType(); +} + +static const ImplementationWitness* const& +checkInterfaceConstraintsQuery(ResolutionContext* rc, + const InterfaceType* ift, + const ID& implPointIdForErr, + const Scope* inScope, + const PoiScope* inPoiScope) { + CHPL_RESOLUTION_QUERY_BEGIN(checkInterfaceConstraintsQuery, rc, ift, implPointIdForErr, inScope, inPoiScope); + + auto inScopes = CallScopeInfo::forNormalCall(inScope, inPoiScope); + + const ImplementationWitness* result = nullptr; + auto itf = parsing::idToAst(rc->context(), ift->id())->toInterface(); + + // First, process any associated constraints, and create a "phase 1" + // implementation witness with this information. + // TODO: not used in production today, so not implemented, + ImplementationWitness::ConstraintMap associatedConstraints; + auto witness1 = ImplementationWitness::get(rc->context(), associatedConstraints, {}, {}); + + // Next, process all the associated types, and create a "phase 2" + // implementation witness. + + // Here, interface formals aren't 'outer variables' since they live in the + // same symbol (the interface), so insert them into 'byPostorderForAssociatedTypes' + // as a shortcut for 'resolveNamedDecl' given ift->subs(). + ResolutionResultByPostorderID byPostorderForAssociatedTypes; + for (auto& sub : ift->substitutions()) { + byPostorderForAssociatedTypes.byId(sub.first).setType(sub.second); + } + auto associatedReceiverType = QualifiedType(); // cached across iterations + ImplementationWitness::AssociatedTypeMap associatedTypes; + for (auto stmt : itf->stmts()) { + auto td = stmt->toVariable(); + if (!td) continue; + + // Only associated type are valid declarations in this position + CHPL_ASSERT(td->storageKind() == QualifiedType::TYPE); + + ResolutionResultByPostorderID byPostorder; + auto resolver = Resolver::createForInterfaceStmt(rc, itf, ift, witness1, stmt, byPostorder); + td->traverse(resolver); + if (associatedReceiverType.kind() != QualifiedType::TYPE) { + // for associated types of multi-type interfaces, resolve the call + // on a tuple. + if (itf->numFormals() > 1) { + std::vector formalTypes; + for (auto formal : itf->formals()) { + formalTypes.push_back(ift->substitutions().at(formal->id()).type()); + } + + associatedReceiverType = + QualifiedType(QualifiedType::TYPE, + TupleType::getValueTuple(rc->context(), formalTypes)); + } else { + associatedReceiverType = ift->substitutions().at(itf->formal(0)->id()); + } + } + + auto foundQt = searchForAssociatedType(rc, ift, implPointIdForErr, + associatedReceiverType, td, inScopes); + if (foundQt.isUnknownOrErroneous()) { + result = nullptr; + return CHPL_RESOLUTION_QUERY_END(result); + } else { + associatedTypes.emplace(td->id(), foundQt.type()); + } + } + auto witness2 = ImplementationWitness::get(rc->context(), associatedConstraints, associatedTypes, {}); + + // Next, process all the functions; if all of these are found, we can construct + // a final witness with all the required information. + ImplementationWitness::FunctionMap functions; + for (auto stmt : itf->stmts()) { + auto fn = stmt->toFunction(); + if (!fn) continue; + + // Note: construct a resolver with the witness above, which pushes + // an interface frame onto the ResolutionContext. This is required for + // resolving typed signatures in the interface. + ResolutionResultByPostorderID byPostorder; + auto resolver = Resolver::createForInterfaceStmt(rc, itf, ift, witness2, stmt, byPostorder); + + + // Construct an initial typed signature, which will have opaque placeholder + // types for 'Self', associated types, etc. Then, replace the placeholders + // with the substitutions we've determined, so that we may proceed to + // type checking. + auto tfs = typedSignatureTemplateForId(rc, fn->id()); + PlaceholderMap allPlaceholders; + for (auto& [id, t] : associatedTypes) allPlaceholders.emplace(id, t); + for (auto& [id, qt] : ift->substitutions()) allPlaceholders.emplace(id, qt.type()); + tfs = tfs->substitute(rc->context(), std::move(allPlaceholders)); + auto foundId = searchFunctionByTemplate(rc, ift, implPointIdForErr, fn, tfs, inScopes); + + if (!foundId) { + result = nullptr; + return CHPL_RESOLUTION_QUERY_END(result); + } else { + functions.emplace(fn->id(), foundId); + } + } + + result = ImplementationWitness::get(rc->context(), associatedConstraints, associatedTypes, functions); + return CHPL_RESOLUTION_QUERY_END(result); +} + +const ImplementationWitness* +checkInterfaceConstraints(ResolutionContext* rc, + const InterfaceType* ift, + const ID& implPointId, + const CallScopeInfo& inScopes) { + return checkInterfaceConstraintsQuery(rc, ift, implPointId, + inScopes.callScope(), + inScopes.poiScope()); +} + static const TypedFnSignature* tryResolveAssignHelper(Context* context, const uast::AstNode* astForScopeOrErr, diff --git a/frontend/lib/resolution/resolution-types.cpp b/frontend/lib/resolution/resolution-types.cpp index 04452a83c245..e8458fcc4cb7 100644 --- a/frontend/lib/resolution/resolution-types.cpp +++ b/frontend/lib/resolution/resolution-types.cpp @@ -46,6 +46,16 @@ namespace resolution { using namespace uast; using namespace types; +SubstitutionsMap substituteInMap(Context* context, + const SubstitutionsMap& substituteIn, + const PlaceholderMap& subs) { + SubstitutionsMap into; + for (auto [id, qt] : substituteIn) { + into.emplace(id, qt.substitute(context, subs)); + } + return into; +} + const owned& UntypedFnSignature::getUntypedFnSignature(Context* context, ID id, UniqueString name, @@ -958,6 +968,27 @@ TypedFnSignature::getInferred( inferredFrom->outerVariables()).get(); } +const TypedFnSignature* +TypedFnSignature::substitute(Context* context, + const PlaceholderMap& subs) const { + std::vector newFormalTypes; + for (const auto& formalType : formalTypes_) { + newFormalTypes.push_back(formalType.substitute(context, subs)); + } + + // TODO: do we need to substitute in outer variables' stored types? + + return getTypedFnSignature(context, untyped(), + std::move(newFormalTypes), + whereClauseResult(), + needsInstantiation(), + isRefinementOnly_, + instantiatedFrom(), + parentFn(), + formalsInstantiatedBitmap(), + outerVariables()).get(); +} + void TypedFnSignature::stringify(std::ostream& ss, chpl::StringifyKind stringKind) const { diff --git a/frontend/lib/resolution/return-type-inference.cpp b/frontend/lib/resolution/return-type-inference.cpp index 3fe88d68a928..dc618cdde5a7 100644 --- a/frontend/lib/resolution/return-type-inference.cpp +++ b/frontend/lib/resolution/return-type-inference.cpp @@ -36,6 +36,7 @@ #include "Resolver.h" #include +#include #include #include #include @@ -55,6 +56,127 @@ static QualifiedType adjustForReturnIntent(uast::Function::ReturnIntent ri, QualifiedType retType); +/* pair (interface ID, implementation point ID) */ +using ImplementedInterface = + std::pair; + +/* (parent class, implemented interfaces) tuple resulting from processing + the inheritance expressions of a class/record. */ +using InheritanceExprResolutionResult = + std::pair>; + +static const InheritanceExprResolutionResult& +processInheritanceExpressionsForAggregateQuery(Context* context, + const AggregateDecl* ad, + SubstitutionsMap substitutions, + const PoiScope* poiScope) { + QUERY_BEGIN(processInheritanceExpressionsForAggregateQuery, context, ad, substitutions, poiScope); + const BasicClassType* parentClassType = nullptr; + const AstNode* parentClassNode = nullptr; + std::vector implementationPoints; + auto c = ad->toClass(); + + for (auto inheritExpr : ad->inheritExprs()) { + // Resolve the parent class type expression + ResolutionResultByPostorderID r; + auto visitor = + Resolver::createForParentClass(context, ad, inheritExpr, + substitutions, + poiScope, r); + inheritExpr->traverse(visitor); + + auto& rr = r.byAst(inheritExpr); + QualifiedType qt = rr.type(); + const BasicClassType* newParentClassType = nullptr; + if (auto t = qt.type()) { + if (auto bct = t->toBasicClassType()) { + newParentClassType = bct; + } else if (auto ct = t->toClassType()) { + // safe because it's checked for null later. + newParentClassType = ct->basicClassType(); + } + } + + bool foundParentClass = qt.isType() && newParentClassType != nullptr; + if (!c && foundParentClass) { + CHPL_REPORT(context, NonClassInheritance, ad, inheritExpr, newParentClassType); + } else if (foundParentClass) { + // It's a valid parent class; is it the only one? (error otherwise). + if (parentClassType) { + CHPL_ASSERT(parentClassNode); + reportInvalidMultipleInheritance(context, c, parentClassNode, inheritExpr); + } else { + parentClassType = newParentClassType; + parentClassNode = inheritExpr; + } + + // OK + } else if (qt.isType() && qt.type() && qt.type()->isInterfaceType()) { + auto ift = qt.type()->toInterfaceType(); + if (!ift->substitutions().empty()) { + context->error(inheritExpr, "cannot specify instantiated interface type in inheritance expression"); + } else { + implementationPoints.emplace_back(ift, inheritExpr->id()); + } + } else { + context->error(inheritExpr, "invalid parent class expression"); + parentClassType = BasicClassType::getRootClassType(context); + parentClassNode = inheritExpr; + } + } + + InheritanceExprResolutionResult result { + parentClassType, std::move(implementationPoints) + }; + return QUERY_END(result); +} + +static const std::vector& +getImplementedInterfacesQuery(Context* context, + const AggregateDecl* ad) { + QUERY_BEGIN(getImplementedInterfacesQuery, context, ad); + std::vector result; + std::map seen; + auto inheritanceResult = + processInheritanceExpressionsForAggregateQuery(context, ad, {}, nullptr); + auto& implementationPoints = inheritanceResult.second; + + auto initialType = QualifiedType(QualifiedType::TYPE, + initialTypeForTypeDecl(context, ad->id())); + + for (auto& implementedInterface : implementationPoints) { + auto insertionResult = seen.insert({ implementedInterface.first, implementedInterface.second }); + + if (!insertionResult.second) { + // We already saw an 'implements' for this interface + CHPL_REPORT(context, InterfaceMultipleImplements, ad, + implementedInterface.first, insertionResult.first->second, + implementedInterface.second); + } else { + auto ift = InterfaceType::withTypes(context, implementedInterface.first, + { initialType }); + if (!ift) { + // we gave it a single type, but got null back, which means it's + // not a unary interface. + CHPL_REPORT(context, InterfaceNaryInInherits, ad, + implementedInterface.first, implementedInterface.second); + } else { + auto implPoint = + ImplementationPoint::get(context, ift, implementedInterface.second); + result.push_back(implPoint); + } + } + }; + + return QUERY_END(result); +} + +const std::vector& +getImplementedInterfaces(Context* context, + const AggregateDecl* ad) { + return getImplementedInterfacesQuery(context, ad); +} + // Get a Type for an AggregateDecl // poiScope, instantiatedFrom are nullptr if not instantiating const CompositeType* helpGetTypeForDecl(Context* context, @@ -81,44 +203,10 @@ const CompositeType* helpGetTypeForDecl(Context* context, const CompositeType* ret = nullptr; if (const Class* c = ad->toClass()) { - const BasicClassType* parentClassType = nullptr; - const AstNode* lastParentClass = nullptr; - for (auto inheritExpr : c->inheritExprs()) { - // Resolve the parent class type expression - ResolutionResultByPostorderID r; - auto visitor = - Resolver::createForParentClass(context, c, inheritExpr, - substitutions, - poiScope, r); - inheritExpr->traverse(visitor); - - auto& rr = r.byAst(inheritExpr); - QualifiedType qt = rr.type(); - if (auto t = qt.type()) { - if (auto bct = t->toBasicClassType()) { - parentClassType = bct; - } else if (auto ct = t->toClassType()) { - // safe because it's checked for null later. - parentClassType = ct->basicClassType(); - } - } - - if (qt.isType() && parentClassType != nullptr) { - // It's a valid parent class; is it the only one? (error otherwise). - if (lastParentClass) { - reportInvalidMultipleInheritance(context, c, lastParentClass, inheritExpr); - } - lastParentClass = inheritExpr; - - // OK - } else if (!rr.toId().isEmpty() && - parsing::idToTag(context, rr.toId()) == uast::asttags::Interface) { - // OK, It's an interface. - } else { - context->error(inheritExpr, "invalid parent class expression"); - parentClassType = BasicClassType::getRootClassType(context); - } - } + const BasicClassType* parentClassType = + processInheritanceExpressionsForAggregateQuery(context, ad, + substitutions, + poiScope).first; // All the parent expressions could've been interfaces, and we just // inherit from object. @@ -717,9 +805,13 @@ returnTypeForTypeCtorQuery(Context* context, // handle type construction const AggregateDecl* ad = nullptr; - if (!untyped->id().isEmpty()) - if (auto ast = parsing::idToAst(context, untyped->compilerGeneratedOrigin())) + const Interface* itf = nullptr; + if (!untyped->id().isEmpty()) { + if (auto ast = parsing::idToAst(context, untyped->compilerGeneratedOrigin())) { ad = ast->toAggregateDecl(); + itf = ast->toInterface(); + } + } if (ad) { // compute instantiatedFrom @@ -789,6 +881,19 @@ returnTypeForTypeCtorQuery(Context* context, result = theType; + } else if (itf) { + SubstitutionsMap subs; + + CHPL_ASSERT(sig->numFormals() == itf->numFormals()); + + int nFormals = sig->numFormals(); + for (int i = 0; i < nFormals; i++) { + auto& formalType = sig->formalType(i); + auto& formalId = itf->formal(i)->id(); + subs.emplace(formalId, formalType); + } + + result = InterfaceType::get(context, itf->id(), itf->name(), std::move(subs)); } else { // built-in type construction should be handled // by resolveFnCallSpecialType and not reach this point. diff --git a/frontend/lib/resolution/return-type-inference.h b/frontend/lib/resolution/return-type-inference.h index 349a163b2cf8..abd97112a132 100644 --- a/frontend/lib/resolution/return-type-inference.h +++ b/frontend/lib/resolution/return-type-inference.h @@ -21,6 +21,7 @@ #define RETURN_TYPE_INFERENCE_H #include "chpl/resolution/resolution-types.h" +#include "chpl/resolution/interface-types.h" namespace chpl { namespace uast { @@ -29,6 +30,12 @@ namespace uast { namespace resolution { struct Resolver; +// this helper function returns the list of interfaces that +// are implemented by the given decl. It's defined here because this list +// is computed as part of 'helpGetTypeForDecl'. +const std::vector& +getImplementedInterfaces(Context* context, + const uast::AggregateDecl* ad); // this helper function computes a CompositeType based upon // a decl and some substitutions diff --git a/frontend/lib/types/CMakeLists.txt b/frontend/lib/types/CMakeLists.txt index ff8ff22f6085..56c4bce5ef40 100644 --- a/frontend/lib/types/CMakeLists.txt +++ b/frontend/lib/types/CMakeLists.txt @@ -37,11 +37,13 @@ target_sources(ChplFrontend-obj FnIteratorType.cpp HeapBufferType.cpp ImagType.cpp + InterfaceType.cpp IntType.cpp LoopExprIteratorType.cpp NilType.cpp NothingType.cpp Param.cpp + PlaceholderType.cpp PrimitiveType.cpp PtrType.cpp PromotionIteratorType.cpp diff --git a/frontend/lib/types/CompositeType.cpp b/frontend/lib/types/CompositeType.cpp index 550fb9d3ee81..7bf4da797a9b 100644 --- a/frontend/lib/types/CompositeType.cpp +++ b/frontend/lib/types/CompositeType.cpp @@ -42,52 +42,64 @@ using namespace resolution; bool CompositeType::areSubsInstantiationOf(Context* context, const CompositeType* partial) const { - // Check to see if the substitutions of `this` are all instantiations - // of the field types of `partial` - // // Note: Assumes 'this' and 'partial' share a root instantiation. + return canInstantiateSubstitutions(context, + substitutions(), + partial->substitutions(), + /* allowMissing */ !partial->isTupleType()); +} - const SubstitutionsMap& mySubs = substitutions(); - const SubstitutionsMap& pSubs = partial->substitutions(); - - // check, for each substitution in mySubs, that it matches - // or is an instantiation of pSubs. +CompositeType::~CompositeType() { +} - for (const auto& mySubPair : mySubs) { - ID mySubId = mySubPair.first; - QualifiedType mySubType = mySubPair.second; +using SubstitutionPair = CompositeType::SubstitutionPair; - // look for a substitution in pSubs with the same ID - auto pSearch = pSubs.find(mySubId); - if (pSearch != pSubs.end()) { - QualifiedType pSubType = pSearch->second; - // check the types - auto r = canPass(context, mySubType, pSubType); - if (r.passes() && !r.promotes() && !r.converts()) { - // instantiation and same-type passing are allowed here +static void stringifySortedSubstitutions(std::ostream& ss, + chpl::StringifyKind stringKind, + const std::vector& sorted, + bool& emittedField) { + for (const auto& sub : sorted) { + if (emittedField) ss << ", "; + + if (stringKind != StringifyKind::CHPL_SYNTAX) { + sub.first.stringify(ss, stringKind); + ss << ":"; + sub.second.stringify(ss, stringKind); + } else { + if (sub.second.isType() || (sub.second.isParam() && sub.second.param() == nullptr)) { + sub.second.type()->stringify(ss, stringKind); + } else if (sub.second.isParam()) { + sub.second.param()->stringify(ss, stringKind); } else { - // it was not an instantiation - return false; + // Some odd configuration; fall back to printing the qualified type. + CHPL_UNIMPL("attempting to stringify odd type representation as Chapel syntax"); + sub.second.stringify(ss, stringKind); } - } else { - // If the ID isn't found, then that means the generic component doesn't - // exist in the other type, which means this cannot be an instantiation - // of the other type. - // - // Currently this check assumes that 'this' and 'partial' share a root - // instantiation, so how could we reach this condition? One path here - // involves passing a tuple to a tuple formal with a fewer number of - // elements. For example, passing "(1, 2, 3)" to "(int, ?)". - return false; } - } - return true; + emittedField = true; + } } -CompositeType::~CompositeType() { +static +std::vector +sortedSubstitutionsMap(const CompositeType::SubstitutionsMap& subs) { + // since it's an unordered map, iteration will occur in a + // nondeterministic order. + // it's important to sort the keys / iterate in a deterministic order here, + // so we create a vector of pair and sort that instead + std::vector v(subs.begin(), subs.end()); + std::sort(v.begin(), v.end(), FirstElementComparator()); + return v; } +void CompositeType::stringifySubstitutions(std::ostream& ss, + chpl::StringifyKind stringKind, + const SubstitutionsMap& subs) { + bool emittedField = false; + auto sorted = sortedSubstitutionsMap(subs); + stringifySortedSubstitutions(ss, stringKind, sorted, emittedField); +} void CompositeType::stringify(std::ostream& ss, chpl::StringifyKind stringKind) const { // compute the parent class type for BasicClassType @@ -133,27 +145,7 @@ void CompositeType::stringify(std::ostream& ss, emittedField = true; } - for (const auto& sub : sorted) { - if (emittedField) ss << ", "; - - if (stringKind != StringifyKind::CHPL_SYNTAX) { - sub.first.stringify(ss, stringKind); - ss << ":"; - sub.second.stringify(ss, stringKind); - } else { - if (sub.second.isType() || (sub.second.isParam() && sub.second.param() == nullptr)) { - sub.second.type()->stringify(ss, stringKind); - } else if (sub.second.isParam()) { - sub.second.param()->stringify(ss, stringKind); - } else { - // Some odd configuration; fall back to printing the qualified type. - CHPL_UNIMPL("attempting to stringify odd type representation as Chapel syntax"); - sub.second.stringify(ss, stringKind); - } - } - - emittedField = true; - } + stringifySortedSubstitutions(ss, stringKind, sorted, emittedField); ss << ")"; } } @@ -263,35 +255,12 @@ const ClassType* CompositeType::getErrorType(Context* context) { return ClassType::get(context, bct, /* manager */ nullptr, dec); } - -using SubstitutionPair = CompositeType::SubstitutionPair; - -struct SubstitutionsMapCmp { - bool operator()(const SubstitutionPair& x, const SubstitutionPair& y) { - return x.first < y.first; - } -}; - -static -std::vector -sortedSubstitutionsMap(const CompositeType::SubstitutionsMap& subs) { - // since it's an unordered map, iteration will occur in a - // nondeterministic order. - // it's important to sort the keys / iterate in a deterministic order here, - // so we create a vector of pair and sort that instead - std::vector v(subs.begin(), subs.end()); - SubstitutionsMapCmp cmp; - std::sort(v.begin(), v.end(), cmp); - return v; -} - std::vector CompositeType::sortedSubstitutions(void) const { return sortedSubstitutionsMap(subs_); } size_t hashSubstitutionsMap(const CompositeType::SubstitutionsMap& subs) { - auto sorted = sortedSubstitutionsMap(subs); - return hashVector(sorted); + return hashUnorderedMap(subs); } void stringifySubstitutionsMap(std::ostream& streamOut, diff --git a/frontend/lib/types/InterfaceType.cpp b/frontend/lib/types/InterfaceType.cpp new file mode 100644 index 000000000000..e52826d285ff --- /dev/null +++ b/frontend/lib/types/InterfaceType.cpp @@ -0,0 +1,117 @@ +/* + * Copyright 2024 Hewlett Packard Enterprise Development LP + * Other additional copyright holders may be indicated within. + * + * The entirety of this work is licensed under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "chpl/types/InterfaceType.h" +#include "chpl/framework/query-impl.h" +#include "chpl/parsing/parsing-queries.h" +#include "chpl/resolution/can-pass.h" +#include "chpl/uast/Interface.h" + +namespace chpl { +namespace types { + +bool InterfaceType::validateSubstitutions(Context* context, + const ID& id, + SubstitutionsMap& subs) { + // Just a generic instance of the interface + if (subs.empty()) return true; + + auto ast = parsing::idToAst(context, id); + if (!ast) return false; + auto ifc = ast->toInterface(); + if (!ifc) return false; + + CHPL_ASSERT(ifc->numFormals() >= 0); + if (subs.size() != (size_t) ifc->numFormals()) return false; + for (auto fml : ifc->formals()) { + if (subs.count(fml->id()) == 0) return false; + } + + return true; +} + +owned const& +InterfaceType::getInterfaceType(Context* context, ID id, UniqueString name, SubstitutionsMap subs) { + QUERY_BEGIN(getInterfaceType, context, id, name, subs); + CHPL_ASSERT(validateSubstitutions(context, id, subs)); + auto result = toOwned(new InterfaceType(id, name, subs)); + return QUERY_END(result); +} + +const InterfaceType* InterfaceType::get(Context* context, ID id, UniqueString name, SubstitutionsMap subs) { + return getInterfaceType(context, id, name, subs).get(); +} + +static const InterfaceType* const& +interfaceTypeWithTypesQuery(Context* context, + const InterfaceType* ift, + std::vector types) { + QUERY_BEGIN(interfaceTypeWithTypesQuery, context, ift, types); + const InterfaceType* res = nullptr; + + if (ift->substitutions().size() > 0) { + // don't allow instantiating already-instantiated interfaces + } else { + auto ast = parsing::idToAst(context, ift->id()); + CHPL_ASSERT(ast); + auto itf = ast->toInterface(); + CHPL_ASSERT(itf); + + CHPL_ASSERT(itf->numFormals() >= 0); + if (types.size() != (size_t) itf->numFormals()) { + // not good, wrong instantiation + } else { + InterfaceType::SubstitutionsMap subs; + auto typesIt = types.begin(); + for (auto formal : itf->formals()) { + // Force the intent to TYPE + auto newType = QualifiedType(QualifiedType::TYPE, (typesIt++)->type()); + subs.emplace(formal->id(), std::move(newType)); + } + + res = InterfaceType::get(context, itf->id(), itf->name(), std::move(subs)); + } + } + + return QUERY_END(res); +} + +const InterfaceType* InterfaceType::withTypes(Context* context, + const InterfaceType* ift, + std::vector types) { + return interfaceTypeWithTypesQuery(context, ift, std::move(types)); +} + +bool InterfaceType::isInstantiationOf(Context* context, const InterfaceType* other) const { + if (id_ != other->id()) return false; + CHPL_ASSERT(name_ == other->name()); + return resolution::canInstantiateSubstitutions(context, subs_, other->substitutions(), /* allowMissing */ false); +} + +void InterfaceType::stringify(std::ostream& ss, StringifyKind stringKind) const { + ss << name_; + if (!subs_.empty()) { + ss << "("; + CompositeType::stringifySubstitutions(ss, stringKind, subs_); + ss << ")"; + } +} + +} // end namespace types +} // end namespace chpl diff --git a/frontend/lib/types/PlaceholderType.cpp b/frontend/lib/types/PlaceholderType.cpp new file mode 100644 index 000000000000..58e577e1ab76 --- /dev/null +++ b/frontend/lib/types/PlaceholderType.cpp @@ -0,0 +1,45 @@ +/* + * Copyright 2024 Hewlett Packard Enterprise Development LP + * Other additional copyright holders may be indicated within. + * + * The entirety of this work is licensed under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "chpl/types/PlaceholderType.h" +#include "chpl/framework/query-impl.h" + +namespace chpl { +namespace types { + +owned const& +PlaceholderType::getPlaceholderType(Context* context, ID id) { + QUERY_BEGIN(getPlaceholderType, context, id); + auto result = toOwned(new PlaceholderType(id)); + return QUERY_END(result); +} + +const PlaceholderType* PlaceholderType::get(Context* context, ID id) { + return getPlaceholderType(context, id).get(); +} + +void PlaceholderType::stringify(std::ostream& ss, + chpl::StringifyKind stringKind) const { + ss << "PlaceholderType("; + id_.stringify(ss, stringKind); + ss << ")"; +} + +} // end namespace types +} // end namespace chpl diff --git a/frontend/test/resolution/CMakeLists.txt b/frontend/test/resolution/CMakeLists.txt index e0622bb801aa..9d04c598a0c7 100644 --- a/frontend/test/resolution/CMakeLists.txt +++ b/frontend/test/resolution/CMakeLists.txt @@ -49,6 +49,7 @@ comp_unit_test(testHeapBuffer) comp_unit_test(testIf) comp_unit_test(testInitSemantics) comp_unit_test(testInteractive) +comp_unit_test(testInterfaces) comp_unit_test(testIterators) comp_unit_test(testLibrary) comp_unit_test(testLocation) diff --git a/frontend/test/resolution/testInterfaces.cpp b/frontend/test/resolution/testInterfaces.cpp new file mode 100644 index 000000000000..edf2cb33bb67 --- /dev/null +++ b/frontend/test/resolution/testInterfaces.cpp @@ -0,0 +1,744 @@ +/* + * Copyright 2021-2024 Hewlett Packard Enterprise Development LP + * Other additional copyright holders may be indicated within. + * + * The entirety of this work is licensed under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "chpl/uast/BuilderResult.h" +#include "test-resolution.h" + +#include "chpl/parsing/parsing-queries.h" +#include "chpl/resolution/resolution-queries.h" +#include "chpl/resolution/scope-queries.h" +#include "chpl/types/IntType.h" +#include "chpl/types/QualifiedType.h" +#include "chpl/uast/Comment.h" +#include "chpl/uast/FnCall.h" +#include "chpl/uast/Identifier.h" +#include "chpl/uast/Module.h" +#include "chpl/uast/Variable.h" + +static const std::string INDENT = " "; +static const bool NOT_A_TYPE_METHOD = false; +static const bool IS_A_TYPE_METHOD = true; + +struct InterfaceSource; + +static std::string intercalate(const std::vector& lines, const std::string& separator) { + std::string result; + bool first = true; + for (const auto& line : lines) { + if (!first) { + result += separator; + } + result += line; + first = false; + } + return result; +} + +struct RecordSource { + std::string typeName; + std::vector> methods; + std::vector interfaces; + + RecordSource(std::string name) : typeName(std::move(name)) {} + + RecordSource& addMethod(bool isTypeMethod, std::string sig) { + methods.push_back({isTypeMethod, std::move(sig)}); + return *this; + } + + RecordSource& addInterfaceConstraint(InterfaceSource& interface) { + interfaces.push_back(&interface); + return *this; + } + + std::string declLine(bool includeInterfaces) const; + + std::vector getMethods(const std::string& indent, + const std::string& prefix) const { + std::vector methodLines; + for (const auto& line : methods) { + auto isTypeMethod = std::get<0>(line); + auto& sig = std::get<1>(line); + methodLines.push_back(indent + "proc " + (isTypeMethod ? "type " : "") + prefix + sig); + } + return methodLines; + } + + std::vector primaryLines(bool includeInterfaces) const { + std::vector primary; + primary.push_back(declLine(includeInterfaces) + " {"); + for (const auto& line : getMethods(INDENT, "")) { + primary.push_back(line); + } + primary.push_back("}"); + return primary; + } + + std::vector definitionOnly(bool includeInterfaces) const { + std::vector defOnly; + defOnly.push_back(declLine(includeInterfaces) + " {}"); + return defOnly; + } + + std::vector methodsOnly() const { + return getMethods("", typeName + "."); + } +}; + +struct InterfaceSource { + std::string interfaceName; + std::vector linesInside; + + template + InterfaceSource(std::string name, Args&& ... args) : interfaceName(std::move(name)) { + (linesInside.push_back(std::move(args)), ...); + } + + std::vector allLines() const { + std::vector all; + all.push_back("interface " + interfaceName + " {"); + for (const auto& line : linesInside) { + all.push_back(INDENT + line); + } + all.push_back("}"); + return all; + } + + std::string singleImplements(const RecordSource& record) const { + return record.typeName + " implements " + interfaceName + ";"; + } + + template + std::string generalImplements(Records&& ... records) const { + std::string line = "implements " + interfaceName + "("; + bool first = true; + auto appendRecord = [&line, &first](const auto& record) { + if (!first) { + line += ", "; + } + line += record.typeName; + first = false; + }; + + (appendRecord(records), ...); + return line + ");"; + } +}; + +std::string RecordSource::declLine(bool includeInterfaces) const { + std::string line = "record " + typeName; + if (!includeInterfaces) { + return line; + } + + if (!interfaces.empty()) { + line += " : "; + bool first = true; + for (const auto& interface : interfaces) { + if (!first) { + line += ", "; + } + line += interface->interfaceName; + first = false; + } + } + return line; +} + +struct ModuleSource { + std::string moduleName; + std::vector usesImports; + std::vector linesInside; + std::vector> checks; + + enum MethodKind { + /* when adding a record, don't add its methods */ + M_NONE, + /* when adding a record, add its methods as primary methods */ + M_PRIMARY, + /* when adding a record, add its methods outside of its body as secondary methods */ + M_SECONDARY, + }; + + enum ImplementationKind { + /* when adding a record, don't add an interface 'implements' */ + I_NONE, + /* when adding a record, include the interfaces it implements as part of its declaration */ + I_DECL, + /* when adding a record, include the interfaces it implements as a statement in the form 'R implements I' */ + I_SINGLE, + /* when adding a record, include the interfaces it implements as a statement in the form 'implements I(R)' */ + I_GENERAL, + }; + + ModuleSource(std::string name) : moduleName(std::move(name)) {} + + ModuleSource& addRecord(const RecordSource& record, MethodKind methodKind, ImplementationKind implKind) { + bool includeInterfaces = implKind == I_DECL; + bool includeMethods = methodKind == M_PRIMARY; + for (const auto& line : includeMethods ? record.primaryLines(includeInterfaces) : record.definitionOnly(includeInterfaces)) { + linesInside.push_back(line); + } + if (methodKind == M_SECONDARY) { + for (const auto& line : record.methodsOnly()) { + linesInside.push_back(line); + } + } + + if (implKind == I_SINGLE) { + for (const auto& interface : record.interfaces) { + linesInside.push_back(interface->singleImplements(record)); + } + } else if (implKind == I_GENERAL) { + for (const auto& interface : record.interfaces) { + linesInside.push_back(interface->generalImplements(record)); + } + } + return *this; + } + + ModuleSource& addRecordMethods(const RecordSource& record) { + for (const auto& line : record.methodsOnly()) { + linesInside.push_back(line); + } + return *this; + } + + ModuleSource& addInterface(const InterfaceSource& interface) { + for (const auto& line : interface.allLines()) { + linesInside.push_back(line); + } + return *this; + } + + ModuleSource& addUsesImport(std::string import) { + usesImports.push_back(std::move(import)); + return *this; + } + + ModuleSource& addSingleImplements(const InterfaceSource& interface, const RecordSource& record) { + linesInside.push_back(interface.singleImplements(record)); + return *this; + } + + template + ModuleSource& addGeneralImplements(const InterfaceSource& interface, Records&& ... records) { + linesInside.push_back(interface.generalImplements(records...)); + return *this; + } + + ModuleSource& addLine(const std::string& line) { + linesInside.push_back(line); + return *this; + } + + ModuleSource& addCheck(const std::string& check) { + static int checkNum = 0; + checks.push_back({"check" + std::to_string(checkNum++), check}); + return *this; + } + + std::vector allLines() const { + std::vector all; + all.push_back("module " + moduleName + " {"); + for (const auto& line : usesImports) { + all.push_back(INDENT + line); + } + all.push_back(""); + for (const auto& line : linesInside) { + all.push_back(INDENT + line); + } + for (const auto& check : checks) { + all.push_back(INDENT + "param " + check.first + " = " + check.second + ";"); + } + all.push_back("}"); + return all; + } + + void validateChecks(Context* context, const chpl::uast::Module* mod, bool expectError) const { + std::vector checkNames; + for (const auto& check : checks) { + checkNames.push_back(check.first); + } + + auto types = resolveTypesOfVariables(context, mod, checkNames); + for (auto& [name, type] : types) { + std::cout << "checking " << name << std::endl; + assert(expectError ? (type.isParamFalse() || type.isUnknownOrErroneous()) : type.isParamTrue()); + } + } +}; + +static bool findError(const std::vector>& errors, ErrorType type) { + for (auto& err : errors) { + if (err->type() == type) { + return true; + } + } + return false; +} + +static void testSingleInterface(const InterfaceSource& interface, + const RecordSource& record, + chpl::optional expectedError = chpl::empty) { + Context ctx; + Context* context = &ctx; + ErrorGuard guard(context); + + auto validateAndAdvance = [context, &guard, expectedError](const ModuleSource& src, const Module* mod) { + src.validateChecks(context, mod, (bool) expectedError); + std::cout << std::endl; + + if (expectedError) { + assert(findError(guard.errors(), *expectedError)); + guard.realizeErrors(); + } else { + assert(guard.realizeErrors() == 0); + } + + context->advanceToNextRevision(false); + }; + + // First, place the interface and the type in the same module. + { + for (auto methodKind : { ModuleSource::M_PRIMARY, ModuleSource::M_SECONDARY }) { + for (auto implKind : { ModuleSource::I_DECL, ModuleSource::I_SINGLE, ModuleSource::I_GENERAL }) { + auto module = ModuleSource("M") + .addInterface(interface) + .addRecord(record, methodKind, implKind) + .addCheck("__primitive(\"implements interface\", " + record.typeName + ", " + interface.interfaceName + ") == 0"); + auto source = intercalate(module.allLines(), "\n"); + + std::cout << "--- testing program ---" << std::endl; + std::cout << source << std::endl << std::endl; + + auto filePath = UniqueString::get(context, "M.chpl"); + setFileText(context, filePath, source); + auto& modVec = parseToplevel(context, filePath); + assert(modVec.size() == 1); + + validateAndAdvance(module, modVec[0]); + } + } + } + + // Then, split the interface, record definition, and checks into three modules. + { + for (auto methodKind : { ModuleSource::M_PRIMARY, ModuleSource::M_SECONDARY }) { + for (auto implKind : { ModuleSource::I_DECL, ModuleSource::I_SINGLE, ModuleSource::I_GENERAL }) { + for (auto importTechnique : { std::string("use MRec;"), + std::string("import MRec.{") + record.typeName + "};" }) { + auto moduleLib = ModuleSource("MLib") + .addInterface(interface); + auto moduleRec = ModuleSource("MRec") + .addUsesImport("use MLib;") + .addRecord(record, methodKind, implKind); + auto moduleCheck = ModuleSource("MCheck") + .addUsesImport("use MLib;") + .addUsesImport(importTechnique) + .addCheck("__primitive(\"implements interface\", " + record.typeName + ", " + interface.interfaceName + ") == 0"); + + auto source = + intercalate(moduleLib.allLines(), "\n") + "\n\n" + + intercalate(moduleRec.allLines(), "\n") + "\n\n" + + intercalate(moduleCheck.allLines(), "\n"); + + std::cout << "--- testing program ---" << std::endl; + std::cout << source << std::endl << std::endl; + + auto filePath = UniqueString::get(context, "file.chpl"); + setFileText(context, filePath, source); + auto& modVec = parseToplevel(context, filePath); + + validateAndAdvance(moduleCheck, modVec[2]); + } + } + } + } + + // Then, define the record in one place, but provide the required methods in + // another place. + { + for (auto implKind : { ModuleSource::I_SINGLE, ModuleSource::I_GENERAL }) { + for (auto importTechnique : { std::string("use MRec;"), + std::string("import MRec.{") + record.typeName + "};" }) { + auto moduleLib = ModuleSource("MLib") + .addInterface(interface); + auto moduleRec = ModuleSource("MRec") + .addUsesImport("use MLib;") + .addRecord(record, ModuleSource::M_NONE, ModuleSource::I_NONE); + auto moduleCheck = ModuleSource("MCheck") + .addUsesImport("use MLib;") + .addUsesImport(importTechnique) + .addRecordMethods(record); + + implKind == ModuleSource::I_SINGLE ? moduleCheck.addSingleImplements(interface, record) + : moduleCheck.addGeneralImplements(interface, record); + moduleCheck.addCheck("__primitive(\"implements interface\", " + record.typeName + ", " + interface.interfaceName + ") == 0"); + + auto source = + intercalate(moduleLib.allLines(), "\n") + "\n\n" + + intercalate(moduleRec.allLines(), "\n") + "\n\n" + + intercalate(moduleCheck.allLines(), "\n"); + + std::cout << "--- testing program ---" << std::endl; + std::cout << source << std::endl << std::endl; + + auto filePath = UniqueString::get(context, "file.chpl"); + setFileText(context, filePath, source); + auto& modVec = parseToplevel(context, filePath); + + validateAndAdvance(moduleCheck, modVec[2]); + } + } + } +} + +static void testRequiredMethodNoArgs() { + auto i = InterfaceSource("myInterface", "proc Self.foo();"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo() {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2, ErrorType::InterfaceMissingFn); +} + +static void testRequiredMethodNoArgsDefault() { + auto i = InterfaceSource("myInterface", "proc Self.foo();"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int = 10) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); +} + +static void testRequiredMethodOneArg() { + auto i = InterfaceSource("myInterface", "proc Self.foo(x: int);"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo() {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2, ErrorType::InterfaceMissingFn); + + auto r3 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: bool) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r3, ErrorType::InterfaceMissingFn); +} + +static void testRequiredMethodAmbiguous() { + auto i = InterfaceSource("myInterface", "proc Self.foo(x: int);"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1, ErrorType::InterfaceAmbiguousFn); +} + +static void testRequiredMethodWrongIntent() { + auto i = InterfaceSource("myInterface", "proc Self.foo(x: int);"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) type do return int;") + .addInterfaceConstraint(i); + testSingleInterface(i, r1, ErrorType::InterfaceInvalidIntent); +} + +static void testTwoRequiredMethods() { + auto i = InterfaceSource("myInterface", + "proc Self.foo(x: int);", + "proc Self.bar(y: real);"); + + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addMethod(NOT_A_TYPE_METHOD, "bar(y: real) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addMethod(NOT_A_TYPE_METHOD, "bar() {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2, ErrorType::InterfaceMissingFn); + + auto r3 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addMethod(NOT_A_TYPE_METHOD, "bar(y: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r3, ErrorType::InterfaceMissingFn); + + auto r4 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r4, ErrorType::InterfaceMissingFn); + + auto r5 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "bar(y: real) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r5, ErrorType::InterfaceMissingFn); +} + +static void testBasicGeneric() { + auto i = InterfaceSource("myInterface", + "proc Self.foo(x);"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + // top-level type queries and missing types should be cross-compatible. + auto r2 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: ?t) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2); + + // too specific: interface requires fully generic, we're giving it a specific type + auto r3 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r3, ErrorType::InterfaceMissingFn); + +} + +static void testBasicGenericTypeQuery() { + auto i = InterfaceSource("myInterface", + "proc Self.foo(x: ?tq);"); + // top-level type queries and missing types should be cross-compatible. + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: ?t) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2); + + // too specific: interface requires fully generic, we're giving it a specific type + auto r3 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r3, ErrorType::InterfaceMissingFn); + +} + +static void testDependentGeneric() { + auto i = InterfaceSource("myInterface", + "proc Self.foo(x: ?tq, y: tq);"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x, y: x.type) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: ?t, y: t) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2); +} + +static void testAssociatedType() { + auto i = InterfaceSource("myInterface", + "type someType;"); + auto r1 = RecordSource("myRec") + .addMethod(IS_A_TYPE_METHOD, "someType type do return int;") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addInterfaceConstraint(i); + testSingleInterface(i, r2, ErrorType::InterfaceMissingAssociatedType); + + auto r3 = RecordSource("myRec") + .addMethod(IS_A_TYPE_METHOD, "wrongName type do return int;") + .addInterfaceConstraint(i); + testSingleInterface(i, r2, ErrorType::InterfaceMissingAssociatedType); + + auto r4 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "someType type do return int;") + .addInterfaceConstraint(i); + testSingleInterface(i, r4, ErrorType::InterfaceMissingAssociatedType); + + auto r5 = RecordSource("myRec") + .addMethod(IS_A_TYPE_METHOD, "someType do return 42;") + .addInterfaceConstraint(i); + testSingleInterface(i, r5, ErrorType::InterfaceMissingAssociatedType); +} + +static void testAssociatedTypeInFn() { + auto i = InterfaceSource("myInterface", + "type someType;" + "proc Self.foo(x: someType);"); + auto r1 = RecordSource("myRec") + .addMethod(IS_A_TYPE_METHOD, "someType type do return int;") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addMethod(IS_A_TYPE_METHOD, "someType type do return int;") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: bool) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2, ErrorType::InterfaceMissingFn); +} + +static void testFormalNaming() { + auto i = InterfaceSource("myInterface", + "proc Self.foo(x: int);"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(y: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2, ErrorType::InterfaceMissingFn); +} + +static void testFormalOrdering() { + auto i = InterfaceSource("myInterface", + "proc Self.foo(x: int, y: int, z: int);"); + auto r1 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(x: int, y: int, z: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r1); + + auto r2 = RecordSource("myRec") + .addMethod(NOT_A_TYPE_METHOD, "foo(y: int, x: int, z: int) {}") + .addInterfaceConstraint(i); + testSingleInterface(i, r2, ErrorType::InterfaceReorderedFnFormals); +} + +static void expectError(const std::string& program, ErrorType error) { + Context ctx; + Context* context = &ctx; + ErrorGuard guard(context); + + auto filePath = UniqueString::get(context, "file.chpl"); + setFileText(context, filePath, program); + auto& modVec = parseToplevel(context, filePath); + assert(modVec.size() == 1); + resolveModule(context, modVec[0]->id()); + + assert(findError(guard.errors(), error)); + guard.realizeErrors(); +} + +static void testImplementsInvalidInterface() { + expectError( + R"""( + record I { + + } + int implements I; + )""", ErrorType::InvalidImplementsInterface); +} + +static void testImplementsInvalidActual() { + expectError( + R"""( + interface I { + + } + var x = 3; + x implements I; + )""", ErrorType::InvalidImplementsActual); + + expectError( + R"""( + interface I { + + } + var x = 3; + implements I(x); + )""", ErrorType::InvalidImplementsActual); +}; + +static void testImplementsDuplicate() { + expectError( + R"""( + interface I { + + } + record R : I, I { + + } + )""", ErrorType::InterfaceMultipleImplements); +}; + +static void testInvalidImplementsArity() { + expectError( + R"""( + interface Unary {} + implements Unary(int, bool); + )""", ErrorType::InvalidImplementsArity); + + expectError( + R"""( + interface Unary(Self) {} + implements Unary(int, bool); + )""", ErrorType::InvalidImplementsArity); + + expectError( + R"""( + interface Binary(L, R) {} + implements Binary(int); + )""", ErrorType::InvalidImplementsArity); + + expectError( + R"""( + interface Binary(L, R) {} + implements Binary(int, bool, real); + )""", ErrorType::InvalidImplementsArity); + + expectError( + R"""( + interface Binary(L, R) {} + record R : Binary {} + )""", ErrorType::InterfaceNaryInInherits); +}; + +int main() { + // tests for "basic" interface resolution (unary interfaces) + testRequiredMethodNoArgs(); + testRequiredMethodNoArgsDefault(); + testRequiredMethodOneArg(); + testRequiredMethodAmbiguous(); + testRequiredMethodWrongIntent(); + testTwoRequiredMethods(); + testBasicGeneric(); + testBasicGenericTypeQuery(); + testDependentGeneric(); + testAssociatedType(); + testAssociatedTypeInFn(); + testFormalNaming(); + testFormalOrdering(); + + // tests for the various error message cases + testImplementsInvalidInterface(); + testImplementsInvalidActual(); + testImplementsDuplicate(); + testInvalidImplementsArity(); +} diff --git a/frontend/test/resolution/testResolverVerboseErrors.cpp b/frontend/test/resolution/testResolverVerboseErrors.cpp index d88388c1212e..553fc42825e0 100644 --- a/frontend/test/resolution/testResolverVerboseErrors.cpp +++ b/frontend/test/resolution/testResolverVerboseErrors.cpp @@ -207,7 +207,7 @@ static const char* errorStarVsNotStar = R"""( 3 | f((1.0, 1.0, true)); | ⎺⎺⎺⎺⎺⎺⎺⎺⎺⎺⎺⎺⎺⎺⎺⎺ | - A formal that is a star tuple cannot accept an actual actual that is not. + A formal that is a star tuple cannot accept an actual that is not. )"""; static const char* progVarArgMismatch = R"""( diff --git a/frontend/test/test-resolution.cpp b/frontend/test/test-resolution.cpp index 3e11ff13073b..83848565bfb9 100644 --- a/frontend/test/test-resolution.cpp +++ b/frontend/test/test-resolution.cpp @@ -210,6 +210,21 @@ const Variable* findVariable(const ModuleVec& vec, const char* name) { return nullptr; } +std::unordered_map +resolveTypesOfVariables(Context* context, + const Module* mod, + const std::vector& variables) { + std::unordered_map toReturn; + auto& rr = resolveModule(context, mod->id()); + for (auto& variable : variables) { + if (auto varAst = findVariable(mod, variable.c_str())) { + toReturn[variable] = rr.byAst(varAst).type(); + } + } + assert(variables.size() == toReturn.size()); + return toReturn; +} + std::unordered_map resolveTypesOfVariables(Context* context, std::string program, diff --git a/frontend/test/test-resolution.h b/frontend/test/test-resolution.h index 93bc71e8bf07..1bb01a9c5aa7 100644 --- a/frontend/test/test-resolution.h +++ b/frontend/test/test-resolution.h @@ -69,6 +69,11 @@ void testCall(const char* testName, const Variable* findVariable(const AstNode* ast, const char* name); const Variable* findVariable(const ModuleVec& vec, const char* name); +std::unordered_map +resolveTypesOfVariables(Context* context, + const Module* mod, + const std::vector& variables); + std::unordered_map resolveTypesOfVariables(Context* context, std::string program, const std::vector& variables);