Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
shun-liang committed Nov 2, 2024
1 parent 3c7615a commit 991bc15
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 119 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Type check and lint
name: Type check, lint, and integration tests

on: push

Expand All @@ -12,7 +12,11 @@ jobs:
run: uv python install 3.12
- name: Install dependencies
run: uv sync
- name: Run mypy
- name: Run mypy on src
run: uv run mypy --strict src
- name: Run mypy on tests
run: uv run mypy --strict tests
- name: Run ruff check
run: uv run ruff check
run: uv run ruff check
- name: Run pytest
run: uv run pytest tests
4 changes: 3 additions & 1 deletion src/yt2doc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def main(
ignore_source_chapters: typing.Annotated[
bool,
typer.Option(
"--ignore-source-chapters", "--ignore-chapters", help="Ignore original chapters from the source"
"--ignore-source-chapters",
"--ignore-chapters",
help="Ignore original chapters from the source",
),
] = False,
) -> None:
Expand Down
7 changes: 3 additions & 4 deletions src/yt2doc/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from yt2doc.extraction.extractor import Extractor
from yt2doc.formatting.formatter import MarkdownFormatter
from yt2doc.formatting.llm_topic_segmenter import LLMTopicSegmenter
from yt2doc.formatting.llm_adapter import LLMAdapter
from yt2doc.yt2doc import Yt2Doc


Expand Down Expand Up @@ -54,10 +55,8 @@ def get_yt2doc(
),
mode=instructor.Mode.JSON,
)

llm_topic_segmenter = LLMTopicSegmenter(
llm_client=llm_client, llm_model=llm_model
)
llm_adapter = LLMAdapter(llm_client=llm_client, llm_model=llm_model)
llm_topic_segmenter = LLMTopicSegmenter(llm_adapter=llm_adapter)
formatter = MarkdownFormatter(sat=sat, topic_segmenter=llm_topic_segmenter)
else:
formatter = MarkdownFormatter(sat=sat)
Expand Down
99 changes: 11 additions & 88 deletions src/yt2doc/formatting/llm_topic_segmenter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import typing
import logging

import instructor

from pydantic import BaseModel, AfterValidator
from tqdm import tqdm

from yt2doc.formatting import interfaces
Expand All @@ -12,37 +9,14 @@


class LLMTopicSegmenter:
def __init__(self, llm_client: instructor.Instructor, llm_model: str) -> None:
self.llm_client = llm_client
self.model = llm_model
def __init__(self, llm_adapter: interfaces.ILLMAdapter) -> None:
self.llm_adapter = llm_adapter

