Skip to content

Commit

Permalink
Add COBOL parser and splitter (langchain-ai#11674)
Browse files Browse the repository at this point in the history
- **Description:** Add COBOL parser and splitter
  - **Issue:** n/a
  - **Dependencies:** n/a
  - **Tag maintainer:** @baskaryan 
  - **Twitter handle:** erhartford

---------

Co-authored-by: Bagatur <[email protected]>
Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2023
1 parent bb137fd commit 8c150ad
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import re
from typing import Callable, List

from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter


class CobolSegmenter(CodeSegmenter):
"""Code segmenter for `COBOL`."""

PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
DIVISION_PATTERN = re.compile(
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
)
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)

def __init__(self, code: str):
super().__init__(code)
self.source_lines: List[str] = self.code.splitlines()

def is_valid(self) -> bool:
# Identify presence of any division to validate COBOL code
return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines)

def _extract_code(self, start_idx: int, end_idx: int) -> str:
return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n")

def _is_relevant_code(self, line: str) -> bool:
"""Check if a line is part of the procedure division or a relevant section."""
if "PROCEDURE DIVISION" in line.upper():
return True
# Add additional conditions for relevant sections if needed
return False

def _process_lines(self, func: Callable) -> List[str]:
"""A generic function to process COBOL lines based on provided func."""
elements: List[str] = []
start_idx = None
inside_relevant_section = False

for i, line in enumerate(self.source_lines):
if self._is_relevant_code(line):
inside_relevant_section = True

if inside_relevant_section and (
self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
or self.SECTION_PATTERN.match(line.strip())
):
if start_idx is not None:
func(elements, start_idx, i)
start_idx = i

# Handle the last element if exists
if start_idx is not None:
func(elements, start_idx, len(self.source_lines))

return elements

def extract_functions_classes(self) -> List[str]:
def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None:
elements.append(self._extract_code(start_idx, end_idx))

return self._process_lines(extract_func)

def simplify_code(self) -> str:
simplified_lines: List[str] = []
inside_relevant_section = False
omitted_code_added = (
False # To track if "* OMITTED CODE *" has been added after the last header
)

for line in self.source_lines:
is_header = (
"PROCEDURE DIVISION" in line
or "DATA DIVISION" in line
or "IDENTIFICATION DIVISION" in line
or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
or self.SECTION_PATTERN.match(line.strip())
)

if is_header:
inside_relevant_section = True
# Reset the flag since we're entering a new section/division or
# paragraph
omitted_code_added = False

if inside_relevant_section:
if is_header:
# Add header and reset the omitted code added flag
simplified_lines.append(line)
elif not omitted_code_added:
# Add omitted code comment only if it hasn't been added directly
# after the last header
simplified_lines.append("* OMITTED CODE *")
omitted_code_added = True

return "\n".join(simplified_lines)
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.document_loaders.parsers.language.cobol import CobolSegmenter
from langchain.document_loaders.parsers.language.javascript import JavaScriptSegmenter
from langchain.document_loaders.parsers.language.python import PythonSegmenter
from langchain.text_splitter import Language

LANGUAGE_EXTENSIONS: Dict[str, str] = {
"py": Language.PYTHON,
"js": Language.JS,
"cobol": Language.COBOL,
}

LANGUAGE_SEGMENTERS: Dict[str, Any] = {
Language.PYTHON: PythonSegmenter,
Language.JS: JavaScriptSegmenter,
Language.COBOL: CobolSegmenter,
}


Expand Down
33 changes: 33 additions & 0 deletions libs/langchain/langchain/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ class Language(str, Enum):
HTML = "html"
SOL = "sol"
CSHARP = "csharp"
COBOL = "cobol"


