Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration tests on formatter #49

Merged
merged 10 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Type check and lint
name: Type check, lint, and integration tests

on: push

Expand All @@ -12,7 +12,11 @@ jobs:
run: uv python install 3.12
- name: Install dependencies
run: uv sync
- name: Run mypy
- name: Run mypy on src
run: uv run mypy --strict src
- name: Run mypy on tests
run: uv run mypy --strict tests
- name: Run ruff check
run: uv run ruff check
run: uv run ruff check
- name: Run pytest
run: uv run python -m pytest tests
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ source = "vcs" # Use version control (e.g., git) for versioning
dev-dependencies = [
"ipython>=8.27.0",
"mypy>=1.11.2",
"pytest>=8.3.3",
"ruff>=0.6.3",
"types-tqdm>=4.66.0.20240417",
]
Expand Down
4 changes: 3 additions & 1 deletion src/yt2doc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def main(
ignore_source_chapters: typing.Annotated[
bool,
typer.Option(
"--ignore-source-chapters", "--ignore-chapters", help="Ignore original chapters from the source"
"--ignore-source-chapters",
"--ignore-chapters",
help="Ignore original chapters from the source",
),
] = False,
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/yt2doc/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from yt2doc.extraction.extractor import Extractor
from yt2doc.formatting.formatter import MarkdownFormatter
from yt2doc.formatting.llm_topic_segmenter import LLMTopicSegmenter
from yt2doc.formatting.llm_adapter import LLMAdapter
from yt2doc.yt2doc import Yt2Doc


Expand Down Expand Up @@ -54,8 +55,8 @@ def get_yt2doc(
),
mode=instructor.Mode.JSON,
)

llm_topic_segmenter = LLMTopicSegmenter(llm_client=llm_client, model=llm_model)
llm_adapter = LLMAdapter(llm_client=llm_client, llm_model=llm_model)
llm_topic_segmenter = LLMTopicSegmenter(llm_adapter=llm_adapter)
formatter = MarkdownFormatter(sat=sat, topic_segmenter=llm_topic_segmenter)
else:
formatter = MarkdownFormatter(sat=sat)
Expand Down
10 changes: 10 additions & 0 deletions src/yt2doc/formatting/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class FormattedPlaylist(BaseModel):
transcripts: typing.Sequence[FormattedTranscript]


class ILLMAdapter(typing.Protocol):
def get_topic_changing_paragraph_indexes(
self, paragraphs: typing.List[typing.List[str]]
) -> typing.List[int]: ...

def generate_title_for_paragraphs(
self, paragraphs: typing.List[typing.List[str]]
) -> str: ...


class ITopicSegmenter(typing.Protocol):
def segment(
self, paragraphs: typing.List[typing.List[str]]
Expand Down
100 changes: 100 additions & 0 deletions src/yt2doc/formatting/llm_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import typing

from instructor import Instructor
from pydantic import BaseModel, AfterValidator


class LLMAdapter:
def __init__(self, llm_client: Instructor, llm_model: str) -> None:
self.llm_client = llm_client
self.llm_model = llm_model

def get_topic_changing_paragraph_indexes(
self, paragraphs: typing.List[typing.List[str]]
) -> typing.List[int]:
def validate_paragraph_indexes(v: typing.List[int]) -> typing.List[int]:
n = len(paragraphs)
unique_values = set(v)
if len(unique_values) != len(v):
raise ValueError("All elements must be unique")
for i in v:
if i <= 0:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is less than or equal to 0"
)
if i >= n:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is greater or equal to {n}"
)

return v

paragraph_texts = ["\n\n".join(p) for p in paragraphs]

class Result(BaseModel):
paragraph_indexes: typing.Annotated[
typing.List[int], AfterValidator(validate_paragraph_indexes)
]

result = self.llm_client.chat.completions.create(
model=self.llm_model,
response_model=Result,
messages=[
{
"role": "system",
"content": """
You are a smart assistant who reads paragraphs of text from an audio transcript and
find the paragraphs that significantly change topic from the previous paragraph.

Make sure only mark paragraphs that talks about a VERY DIFFERENT topic from the previous one.

The response should be an array of the index number of such paragraphs, such as `[1, 3, 5]`

If there is no paragraph that changes topic, then return an empty list.
""",
},
{
"role": "user",
"content": """
{% for paragraph in paragraphs %}
<paragraph {{ loop.index0 }}>
{{ paragraph }}
</ paragraph {{ loop.index0 }}>
{% endfor %}
""",
},
],
context={
"paragraphs": paragraph_texts,
},
)
return result.paragraph_indexes

