Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data pack attribute interface implementation #926

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 148 additions & 3 deletions forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,6 @@ def _generate_annotation_entry_data(
data: Dict,
cont: Optional[Annotation],
) -> Dict:

components, unit, fields = self._parse_request_args(a_type, a_args)

a_dict: Dict[str, Any] = {}
Expand Down Expand Up @@ -1391,7 +1390,6 @@ def _generate_link_entry_data(
data: Dict,
cont: Optional[Annotation],
) -> Dict:

components, unit, fields = self._parse_request_args(a_type, a_args)

if unit is not None:
Expand Down Expand Up @@ -1658,7 +1656,6 @@ def require_annotations(entry_class=Annotation) -> bool:
range_span=range_annotation # type: ignore
and (range_begin, range_end),
):

# Filter by components
if components is not None:
if not self.is_created_by(
Expand Down Expand Up @@ -1691,6 +1688,154 @@ def require_annotations(entry_class=Annotation) -> bool:
# type_name does not exist in DataStore
yield from []

def get_attributes_of_type(
self,
s_entry_type: str,
attributes_names: List[str],
range_annotation: Optional[
Union[Annotation, AudioAnnotation, int]
] = None,
components: Optional[Union[str, Iterable[str]]] = None,
include_sub_type: bool = True,
get_raw: bool = False,
) -> Iterable: # [EntryType]
# pylint: disable=protected-access
def require_annotations(entry_class=Annotation) -> bool:
if self._data_store._is_subclass(s_entry_type, entry_class):
return True

curr_class: Type[EntryType] = as_entry_type(s_entry_type) # type: ignore
if issubclass(curr_class, Link):
return issubclass(
curr_class.ParentType, entry_class # type: ignore
) and issubclass(curr_class.ChildType, entry_class) # type: ignore
if issubclass(curr_class, Group):
return issubclass(curr_class.MemberType, entry_class)
return False

# If we don't have any annotations but the items to check requires them,
# then we simply yield from an empty list.
# changed form using len(annotations) to num_annotations directly for
# improving the performance.
if (
self.num_annotations == 0
and isinstance(range_annotation, Annotation)
and require_annotations(Annotation)
) or (
self.num_audio_annotations == 0
and isinstance(range_annotation, AudioAnnotation)
and require_annotations(AudioAnnotation)
):
yield from []
return

# If the ``entry_type`` and `range_annotation` are for different types of
# payload, then we yield from an empty list with a warning.
if (
require_annotations(Annotation)
and isinstance(range_annotation, AudioAnnotation)
) or (
require_annotations(AudioAnnotation)
and isinstance(range_annotation, Annotation)
):
logger.warning(
"Incompatible combination of ``entry_type`` and "
"`range_annotation` found in the input of `DataPack.get()`"
" method. An empty iterator will be returned when inputs "
"contain multi-media entries. Please double check the input "
"arguments and make sure they are associated with the same type"
" of payload (i.e., either text or audio)."
)
yield from []
return

# If range_annotation is specified, we record its begin and
# end index
range_begin: int
range_end: int

if range_annotation is not None:
if isinstance(range_annotation, AnnotationLikeEntries):
range_begin = range_annotation.begin
range_end = range_annotation.end
else:
# range_annotation is given by the tid of the entry it
# represents
range_raw = self._data_store.transform_data_store_entry(
self.get_entry_raw(range_annotation)
)
range_begin = range_raw[BEGIN_ATTR_NAME]
range_end = range_raw[END_ATTR_NAME]

try:
for attrs_from_ds in self._data_store.get_attributes_of_type(
type_name=s_entry_type,
attributes_names=attributes_names,
include_sub_type=include_sub_type,
range_span=range_annotation # type: ignore
and (range_begin, range_end),
):
# Filter by components
if components is not None:
if not self.is_created_by(
attrs_from_ds["tid"],
# entry_data[TID_INDEX],
components,
):
continue

entry: Union[Entry, Dict[str, Any]]
if get_raw:
entry = None # not implemented
# entry = self._data_store.transform_data_store_entry(
# attrs_from_ds
# )
else:
entry = self.get_entry(
tid=attrs_from_ds["tid"]
) # entry_data[TID_INDEX]
# entry_data["text"] = self.get_span_text(entry_data['begin'],
# entry_data['end'])

# Filter out incompatible audio span comparison for Links and Groups
if (
self._data_store._is_subclass( # pylint: disable=W0212
s_entry_type, (Link, Group)
)
and isinstance(range_annotation, AudioAnnotation)
and not self._index.in_audio_span(
entry, range_annotation.span
)
):
continue

yield entry, attrs_from_ds
except ValueError:
# type_name does not exist in DataStore
yield from []

def get_attributes_of_tids(self, tids: List[int], attr_names: List[str]) -> List:
r"""This function returns the value of attributes listed in
``attr_names`` for entries in listed in the ``tids``. It locates
the entries data with ``tid`` and put attributes listed in
``attr_name`` in a dict for each entry (tid).

Args:
tids: List of unique ids (tids) of the entry.
attr_names: List of name of the attribute.

Returns:
A list of dict with ``attr_name`` as key for atrributes
of the entries listed in``tids``.

Raises:
KeyError: when ``tid`` or ``attr_name`` is not found.
"""

tids_attrs = self._data_store.get_attributes_of_tids(tids, attr_names)

return tids_attrs

def update(self, datapack: "DataPack"):
r"""Update the attributes and properties of the current DataPack with
another DataPack.
Expand Down
174 changes: 170 additions & 4 deletions forte/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,6 @@ def fetch_entry_type_data(
else:
attr_fields: Dict = self._get_entry_attributes_by_class(type_name)
for attr_name, attr_info in attr_fields.items():

attr_class = get_origin(attr_info.type)
# Since we store the class specified by get_origin,
# if the output it None, we store the class for it,
Expand Down Expand Up @@ -1047,7 +1046,6 @@ def _add_entry_raw(
self._is_subclass(type_name, cls)
for cls in (list(SinglePackEntries) + list(MultiPackEntries))
):

try:
self.__elements[type_name].append(entry)
except KeyError:
Expand Down Expand Up @@ -1246,7 +1244,6 @@ def add_entry_raw(
allow_duplicate: bool = True,
attribute_data: Optional[List] = None,
) -> int:

r"""
This function provides a general implementation to add all
types of entries to the data store. It can add namely
Expand Down Expand Up @@ -1487,6 +1484,173 @@ def get_attribute(self, tid: int, attr_name: str) -> Any:

return entry[attr_id]

def get_attributes_of_tid(self, tid: int, attr_names: List[str]) -> dict:
r"""This function returns the value of attributes listed in
``attr_names`` for the entry with ``tid``. It locates the entry data
with ``tid`` and finds attributes listed in ``attr_names`` and return
as a dict.

Args:
tid: Unique id of the entry.
attr_names: List of names of the attribute.

Returns:
A dict with keys listed in ``attr_names`` for attributes of the
entry with ``tid``.

Raises:
KeyError: when ``tid`` or ``attr_name`` is not found.
"""
entry, entry_type = self.get_entry(tid)
attrs: dict = {}
for attr_name in attr_names:
try:
attr_id = self._get_type_attribute_dict(entry_type)[attr_name][
constants.ATTR_INDEX_KEY
]
except KeyError as e:
raise KeyError(
f"{entry_type} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

return attrs

def get_attributes_of_tids(
self, tids: List[int], attr_names: List[str]
) -> List[Any]:
r"""This function returns the value of attributes listed in
``attr_names`` for entries in listed in the ``tids``. It locates
the entries data with ``tid`` and put attributes listed in
``attr_name`` in a dict for each entry (tid).

Args:
tids: List of unique ids (tids) of the entry.
attr_names: List of name of the attribute.

Returns:
A list of dict with ``attr_name`` as key for atrributes
of the entries listed in``tids``.

Raises:
KeyError: when ``tid`` or ``attr_name`` is not found.
"""
tids_attrs = []
for tid in tids:
entry, entry_type = self.get_entry(tid)
attrs: dict = {}
for attr_name in attr_names:
try:
attr_id = self._get_type_attribute_dict(entry_type)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{entry_type} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

tids_attrs.append(attrs)

return tids_attrs

def get_attributes_of_type(
self,
type_name: str,
attributes_names: List[str],
include_sub_type: bool = True,
range_span: Optional[Tuple[int, int]] = None,
) -> Iterator[dict]:
r"""This function fetches required attributes of entries from the
data store of type ``type_name``. If `include_sub_type` is set to
True and ``type_name`` is in [Annotation], this function also
fetches entries of subtype of ``type_name``. Otherwise, it only
fetches entries of type ``type_name``.

Args:
type_name: The fully qualified name of the entry.
attributes_names: list of attributes to be fetched for each entry
include_sub_type: A boolean to indicate whether get its subclass.
range_span: A tuple that contains the begin and end indices
of the searching range of entries.

Returns:
An iterator of the attributes of the entry in dict matching the
provided arguments.
"""

entry_class = get_class(type_name)
all_types = set()
if include_sub_type:
for type in self.__elements:
if issubclass(get_class(type), entry_class):
all_types.add(type)
else:
all_types.add(type_name)
all_types = list(all_types)
all_types.sort()

if self._is_annotation(type_name):
if range_span is None:
# yield from self.co_iterator_annotation_like(all_types)
for entry in self.co_iterator_annotation_like(all_types):
attrs: dict = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs
else:
for entry in self.co_iterator_annotation_like(
all_types, range_span=range_span
):
attrs = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs # attrs instead of entry
elif issubclass(entry_class, Link):
raise NotImplementedError(
f"{type_name} of Link is not currently supported."
)
elif issubclass(entry_class, Group):
raise NotImplementedError(
f"{type_name} of Group is not currently supported."
)
else:
if type_name not in self.__elements:
raise ValueError(f"type {type_name} does not exist")
# yield from self.iter(type_name)
for entry in self.iter(type_name):
attrs = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs

def _get_attr(self, tid: int, attr_id: int) -> Any:
r"""This function locates the entry data with ``tid`` and gets the value
of ``attr_id`` of this entry. Called by `get_attribute()`.
Expand Down Expand Up @@ -1870,7 +2034,9 @@ def co_iterator_annotation_like(
self.get_datastore_attr_idx(tn, constants.BEGIN_ATTR_NAME),
self.get_datastore_attr_idx(tn, constants.END_ATTR_NAME),
)
except IndexError as e: # all_entries_range[tn][0] will be caught here.
except (
IndexError
) as e: # all_entries_range[tn][0] will be caught here.
raise ValueError(
f"Entry list of type name, {tn} which is"
" one list item of input argument `type_names`,"
Expand Down
2 changes: 1 addition & 1 deletion forte/data/ontology/ontology_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
except ImportError:
# Try backported to PY<39 `importlib_resources`.
import importlib_resources as resources # type: ignore
from importlib_resources.abc import Traversable # type: ignore
from importlib_resources.abc import Traversable

from forte.data.ontology import top, utils
from forte.data.ontology.code_generation_exceptions import (
Expand Down
Loading