diff --git a/gptme/codeblock.py b/gptme/codeblock.py index 711b23d8..7b815d76 100644 --- a/gptme/codeblock.py +++ b/gptme/codeblock.py @@ -66,22 +66,41 @@ def _extract_codeblocks(markdown: str) -> Generator[Codeblock, None, None]: # TODO: fix to actually be correct start_idx = sum(len(line) + 1 for line in lines[:idx]) stripped_line = line.strip() - if stripped_line.startswith("```"): - if not stack: # Start of a new block - stack.append(stripped_line[3:]) - current_lang = stripped_line[3:] - elif stripped_line[3:] and stack[-1] != stripped_line[3:]: # Nested start - current_block.append(line) - stack.append(stripped_line[3:]) - else: # End of a block - if len(stack) == 1: # Outermost block - yield Codeblock( - current_lang, "\n".join(current_block), start=start_idx - ) - current_block = [] - current_lang = "" - else: # Nested end + + # Handle multiple code block markers on the same line + pos = 0 + while pos < len(stripped_line): + if stripped_line[pos:].startswith("```"): + pos += 3 + # Find the language or end marker + lang_end = pos + while lang_end < len(stripped_line) and not stripped_line[ + lang_end: + ].startswith("```"): + lang_end += 1 + lang = stripped_line[pos:lang_end].strip() + + if not stack: # Start of a new block + stack.append(lang) + current_lang = lang + pos = lang_end + elif lang and stack[-1] != lang: # Nested start current_block.append(line) - stack.pop() - elif stack: + stack.append(lang) + pos = lang_end + else: # End of a block + if len(stack) == 1: # Outermost block + yield Codeblock( + current_lang, "\n".join(current_block), start=start_idx + ) + current_block = [] + current_lang = "" + else: # Nested end + current_block.append(line) + stack.pop() + pos = lang_end + 3 # Skip past the closing ``` + else: + pos += 1 + + if not stripped_line.startswith("```") and stack: current_block.append(line) diff --git a/tests/test_codeblock.py b/tests/test_codeblock.py index 2705eb2b..9027060d 100644 --- a/tests/test_codeblock.py +++ b/tests/test_codeblock.py @@ -68,6 +68,64 @@ def print_readme(): assert Codeblock.iter_from_markdown(markdown) == [] +def test_extract_codeblocks_nested_oneline(): + markdown = """ +```python +def print_readme(): + print(''' + ```echo test``` + ''') +``` +""" + blocks = list(Codeblock.iter_from_markdown(markdown)) + assert len(blocks) == 1 + assert blocks[0].lang == "python" + assert "```echo test```" in blocks[0].content + + +def test_extract_codeblocks_complete_nested(): + markdown = """ +```python +def print_readme(): + print('''Usage: +```javascript +console.log('hello') +``` + ''') +``` +""" + blocks = list(Codeblock.iter_from_markdown(markdown)) + assert len(blocks) == 1 + assert blocks[0].lang == "python" + assert "```javascript" in blocks[0].content + + +def test_extract_codeblocks_multiple_nested(): + markdown = """ +```python +def example(): + code = ''' +```javascript +console.log('hello') +``` + ''' + doc = ''' +```html +
test
+``` + ''' +``` +""" + blocks = list(Codeblock.iter_from_markdown(markdown)) + assert len(blocks) == 1 + assert blocks[0].lang == "python" + assert "```javascript" in blocks[0].content + assert "```html" in blocks[0].content + # check that entire content is extracted + assert blocks[0].content.count("```") == 4 + assert blocks[0].content.count("'''") == 4 + + def test_extract_codeblocks_empty(): assert Codeblock.iter_from_markdown("") == []