def generate_title_for_paragraphs(
self, paragraphs: typing.List[typing.List[str]]
) -> str:
text = "\n\n".join(["".join(p) for p in paragraphs])
title = self.llm_client.chat.completions.create(
model=self.llm_model,
response_model=str,
messages=[
{
"role": "system",
"content": """
Please generate a short title for the following text.

Be VERY SUCCINCT. No more than 6 words.
""",
},
{
"role": "user",
"content": """
{{ text }}
""",
},
],
context={
"text": text,
},
)
return title
99 changes: 11 additions & 88 deletions src/yt2doc/formatting/llm_topic_segmenter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import typing
import logging

import instructor

from pydantic import BaseModel, AfterValidator
from tqdm import tqdm

from yt2doc.formatting import interfaces
Expand All @@ -12,37 +9,14 @@


class LLMTopicSegmenter:
def __init__(self, llm_client: instructor.Instructor, model: str) -> None:
self.llm_client = llm_client
self.model = model
def __init__(self, llm_adapter: interfaces.ILLMAdapter) -> None:
self.llm_adapter = llm_adapter

def _get_title_for_chapter(self, paragraphs: typing.List[typing.List[str]]) -> str:
truncated_paragraphs = [p[:10] for p in paragraphs]
truncated_text = "\n\n".join(["".join(p) for p in truncated_paragraphs])
title = self.llm_client.chat.completions.create(
model=self.model,
response_model=str,
messages=[
{
"role": "system",
"content": """
Please generate a short title for the following text.

Be VERY SUCCINCT. No more than 6 words.
""",
},
{
"role": "user",
"content": """
{{ text }}
""",
},
],
context={
"text": truncated_text,
},
return self.llm_adapter.generate_title_for_paragraphs(
paragraphs=truncated_paragraphs
)
return title

def segment(
self, paragraphs: typing.List[typing.List[str]]
Expand All @@ -60,68 +34,17 @@ def segment(
grouped_paragraphs_with_overlap, desc="Finding topic change points"
):
truncate_sentence_index = 6
truncated_grouped_paragraph_texts = [
"".join(paragraph[:truncate_sentence_index])
for paragraph in grouped_paragraphs
truncated_group_paragraphs = [
paragraph[:truncate_sentence_index] for paragraph in grouped_paragraphs
]

def validate_paragraph_indexes(v: typing.List[int]) -> typing.List[int]:
n = len(truncated_grouped_paragraph_texts)
unique_values = set(v)
if len(unique_values) != len(v):
raise ValueError("All elements must be unique")
for i in v:
if i <= 0:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is less than or equal to 0"
)
if i >= n:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is greater or equal to {n}"
)

return v

class Result(BaseModel):
paragraph_indexes: typing.Annotated[
typing.List[int], AfterValidator(validate_paragraph_indexes)
]

result = self.llm_client.chat.completions.create(
model=self.model,
response_model=Result,
messages=[
{
"role": "system",
"content": """
You are a smart assistant who reads paragraphs of text from an audio transcript and
find the paragraphs that significantly change topic from the previous paragraph.

Make sure only mark paragraphs that talks about a VERY DIFFERENT topic from the previous one.

The response should be an array of the index number of such paragraphs, such as `[1, 3, 5]`

If there is no paragraph that changes topic, then return an empty list.
""",
},
{
"role": "user",
"content": """
{% for paragraph in paragraphs %}
<paragraph {{ loop.index0 }}>
{{ paragraph }}
</ paragraph {{ loop.index0 }}>
{% endfor %}
""",
},
],
context={
"paragraphs": truncated_grouped_paragraph_texts,
},
paragraph_indexes = self.llm_adapter.get_topic_changing_paragraph_indexes(
paragraphs=truncated_group_paragraphs
)
logger.info(f"paragraph indexes from LLM: {result}")

logger.info(f"paragraph indexes from LLM: {paragraph_indexes}")
aligned_indexes = [
start_index + index for index in sorted(result.paragraph_indexes)
start_index + index for index in sorted(paragraph_indexes)
]
topic_changed_indexes += aligned_indexes

Expand Down
Loading
Loading