Skip to content

Commit

Permalink
Fix MD parser for inconsistency tables (#10488)
Browse files Browse the repository at this point in the history
* Fix MD parser for inconsistency tables

* cf

* cr

* comment

* unit tests

* fix
  • Loading branch information
hatianzhang authored Feb 7, 2024
1 parent c70abf6 commit 371ef2c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 31 deletions.
53 changes: 30 additions & 23 deletions llama_index/node_parser/relational/base_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def extract_elements(self, text: str, **kwargs: Any) -> List[Element]:

def get_table_elements(self, elements: List[Element]) -> List[Element]:
"""Get table elements."""
return [e for e in elements if e.type == "table"]
return [e for e in elements if e.type == "table" or e.type == "table_text"]

def get_text_elements(self, elements: List[Element]) -> List[Element]:
"""Get text elements."""
Expand All @@ -146,7 +146,7 @@ def extract_table_summaries(self, elements: List[Element]) -> None:

table_context_list = []
for idx, element in tqdm(enumerate(elements)):
if element.type != "table":
if element.type not in ("table", "table_text"):
continue
table_context = str(element.element)
if idx > 0 and str(elements[idx - 1].element).lower().strip().startswith(
Expand Down Expand Up @@ -249,8 +249,8 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]:
nodes = []
cur_text_el_buffer: List[str] = []
for element in elements:
if element.type == "table":
# flush text buffer
if element.type == "table" or element.type == "table_text":
# flush text buffer for table
if len(cur_text_el_buffer) > 0:
cur_text_nodes = self._get_nodes_from_buffer(
cur_text_el_buffer, node_parser
Expand All @@ -259,7 +259,27 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]:
cur_text_el_buffer = []

table_output = cast(TableOutput, element.table_output)
table_df = cast(pd.DataFrame, element.table)
table_md = ""
if element.type == "table":
table_df = cast(pd.DataFrame, element.table)
# We serialize the table as markdown as it allow better accuracy
# We do not use the table_df.to_markdown() method as it generate
# a table with a token hungry format.
table_md = "|"
for col_name, col in table_df.items():
table_md += f"{col_name}|"
table_md += "\n|"
for col_name, col in table_df.items():
table_md += f"---|"
table_md += "\n"
for row in table_df.itertuples():
table_md += "|"
for col in row[1:]:
table_md += f"{col}|"
table_md += "\n"
elif element.type == "table_text":
# if the table is non-perfect table, we still want to keep the original text of table
table_md = str(element.element)
table_id = element.id + "_table"
table_ref_id = element.id + "_table_ref"

Expand All @@ -284,29 +304,16 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]:
index_id=table_id,
)

# We serialize the table as markdown as it allow better accuracy
# We do not use the table_df.to_markdown() method as it generate
# a table with a token hngry format.
table_md = "|"
for col_name, col in table_df.items():
table_md += f"{col_name}|"
table_md += "\n|"
for col_name, col in table_df.items():
table_md += f"---|"
table_md += "\n"
for row in table_df.itertuples():
table_md += "|"
for col in row[1:]:
table_md += f"{col}|"
table_md += "\n"

table_str = table_summary + "\n" + table_md

text_node = TextNode(
text=table_str,
id_=table_id,
metadata={
# serialize the table as a dictionary string
"table_df": str(table_df.to_dict()),
# serialize the table as a dictionary string for dataframe of perfect table
"table_df": str(table_df.to_dict())
if element.type == "table"
else table_md,
# add table summary for retrieval purposes
"table_summary": table_summary,
},
Expand Down
26 changes: 20 additions & 6 deletions llama_index/node_parser/relational/markdown_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,41 @@ def extract_elements(
for idx, element in enumerate(elements):
if element.type == "table":
should_keep = True
perfect_table = True

# verify that the table (markdown) have the same number of columns on each rows
table_lines = element.element.split("\n")
table_columns = [len(line.split("|")) for line in table_lines]
if len(set(table_columns)) > 1:
should_keep = False
# if the table have different number of columns on each rows, it's not a perfect table
# we will store the raw text for such tables instead of converting them to a dataframe
perfect_table = False

# verify that the table (markdown) have at least 2 rows
if len(table_lines) < 2:
should_keep = False

# apply the table filter, now only filter empty tables
if should_keep and table_filters is not None:
if should_keep and perfect_table and table_filters is not None:
should_keep = all(tf(element) for tf in table_filters)

# if the element is a table, convert it to a dataframe
if should_keep:
table = md_to_df(element.element)
elements[idx] = Element(
id=f"id_{idx}", type="table", element=element, table=table
)
if perfect_table:
table = md_to_df(element.element)

elements[idx] = Element(
id=f"id_{idx}", type="table", element=element, table=table
)
else:
# for non-perfect tables, we will store the raw text
# and give it a different type to differentiate it from perfect tables
elements[idx] = Element(
id=f"id_{idx}",
type="table_text",
element=element,
# table=table
)
else:
elements[idx] = Element(
id=f"id_{idx}",
Expand Down
7 changes: 5 additions & 2 deletions tests/node_parser/test_markdown_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ def test_md_table_extraction_broken_table() -> None:
print(f"Number of nodes: {len(nodes)}")
for i, node in enumerate(nodes, start=0):
print(f"Node {i}: {node}, Type: {type(node)}")
assert len(nodes) == 3
assert len(nodes) == 6
assert isinstance(nodes[0], TextNode)
assert isinstance(nodes[1], IndexNode)
assert isinstance(nodes[2], TextNode)
assert isinstance(nodes[3], TextNode)
assert isinstance(nodes[4], IndexNode)
assert isinstance(nodes[5], TextNode)


def test_complex_md() -> None:
Expand Down Expand Up @@ -2645,4 +2648,4 @@ def test_llama2_bad_md() -> None:
node_parser = MarkdownElementNodeParser(llm=MockLLM())

nodes = node_parser.get_nodes_from_documents([test_data])
assert len(nodes) == 208
assert len(nodes) == 224

0 comments on commit 371ef2c

Please sign in to comment.