class RecursiveCharacterTextSplitter(TextSplitter):
Expand Down Expand Up @@ -1305,6 +1306,38 @@ def get_separators_for_language(language: Language) -> List[str]:
" ",
"",
]
elif language == Language.COBOL:
return [
# Split along divisions
"\nIDENTIFICATION DIVISION.",
"\nENVIRONMENT DIVISION.",
"\nDATA DIVISION.",
"\nPROCEDURE DIVISION.",
# Split along sections within DATA DIVISION
"\nWORKING-STORAGE SECTION.",
"\nLINKAGE SECTION.",
"\nFILE SECTION.",
# Split along sections within PROCEDURE DIVISION
"\nINPUT-OUTPUT SECTION.",
# Split along paragraphs and common statements
"\nOPEN ",
"\nCLOSE ",
"\nREAD ",
"\nWRITE ",
"\nIF ",
"\nELSE ",
"\nMOVE ",
"\nPERFORM ",
"\nUNTIL ",
"\nVARYING ",
"\nACCEPT ",
"\nDISPLAY ",
"\nSTOP RUN.",
# Split by the normal type of lines
"\n",
" ",
"",
]

else:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from langchain.document_loaders.parsers.language.cobol import CobolSegmenter

EXAMPLE_CODE = """
IDENTIFICATION DIVISION.
PROGRAM-ID. SampleProgram.
DATA DIVISION.
WORKING-STORAGE SECTION.
01 SAMPLE-VAR PIC X(20) VALUE 'Sample Value'.
PROCEDURE DIVISION.
A000-INITIALIZE-PARA.
DISPLAY 'Initialization Paragraph'.
MOVE 'New Value' TO SAMPLE-VAR.
A100-PROCESS-PARA.
DISPLAY SAMPLE-VAR.
STOP RUN.
"""


def test_extract_functions_classes() -> None:
"""Test that functions and classes are extracted correctly."""
segmenter = CobolSegmenter(EXAMPLE_CODE)
extracted_code = segmenter.extract_functions_classes()
assert extracted_code == [
"A000-INITIALIZE-PARA.\n "
"DISPLAY 'Initialization Paragraph'.\n "
"MOVE 'New Value' TO SAMPLE-VAR.",
"A100-PROCESS-PARA.\n DISPLAY SAMPLE-VAR.\n STOP RUN.",
]


def test_simplify_code() -> None:
"""Test that code is simplified correctly."""
expected_simplified_code = (
"IDENTIFICATION DIVISION.\n"
"PROGRAM-ID. SampleProgram.\n"
"DATA DIVISION.\n"
"WORKING-STORAGE SECTION.\n"
"* OMITTED CODE *\n"
"PROCEDURE DIVISION.\n"
"A000-INITIALIZE-PARA.\n"
"* OMITTED CODE *\n"
"A100-PROCESS-PARA.\n"
"* OMITTED CODE *\n"
)
segmenter = CobolSegmenter(EXAMPLE_CODE)
simplified_code = segmenter.simplify_code()
assert simplified_code.strip() == expected_simplified_code.strip()
35 changes: 35 additions & 0 deletions libs/langchain/tests/unit_tests/test_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,41 @@ def test_javascript_code_splitter() -> None:
]


def test_cobol_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.COBOL, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
IDENTIFICATION DIVISION.
PROGRAM-ID. HelloWorld.
DATA DIVISION.
WORKING-STORAGE SECTION.
01 GREETING PIC X(12) VALUE 'Hello, World!'.
PROCEDURE DIVISION.
DISPLAY GREETING.
STOP RUN.
"""
chunks = splitter.split_text(code)
assert chunks == [
"IDENTIFICATION",
"DIVISION.",
"PROGRAM-ID.",
"HelloWorld.",
"DATA DIVISION.",
"WORKING-STORAGE",
"SECTION.",
"01 GREETING",
"PIC X(12)",
"VALUE 'Hello,",
"World!'.",
"PROCEDURE",
"DIVISION.",
"DISPLAY",
"GREETING.",
"STOP RUN.",
]


def test_typescript_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.TS, chunk_size=CHUNK_SIZE, chunk_overlap=0
Expand Down

0 comments on commit 8c150ad

Please sign in to comment.