Skip to content

Commit

Permalink
Refactor grobid sections (#281)
Browse files Browse the repository at this point in the history
* some kind of progress, still need to address overlap in sentences crossing paragraphs

* ok cool this seems to be working!

* make heading spans part of section

* make sentences have unique ids, give paragraphs and sections ids

* fix 'coords' error

* pad_x for sentences

* IT WORKS we get nice spans for sentences for this one specific sha now

* remove spanless results (useless)

* lil rename

* mmda version bump

* just return list

* oops delete my thoughts

* oops fix my error made when switching to just list being returned

* new fix_overlaps param
  • Loading branch information
geli-gel authored Nov 9, 2023
1 parent 750dfe6 commit dd65039
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'mmda'
version = '0.9.15'
version = '0.9.16'
description = 'MMDA - multimodal document analysis'
authors = [
{name = 'Allen Institute for Artificial Intelligence', email = '[email protected]'},
Expand Down
140 changes: 90 additions & 50 deletions src/mmda/parsers/grobid_augment_existing_document_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
"""
from grobid_client.grobid_client import GrobidClient
from typing import List, Optional
from typing import List, Optional, Tuple, Dict

import logging
import os
import xml.etree.ElementTree as et

from mmda.parsers.parser import Parser
from mmda.types import Metadata
from mmda.types.annotation import BoxGroup, Box, SpanGroup
from mmda.types.annotation import BoxGroup, Box, SpanGroup, Span
from mmda.types.document import Document
from mmda.types.names import PagesField, RowsField, TokensField
from mmda.utils.tools import box_groups_to_span_groups
Expand Down Expand Up @@ -104,32 +104,69 @@ def _parse_xml_onto_doc(self, xml: str, doc: Document) -> Document:
# sentences within the body text, also tagged by paragraphs.
# We use these to annotate the document in order to provide a hierarchical structure:
# e.g. doc.sections.header, doc.sections[0].paragraphs[0].sentences[0]
section_box_groups, heading_box_groups, paragraph_box_groups, sentence_box_groups = \
self._get_structured_body_text_box_groups(xml_root)
doc.annotate(
sections=box_groups_to_span_groups(
section_box_groups, doc, center=True
)
)
doc.annotate(
headings=box_groups_to_span_groups(
heading_box_groups, doc, center=True
)
)
doc.annotate(
paragraphs=box_groups_to_span_groups(
paragraph_box_groups, doc, center=True
)
)
doc.annotate(
sentences=box_groups_to_span_groups(
sentence_box_groups, doc, center=True
)
)
section_headings_and_sentence_box_groups_in_paragraphs = \
self._get_structured_sentence_box_groups(xml_root)

heading_span_groups = []
paragraph_span_groups = []
section_span_groups = []
sentence_span_groups = []

unallocated_section_tokens_dict: Dict[int, SpanGroup] = dict()

for heading_box_group, paragraphs in section_headings_and_sentence_box_groups_in_paragraphs:
section_spans = []
if heading_box_group:
heading_span_group_in_list = (
box_groups_to_span_groups(
[heading_box_group],
doc,
center=True,
unallocated_tokens_dict=unallocated_section_tokens_dict,
fix_overlaps=True,
)
)
heading_span_group = heading_span_group_in_list[0]
heading_span_groups.append(heading_span_group)
section_spans.extend(heading_span_group.spans)
this_section_paragraph_span_groups = []
for sentence_box_groups in paragraphs:
this_paragraph_sentence_span_groups = box_groups_to_span_groups(
sentence_box_groups,
doc,
center=True,
pad_x=True,
unallocated_tokens_dict=unallocated_section_tokens_dict,
fix_overlaps=True,
)
if all([sg.spans for sg in this_paragraph_sentence_span_groups]):
sentence_span_groups.extend(this_paragraph_sentence_span_groups)
paragraph_spans = []
for sg in this_paragraph_sentence_span_groups:
paragraph_spans.extend(sg.spans)
# TODO add boxes to paragraph spangroups
this_section_paragraph_span_groups.append(SpanGroup(spans=paragraph_spans))
paragraph_span_groups.extend(this_section_paragraph_span_groups)
for sg in this_section_paragraph_span_groups:
section_spans.extend(sg.spans)
# TODO add boxes to section spangroups
section_span_groups.append(SpanGroup(spans=section_spans))

# ensure unique IDs within annotations
all_section_span_groups = [heading_span_groups, sentence_span_groups, paragraph_span_groups, section_span_groups]
for span_groups in all_section_span_groups:
for i, span_group in enumerate(span_groups):
span_group.id = i

doc.annotate(headings=heading_span_groups)
doc.annotate(sentences=sentence_span_groups)
doc.annotate(paragraphs=paragraph_span_groups)
doc.annotate(sections=section_span_groups)


return doc

def _xml_coords_to_boxes(self, coords_attribute: str):
def _xml_coords_to_boxes(self, coords_attribute: str) -> List[Box]:
coords_list = coords_attribute.split(";")
boxes = []
for coords in coords_list:
Expand Down Expand Up @@ -176,7 +213,11 @@ def _get_box_groups(
elements = item_list_root.findall(f".//tei:{item_tag}", NS)

for e in elements:
coords_string = e.attrib["coords"]
try:
coords_string = e.attrib["coords"]
except KeyError:
logging.warning(f"Element with '{item_tag}' tag missing 'coords' attribute")
continue
boxes = self._xml_coords_to_boxes(coords_string)

grobid_id = e.attrib[ID_ATTR_KEY] if ID_ATTR_KEY in e.keys() else None
Expand Down Expand Up @@ -208,7 +249,11 @@ def _get_heading_box_group(
box_group = None
heading_element = section_div.find(f".//tei:head", NS)
if heading_element is not None: # elements evaluate as False if no children
coords_string = heading_element.attrib["coords"]
try:
coords_string = heading_element.attrib["coords"]
except KeyError:
logging.warning(f"Heading element missing 'coords' attribute")
return None
boxes = self._xml_coords_to_boxes(coords_string)
number = heading_element.attrib["n"] if "n" in heading_element.keys() else None
section_title = heading_element.text
Expand All @@ -218,34 +263,29 @@ def _get_heading_box_group(
)
return box_group

def _get_structured_body_text_box_groups(
def _get_structured_sentence_box_groups(
self,
root: et.Element
) -> (List[BoxGroup], List[BoxGroup], List[BoxGroup], List[BoxGroup]):
) -> List[Tuple[Optional[BoxGroup], List[List[BoxGroup]]]]:
section_list_root = root.find(f".//tei:body", NS)

body_sections: List[BoxGroup] = []
body_headings: List[BoxGroup] = []
body_paragraphs: List[BoxGroup] = []
body_sentences: List[BoxGroup] = []

section_divs = section_list_root.findall(f"./tei:div", NS)

section_structures = []
for div in section_divs:
section_boxes: List[Box] = []
heading_box_group = self._get_heading_box_group(div)
if heading_box_group:
body_headings.append(heading_box_group)
section_boxes.extend(heading_box_group.boxes)
paragraphs: List[List[BoxGroup]] = []
for p in div.findall(f"./tei:p", NS):
paragraph_boxes: List[Box] = []
paragraph_sentences: List[BoxGroup] = []
sentence_box_groups: List[BoxGroup] = []
for s in p.findall(f"./tei:s", NS):
sentence_boxes = self._xml_coords_to_boxes(s.attrib["coords"])
paragraph_sentences.append(BoxGroup(boxes=sentence_boxes))
paragraph_boxes.extend(sentence_boxes)
body_paragraphs.append(BoxGroup(boxes=paragraph_boxes))
section_boxes.extend(paragraph_boxes)
body_sentences.extend(paragraph_sentences)
body_sections.append(BoxGroup(boxes=section_boxes))

return body_sections, body_headings, body_paragraphs, body_sentences
try:
coords_string = s.attrib["coords"]
except KeyError:
logging.warning(f"Sentence element missing 'coords' attribute")
continue
sentence_boxes = self._xml_coords_to_boxes(coords_string)
sentence_box_groups.append(BoxGroup(boxes=sentence_boxes))
paragraphs.append(sentence_box_groups)

section_structures.append([heading_box_group, paragraphs])

return section_structures
85 changes: 65 additions & 20 deletions src/mmda/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from itertools import groupby
import itertools
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Optional, Union

import numpy as np

Expand Down Expand Up @@ -41,20 +41,38 @@ def allocate_overlapping_tokens_for_box(


def box_groups_to_span_groups(
box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False
box_groups: List[BoxGroup],
doc,
pad_x: bool = False,
center: bool = False,
unallocated_tokens_dict: Optional[Dict[int, SpanGroup]] = None,
fix_overlaps: bool = False,
) -> List[SpanGroup]:
"""Generate SpanGroups from BoxGroups.
"""Generate SpanGroups from BoxGroups given they can only generate spans of tokens not already allocated
Args
`box_groups` (List[BoxGroup])
`doc` (Document) base document annotated with pages, tokens, rows to
`center` (bool) if True, considers tokens to be overlapping with boxes only if their centers overlap
`center` (bool) if True, considers tokens to be overlapping with boxes only if their centers overlap
`unallocated_tokens` (Optional[Dict]) of token spangroups keyed by page. If provided, will use as starting
point for determining if token is already allocated. Assumes the tokens within are of the same type as the
`doc` (i.e., tokens from both doc and the dict both have their box data in either Span.box or
SpanGroup.boxgroup)
`fix_overlaps` (bool) if True, will attempt to fix overlapping spans within a SpanGroup by omitting
spans from already allocated tokens that end up contained in the derived_spans that come from MergeSpans.
This allows for the possibility of a BoxGroup that covers text to end up with a SpanGroup that is missing
spans or even has no spans since a previous BoxGroup already allocated all the underlying tokens. This
reduces the possibility of SpanGroup overlap errors, but may not return the desired SpanGroups.
Returns
List[SpanGroup] with each SpanGroup.spans corresponding to spans (sans boxes) of allocated tokens per box_group,
Union (either) of:
-List[SpanGroup] with each SpanGroup.spans corresponding to spans (sans boxes) of allocated tokens per box_group,
and each SpanGroup.box_group containing original box_groups
or Tuple of:
-List[SpanGroup] as described above, and
-Dictionary of unallocated tokens keyed by page
"""
assert all([isinstance(group, BoxGroup) for group in box_groups])

all_page_tokens = dict()
unallocated_tokens = unallocated_tokens_dict if unallocated_tokens_dict is not None else dict()
avg_token_widths = dict()
derived_span_groups = []
token_box_in_box_group = None
Expand All @@ -66,8 +84,8 @@ def box_groups_to_span_groups(
for box in box_group.boxes:

# Caching the page tokens to avoid duplicated search
if box.page not in all_page_tokens:
cur_page_tokens = all_page_tokens[box.page] = doc.pages[
if box.page not in unallocated_tokens:
cur_page_tokens = unallocated_tokens[box.page] = doc.pages[
box.page
].tokens
if token_box_in_box_group is None:
Expand All @@ -89,7 +107,7 @@ def box_groups_to_span_groups(
avg_token_widths[box.page] = np.average([t.spans[0].box.w for t in cur_page_tokens])

else:
cur_page_tokens = all_page_tokens[box.page]
cur_page_tokens = unallocated_tokens[box.page]

# Find all the tokens within the box
tokens_in_box, remaining_tokens = allocate_overlapping_tokens_for_box(
Expand All @@ -101,7 +119,7 @@ def box_groups_to_span_groups(
y=0.0,
center=center
)
all_page_tokens[box.page] = remaining_tokens
unallocated_tokens[box.page] = remaining_tokens
all_tokens_overlapping_box_group.extend(tokens_in_box)

merge_spans = (
Expand All @@ -123,15 +141,47 @@ def box_groups_to_span_groups(
# tokens overlapping with derived spans:
sg_tokens = doc.find_overlapping(SpanGroup(spans=derived_spans), "tokens")

def omit_span_from_derived_spans(t_span):
# if the sg_token is in the derived_span, cut it out by updating derived_spans.
# this can happen because merge_spans finds min number of spans and can merge spans that
# cover tokens that were already allocated. We update this to avoid spangroup overlap errors.
for i, d_span in enumerate(derived_spans):
if d_span.start == t_span.start and t_span.end < d_span.end:
# unusable token_span is at start of derived_span
d_span.start = t_span.end
elif d_span.end == t_span.end and d_span.start < t_span.start < d_span.end:
# unusable token_span is at end of derived_span
d_span.end = t_span.end
elif d_span.start < t_span.start < d_span.end and t_span.end < d_span.end:
# unusable token_span is encompassed by derived_span
d_span.end = t_span.start
derived_spans.insert(i+1, Span(t_span.end, d_span.end))
elif d_span.start == t_span.start and d_span.end == t_span.end:
# unusable token_span is equal to derived_span
derived_spans.remove(d_span)

# remove any additional tokens added to the spangroup via MergeSpans from the list of available page tokens
# (this can happen if the MergeSpans algorithm merges tokens that are not adjacent, e.g. if `center` is True and
# a token is not found to be overlapping with the box, but MergeSpans decides it is close enough to be merged)
for sg_token in sg_tokens:
if sg_token not in all_tokens_overlapping_box_group:
if token_box_in_box_group and sg_token in all_page_tokens[sg_token.box_group.boxes[0].page]:
all_page_tokens[sg_token.box_group.boxes[0].page].remove(sg_token)
elif not token_box_in_box_group and sg_token in all_page_tokens[sg_token.spans[0].box.page]:
all_page_tokens[sg_token.spans[0].box.page].remove(sg_token)
# if token not removed from unallocated_tokens yet, do it now
if token_box_in_box_group:
if sg_token in unallocated_tokens[sg_token.box_group.boxes[0].page]:
unallocated_tokens[sg_token.box_group.boxes[0].page].remove(sg_token)
# otherwise, if it is in neither all_tokens_overlapping_box_group nor unallocated_tokens,
# the assumption is that the token has already been allocated by a different box_group, so, we need
# to remove it from our derived spans to avoid 'SpanGroup overlap' error.
else:
if fix_overlaps:
omit_span_from_derived_spans(sg_token.spans[0])
else:
if sg_token in unallocated_tokens[sg_token.spans[0].box.page]:
unallocated_tokens[sg_token.spans[0].box.page].remove(sg_token)
# same scenario as above.
else:
if fix_overlaps:
omit_span_from_derived_spans(sg_token.spans[0])

derived_span_groups.append(
SpanGroup(
Expand All @@ -148,21 +198,16 @@ def box_groups_to_span_groups(
"future Spans wont contain box). Ensure Document is annotated with tokens "
"having box stored in SpanGroup box_group.boxes")

del all_page_tokens

derived_span_groups = sorted(
derived_span_groups, key=lambda span_group: span_group.start
)
# ensure they are ordered based on span indices

for box_id, span_group in enumerate(derived_span_groups):
span_group.id = box_id

# return self._annotate_span_group(
# span_groups=derived_span_groups, field_name=field_name
# )
return derived_span_groups


class MergeSpans:
"""
Given w=width and h=height merge neighboring spans which are w, h or less apart or by merging neighboring spans
Expand Down

0 comments on commit dd65039

Please sign in to comment.