diff --git a/.github/workflows/type-check-and-lint.yml b/.github/workflows/type-check-lint-integ-tests.yml similarity index 57% rename from .github/workflows/type-check-and-lint.yml rename to .github/workflows/type-check-lint-integ-tests.yml index 5821493..6748ac1 100644 --- a/.github/workflows/type-check-and-lint.yml +++ b/.github/workflows/type-check-lint-integ-tests.yml @@ -1,4 +1,4 @@ -name: Type check and lint +name: Type check, lint, and integration tests on: push @@ -12,7 +12,11 @@ jobs: run: uv python install 3.12 - name: Install dependencies run: uv sync - - name: Run mypy + - name: Run mypy on src run: uv run mypy --strict src + - name: Run mypy on tests + run: uv run mypy --strict tests - name: Run ruff check - run: uv run ruff check \ No newline at end of file + run: uv run ruff check + - name: Run pytest + run: uv run pytest tests \ No newline at end of file diff --git a/src/yt2doc/cli.py b/src/yt2doc/cli.py index 4f20c60..980e487 100644 --- a/src/yt2doc/cli.py +++ b/src/yt2doc/cli.py @@ -87,7 +87,9 @@ def main( ignore_source_chapters: typing.Annotated[ bool, typer.Option( - "--ignore-source-chapters", "--ignore-chapters", help="Ignore original chapters from the source" + "--ignore-source-chapters", + "--ignore-chapters", + help="Ignore original chapters from the source", ), ] = False, ) -> None: diff --git a/src/yt2doc/factories.py b/src/yt2doc/factories.py index c893098..44940d6 100644 --- a/src/yt2doc/factories.py +++ b/src/yt2doc/factories.py @@ -14,6 +14,7 @@ from yt2doc.extraction.extractor import Extractor from yt2doc.formatting.formatter import MarkdownFormatter from yt2doc.formatting.llm_topic_segmenter import LLMTopicSegmenter +from yt2doc.formatting.llm_adapter import LLMAdapter from yt2doc.yt2doc import Yt2Doc @@ -54,10 +55,8 @@ def get_yt2doc( ), mode=instructor.Mode.JSON, ) - - llm_topic_segmenter = LLMTopicSegmenter( - llm_client=llm_client, llm_model=llm_model - ) + llm_adapter = LLMAdapter(llm_client=llm_client, llm_model=llm_model) + llm_topic_segmenter = LLMTopicSegmenter(llm_adapter=llm_adapter) formatter = MarkdownFormatter(sat=sat, topic_segmenter=llm_topic_segmenter) else: formatter = MarkdownFormatter(sat=sat) diff --git a/src/yt2doc/formatting/llm_topic_segmenter.py b/src/yt2doc/formatting/llm_topic_segmenter.py index f0dfa0e..9fb4827 100644 --- a/src/yt2doc/formatting/llm_topic_segmenter.py +++ b/src/yt2doc/formatting/llm_topic_segmenter.py @@ -1,9 +1,6 @@ import typing import logging -import instructor - -from pydantic import BaseModel, AfterValidator from tqdm import tqdm from yt2doc.formatting import interfaces @@ -12,37 +9,14 @@ class LLMTopicSegmenter: - def __init__(self, llm_client: instructor.Instructor, llm_model: str) -> None: - self.llm_client = llm_client - self.model = llm_model + def __init__(self, llm_adapter: interfaces.ILLMAdapter) -> None: + self.llm_adapter = llm_adapter def _get_title_for_chapter(self, paragraphs: typing.List[typing.List[str]]) -> str: truncated_paragraphs = [p[:10] for p in paragraphs] - truncated_text = "\n\n".join(["".join(p) for p in truncated_paragraphs]) - title = self.llm_client.chat.completions.create( - model=self.model, - response_model=str, - messages=[ - { - "role": "system", - "content": """ - Please generate a short title for the following text. - - Be VERY SUCCINCT. No more than 6 words. - """, - }, - { - "role": "user", - "content": """ - {{ text }} - """, - }, - ], - context={ - "text": truncated_text, - }, + return self.llm_adapter.generate_title_for_paragraphs( + paragraphs=truncated_paragraphs ) - return title def segment( self, paragraphs: typing.List[typing.List[str]] @@ -60,68 +34,17 @@ def segment( grouped_paragraphs_with_overlap, desc="Finding topic change points" ): truncate_sentence_index = 6 - truncated_grouped_paragraph_texts = [ - "".join(paragraph[:truncate_sentence_index]) - for paragraph in grouped_paragraphs + truncated_group_paragraphs = [ + paragraph[:truncate_sentence_index] for paragraph in grouped_paragraphs ] - def validate_paragraph_indexes(v: typing.List[int]) -> typing.List[int]: - n = len(truncated_grouped_paragraph_texts) - unique_values = set(v) - if len(unique_values) != len(v): - raise ValueError("All elements must be unique") - for i in v: - if i <= 0: - raise ValueError( - f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is less than or equal to 0" - ) - if i >= n: - raise ValueError( - f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is greater or equal to {n}" - ) - - return v - - class Result(BaseModel): - paragraph_indexes: typing.Annotated[ - typing.List[int], AfterValidator(validate_paragraph_indexes) - ] - - result = self.llm_client.chat.completions.create( - model=self.model, - response_model=Result, - messages=[ - { - "role": "system", - "content": """ - You are a smart assistant who reads paragraphs of text from an audio transcript and - find the paragraphs that significantly change topic from the previous paragraph. - - Make sure only mark paragraphs that talks about a VERY DIFFERENT topic from the previous one. - - The response should be an array of the index number of such paragraphs, such as `[1, 3, 5]` - - If there is no paragraph that changes topic, then return an empty list. - """, - }, - { - "role": "user", - "content": """ - {% for paragraph in paragraphs %} - - {{ paragraph }} - - {% endfor %} - """, - }, - ], - context={ - "paragraphs": truncated_grouped_paragraph_texts, - }, + paragraph_indexes = self.llm_adapter.get_topic_changing_paragraph_indexes( + paragraphs=truncated_group_paragraphs ) - logger.info(f"paragraph indexes from LLM: {result}") + + logger.info(f"paragraph indexes from LLM: {paragraph_indexes}") aligned_indexes = [ - start_index + index for index in sorted(result.paragraph_indexes) + start_index + index for index in sorted(paragraph_indexes) ] topic_changed_indexes += aligned_indexes diff --git a/tests/integ/formatting/test_formatter.py b/tests/integ/formatting/test_formatter.py index d69bd98..12d2069 100644 --- a/tests/integ/formatting/test_formatter.py +++ b/tests/integ/formatting/test_formatter.py @@ -10,6 +10,7 @@ from src.yt2doc.formatting.formatter import MarkdownFormatter from src.yt2doc.formatting.llm_topic_segmenter import LLMTopicSegmenter +from src.yt2doc.formatting.llm_adapter import LLMAdapter from src.yt2doc.extraction.interfaces import ChapteredTranscript, TranscriptChapter from src.yt2doc.transcription.interfaces import Segment @@ -34,9 +35,8 @@ def parse_whisper_line(line: str) -> Segment: @pytest.fixture -def mock_llm_client() -> Instructor: - mock_llm_instance = MagicMock(spec=Instructor) - return mock_llm_instance +def mock_llm_adapter() -> LLMAdapter: + return MagicMock(spec_set=LLMAdapter) @pytest.fixture @@ -131,11 +131,31 @@ def test_format_chaptered_transcript_basic( def test_markdown_formatter_with_segmentation( - mock_transcript_segments: typing.List[Segment], mock_llm_client: Instructor + mock_transcript_segments: typing.List[Segment], mock_llm_adapter: LLMAdapter ) -> None: # Arrange + def mocked_get_topic_change( + paragraphs: typing.List[typing.List[str]], + ) -> typing.List[int]: + if len(paragraphs) == 1: + return [] + return [1] + + mock_llm_adapter.get_topic_changing_paragraph_indexes.side_effect = ( # type: ignore + mocked_get_topic_change + ) + + def mock_generate_title_for_paragraphs( + paragraphs: typing.List[typing.List[str]], + ) -> str: + return f"Chapter Title" + + mock_llm_adapter.generate_title_for_paragraphs.side_effect = ( # type: ignore + mock_generate_title_for_paragraphs + ) + sat = SaT("sat-3l") - segmenter = LLMTopicSegmenter(llm_client=mock_llm_client) + segmenter = LLMTopicSegmenter(llm_adapter=mock_llm_adapter) formatter = MarkdownFormatter(sat=sat, topic_segmenter=segmenter) segments_dicts = [ @@ -147,26 +167,15 @@ def test_markdown_formatter_with_segmentation( url="https://example.com/video", title="Test Video Title", language="en", - chaptered_at_source=True, + chaptered_at_source=False, chapters=[ TranscriptChapter( - title="Chapter 1", + title="Untitled Chapter", segments=segments_dicts, ), ], ) - # Configure mock LLM response - # mock_llm_client.chat.completions.create.return_value = MagicMock( - # choices=[ - # MagicMock( - # message=MagicMock( - # content='{"chapters": [{"title": "Introduction", "text": "Test paragraph 1."}, {"title": "Discussion", "text": "Test paragraph 2."}]}' - # ) - # ) - # ] - # ) - # Act formatted_output = formatter.format_chaptered_transcript( chaptered_transcript=test_transcript @@ -174,8 +183,5 @@ def test_markdown_formatter_with_segmentation( # Assert assert "# Test Video Title" in formatted_output.transcript - assert "## Introduction" in formatted_output.transcript - assert "## Discussion" in formatted_output.transcript - assert "Test paragraph 1." in formatted_output.transcript - assert "Test paragraph 2." in formatted_output.transcript - mock_llm_client.chat.completions.create.assert_called_once() + assert "## Chapter Title" in formatted_output.transcript + # mock_llm_client.chat.completions.create.assert_called_once()