diff --git a/llama_index/node_parser/relational/base_element.py b/llama_index/node_parser/relational/base_element.py index 106dd10b6197b..3dcbf1887e159 100644 --- a/llama_index/node_parser/relational/base_element.py +++ b/llama_index/node_parser/relational/base_element.py @@ -126,7 +126,7 @@ def extract_elements(self, text: str, **kwargs: Any) -> List[Element]: def get_table_elements(self, elements: List[Element]) -> List[Element]: """Get table elements.""" - return [e for e in elements if e.type == "table"] + return [e for e in elements if e.type == "table" or e.type == "table_text"] def get_text_elements(self, elements: List[Element]) -> List[Element]: """Get text elements.""" @@ -146,7 +146,7 @@ def extract_table_summaries(self, elements: List[Element]) -> None: table_context_list = [] for idx, element in tqdm(enumerate(elements)): - if element.type != "table": + if element.type not in ("table", "table_text"): continue table_context = str(element.element) if idx > 0 and str(elements[idx - 1].element).lower().strip().startswith( @@ -249,8 +249,8 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]: nodes = [] cur_text_el_buffer: List[str] = [] for element in elements: - if element.type == "table": - # flush text buffer + if element.type == "table" or element.type == "table_text": + # flush text buffer for table if len(cur_text_el_buffer) > 0: cur_text_nodes = self._get_nodes_from_buffer( cur_text_el_buffer, node_parser @@ -259,7 +259,27 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]: cur_text_el_buffer = [] table_output = cast(TableOutput, element.table_output) - table_df = cast(pd.DataFrame, element.table) + table_md = "" + if element.type == "table": + table_df = cast(pd.DataFrame, element.table) + # We serialize the table as markdown as it allow better accuracy + # We do not use the table_df.to_markdown() method as it generate + # a table with a token hungry format. + table_md = "|" + for col_name, col in table_df.items(): + table_md += f"{col_name}|" + table_md += "\n|" + for col_name, col in table_df.items(): + table_md += f"---|" + table_md += "\n" + for row in table_df.itertuples(): + table_md += "|" + for col in row[1:]: + table_md += f"{col}|" + table_md += "\n" + elif element.type == "table_text": + # if the table is non-perfect table, we still want to keep the original text of table + table_md = str(element.element) table_id = element.id + "_table" table_ref_id = element.id + "_table_ref" @@ -284,29 +304,16 @@ def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]: index_id=table_id, ) - # We serialize the table as markdown as it allow better accuracy - # We do not use the table_df.to_markdown() method as it generate - # a table with a token hngry format. - table_md = "|" - for col_name, col in table_df.items(): - table_md += f"{col_name}|" - table_md += "\n|" - for col_name, col in table_df.items(): - table_md += f"---|" - table_md += "\n" - for row in table_df.itertuples(): - table_md += "|" - for col in row[1:]: - table_md += f"{col}|" - table_md += "\n" - table_str = table_summary + "\n" + table_md + text_node = TextNode( text=table_str, id_=table_id, metadata={ - # serialize the table as a dictionary string - "table_df": str(table_df.to_dict()), + # serialize the table as a dictionary string for dataframe of perfect table + "table_df": str(table_df.to_dict()) + if element.type == "table" + else table_md, # add table summary for retrieval purposes "table_summary": table_summary, }, diff --git a/llama_index/node_parser/relational/markdown_element.py b/llama_index/node_parser/relational/markdown_element.py index 4e176894471f9..7abea2608744e 100644 --- a/llama_index/node_parser/relational/markdown_element.py +++ b/llama_index/node_parser/relational/markdown_element.py @@ -144,27 +144,41 @@ def extract_elements( for idx, element in enumerate(elements): if element.type == "table": should_keep = True + perfect_table = True # verify that the table (markdown) have the same number of columns on each rows table_lines = element.element.split("\n") table_columns = [len(line.split("|")) for line in table_lines] if len(set(table_columns)) > 1: - should_keep = False + # if the table have different number of columns on each rows, it's not a perfect table + # we will store the raw text for such tables instead of converting them to a dataframe + perfect_table = False # verify that the table (markdown) have at least 2 rows if len(table_lines) < 2: should_keep = False # apply the table filter, now only filter empty tables - if should_keep and table_filters is not None: + if should_keep and perfect_table and table_filters is not None: should_keep = all(tf(element) for tf in table_filters) # if the element is a table, convert it to a dataframe if should_keep: - table = md_to_df(element.element) - elements[idx] = Element( - id=f"id_{idx}", type="table", element=element, table=table - ) + if perfect_table: + table = md_to_df(element.element) + + elements[idx] = Element( + id=f"id_{idx}", type="table", element=element, table=table + ) + else: + # for non-perfect tables, we will store the raw text + # and give it a different type to differentiate it from perfect tables + elements[idx] = Element( + id=f"id_{idx}", + type="table_text", + element=element, + # table=table + ) else: elements[idx] = Element( id=f"id_{idx}", diff --git a/tests/node_parser/test_markdown_element.py b/tests/node_parser/test_markdown_element.py index 96b593109ae3b..597968e5a9e82 100644 --- a/tests/node_parser/test_markdown_element.py +++ b/tests/node_parser/test_markdown_element.py @@ -76,10 +76,13 @@ def test_md_table_extraction_broken_table() -> None: print(f"Number of nodes: {len(nodes)}") for i, node in enumerate(nodes, start=0): print(f"Node {i}: {node}, Type: {type(node)}") - assert len(nodes) == 3 + assert len(nodes) == 6 assert isinstance(nodes[0], TextNode) assert isinstance(nodes[1], IndexNode) assert isinstance(nodes[2], TextNode) + assert isinstance(nodes[3], TextNode) + assert isinstance(nodes[4], IndexNode) + assert isinstance(nodes[5], TextNode) def test_complex_md() -> None: @@ -2645,4 +2648,4 @@ def test_llama2_bad_md() -> None: node_parser = MarkdownElementNodeParser(llm=MockLLM()) nodes = node_parser.get_nodes_from_documents([test_data]) - assert len(nodes) == 208 + assert len(nodes) == 224