diff --git a/cassis/typesystem.py b/cassis/typesystem.py index b3dc2ad..e16b14e 100644 --- a/cassis/typesystem.py +++ b/cassis/typesystem.py @@ -847,7 +847,9 @@ def __iter__(self) -> Iterator[Type]: return self.get_types() def contains_type(self, typename: str) -> bool: - """Checks whether this type system contains a type with name `typename`. + """Checks whether this type system contains a type with name `typename`. If a type is a short name (i.e. it + does not contain a dot) and there is exactly one type with this short name, this method will consider the + type to be included in the type system. Args: typename: The name of type whose existence is to be checked. @@ -855,7 +857,14 @@ def contains_type(self, typename: str) -> bool: Returns: `True` if a type with `typename` exists, else `False`. """ - return typename in self._types + + if "." in typename: + return typename in self._types + + try: + return self.get_type(typename) is not None + except TypeNotFoundError: + return False def create_type(self, name: str, supertypeName: str = TYPE_NAME_ANNOTATION, description: str = None) -> Type: """Creates a new type and return it. @@ -887,7 +896,8 @@ def create_type(self, name: str, supertypeName: str = TYPE_NAME_ANNOTATION, desc return new_type def get_type(self, type_name: str) -> Type: - """Finds a type by name in the type system of this CAS. + """Finds a type by name in the type system of this CAS. If a type is a short name (i.e. it + does not contain a dot) and there is exactly one type with this short name, this method will return this type. Args: typename: The name of the type to retrieve @@ -897,10 +907,22 @@ def get_type(self, type_name: str) -> Type: Raises: Exception: If no type with `typename` could be found. """ - if self.contains_type(type_name): + if type_name in self._types: return self._types[type_name] - else: - raise TypeNotFoundError(f"Type with name [{type_name}] not found!") + + if "." not in type_name: + types_by_simple_name = defaultdict(list) + for tn, t in self._types.items(): + types_by_simple_name[t.short_name].append(t) + + if types_by_simple_name[type_name]: + candidates = types_by_simple_name[type_name] + if len(candidates) == 1: + return candidates[0] + elif len(candidates) > 1: + raise TypeNotFoundError(f"Multiple types with short name [{type_name}] found: {candidates}") + + raise TypeNotFoundError(f"Type with name [{type_name}] not found!") def get_types(self, built_in: bool = False) -> Iterator[Type]: """Returns all types of this type system. Normally, this excludes the built-in types diff --git a/tests/test_cas.py b/tests/test_cas.py index a455aea..f09f06d 100644 --- a/tests/test_cas.py +++ b/tests/test_cas.py @@ -143,7 +143,9 @@ def test_select(small_typesystem_xml, tokens, sentences): cas = Cas(typesystem=ts) cas.add_all(tokens + sentences) + assert list(cas.select("Token")) == tokens assert list(cas.select("cassis.Token")) == tokens + assert list(cas.select("Sentence")) == sentences assert list(cas.select("cassis.Sentence")) == sentences assert list(cas.select(ts.get_type("cassis.Token"))) == tokens assert list(cas.select(ts.get_type("cassis.Sentence"))) == sentences diff --git a/tests/test_typesystem.py b/tests/test_typesystem.py index 36fba9e..d5a84c8 100644 --- a/tests/test_typesystem.py +++ b/tests/test_typesystem.py @@ -515,10 +515,19 @@ def test_type_name(): cas = Cas() Annotation = cas.typesystem.get_type(TYPE_NAME_ANNOTATION) + assert cas.typesystem.contains_type(TYPE_NAME_ANNOTATION) assert Annotation.name == TYPE_NAME_ANNOTATION assert Annotation.short_name == "Annotation" +def test_get_type_by_short_name(): + cas = Cas() + + Annotation = cas.typesystem.get_type("Annotation") + assert cas.typesystem.contains_type("Annotation") + assert Annotation.name == TYPE_NAME_ANNOTATION + + def test_get_types(): cas = Cas()