def _get_title_for_chapter(self, paragraphs: typing.List[typing.List[str]]) -> str:
truncated_paragraphs = [p[:10] for p in paragraphs]
truncated_text = "\n\n".join(["".join(p) for p in truncated_paragraphs])
title = self.llm_client.chat.completions.create(
model=self.model,
response_model=str,
messages=[
{
"role": "system",
"content": """
Please generate a short title for the following text.
Be VERY SUCCINCT. No more than 6 words.
""",
},
{
"role": "user",
"content": """
{{ text }}
""",
},
],
context={
"text": truncated_text,
},
return self.llm_adapter.generate_title_for_paragraphs(
paragraphs=truncated_paragraphs
)
return title

def segment(
self, paragraphs: typing.List[typing.List[str]]
Expand All @@ -60,68 +34,17 @@ def segment(
grouped_paragraphs_with_overlap, desc="Finding topic change points"
):
truncate_sentence_index = 6
truncated_grouped_paragraph_texts = [
"".join(paragraph[:truncate_sentence_index])
for paragraph in grouped_paragraphs
truncated_group_paragraphs = [
paragraph[:truncate_sentence_index] for paragraph in grouped_paragraphs
]

def validate_paragraph_indexes(v: typing.List[int]) -> typing.List[int]:
n = len(truncated_grouped_paragraph_texts)
unique_values = set(v)
if len(unique_values) != len(v):
raise ValueError("All elements must be unique")
for i in v:
if i <= 0:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is less than or equal to 0"
)
if i >= n:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is greater or equal to {n}"
)

return v

class Result(BaseModel):
paragraph_indexes: typing.Annotated[
typing.List[int], AfterValidator(validate_paragraph_indexes)
]

result = self.llm_client.chat.completions.create(
model=self.model,
response_model=Result,
messages=[
{
"role": "system",
"content": """
You are a smart assistant who reads paragraphs of text from an audio transcript and
find the paragraphs that significantly change topic from the previous paragraph.
Make sure only mark paragraphs that talks about a VERY DIFFERENT topic from the previous one.
The response should be an array of the index number of such paragraphs, such as `[1, 3, 5]`
If there is no paragraph that changes topic, then return an empty list.
""",
},
{
"role": "user",
"content": """
{% for paragraph in paragraphs %}
<paragraph {{ loop.index0 }}>
{{ paragraph }}
</ paragraph {{ loop.index0 }}>
{% endfor %}
""",
},
],
context={
"paragraphs": truncated_grouped_paragraph_texts,
},
paragraph_indexes = self.llm_adapter.get_topic_changing_paragraph_indexes(
paragraphs=truncated_group_paragraphs
)
logger.info(f"paragraph indexes from LLM: {result}")

logger.info(f"paragraph indexes from LLM: {paragraph_indexes}")
aligned_indexes = [
start_index + index for index in sorted(result.paragraph_indexes)
start_index + index for index in sorted(paragraph_indexes)
]
topic_changed_indexes += aligned_indexes

Expand Down
52 changes: 29 additions & 23 deletions tests/integ/formatting/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.yt2doc.formatting.formatter import MarkdownFormatter
from src.yt2doc.formatting.llm_topic_segmenter import LLMTopicSegmenter
from src.yt2doc.formatting.llm_adapter import LLMAdapter
from src.yt2doc.extraction.interfaces import ChapteredTranscript, TranscriptChapter
from src.yt2doc.transcription.interfaces import Segment

Expand All @@ -34,9 +35,8 @@ def parse_whisper_line(line: str) -> Segment:


@pytest.fixture
def mock_llm_client() -> Instructor:
mock_llm_instance = MagicMock(spec=Instructor)
return mock_llm_instance
def mock_llm_adapter() -> LLMAdapter:
return MagicMock(spec_set=LLMAdapter)


@pytest.fixture
Expand Down Expand Up @@ -131,11 +131,31 @@ def test_format_chaptered_transcript_basic(


def test_markdown_formatter_with_segmentation(
mock_transcript_segments: typing.List[Segment], mock_llm_client: Instructor
mock_transcript_segments: typing.List[Segment], mock_llm_adapter: LLMAdapter
) -> None:
# Arrange
def mocked_get_topic_change(
paragraphs: typing.List[typing.List[str]],
) -> typing.List[int]:
if len(paragraphs) == 1:
return []
return [1]

mock_llm_adapter.get_topic_changing_paragraph_indexes.side_effect = ( # type: ignore
mocked_get_topic_change
)

def mock_generate_title_for_paragraphs(
paragraphs: typing.List[typing.List[str]],
) -> str:
return f"Chapter Title"

mock_llm_adapter.generate_title_for_paragraphs.side_effect = ( # type: ignore
mock_generate_title_for_paragraphs
)

sat = SaT("sat-3l")
segmenter = LLMTopicSegmenter(llm_client=mock_llm_client)
segmenter = LLMTopicSegmenter(llm_adapter=mock_llm_adapter)
formatter = MarkdownFormatter(sat=sat, topic_segmenter=segmenter)

segments_dicts = [
Expand All @@ -147,35 +167,21 @@ def test_markdown_formatter_with_segmentation(
url="https://example.com/video",
title="Test Video Title",
language="en",
chaptered_at_source=True,
chaptered_at_source=False,
chapters=[
TranscriptChapter(
title="Chapter 1",
title="Untitled Chapter",
segments=segments_dicts,
),
],
)

# Configure mock LLM response
# mock_llm_client.chat.completions.create.return_value = MagicMock(
# choices=[
# MagicMock(
# message=MagicMock(
# content='{"chapters": [{"title": "Introduction", "text": "Test paragraph 1."}, {"title": "Discussion", "text": "Test paragraph 2."}]}'
# )
# )
# ]
# )

# Act
formatted_output = formatter.format_chaptered_transcript(
chaptered_transcript=test_transcript
)

# Assert
assert "# Test Video Title" in formatted_output.transcript
assert "## Introduction" in formatted_output.transcript
assert "## Discussion" in formatted_output.transcript
assert "Test paragraph 1." in formatted_output.transcript
assert "Test paragraph 2." in formatted_output.transcript
mock_llm_client.chat.completions.create.assert_called_once()
assert "## Chapter Title" in formatted_output.transcript
# mock_llm_client.chat.completions.create.assert_called_once()

0 comments on commit 991bc15

Please sign in to comment.