Skip to content

Commit

Permalink
Use ParagraphsSegmenter in MarkdownFormater
Browse files Browse the repository at this point in the history
  • Loading branch information
shun-liang committed Nov 4, 2024
1 parent 959f57b commit 7ad847d
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 108 deletions.
9 changes: 7 additions & 2 deletions src/yt2doc/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from yt2doc.formatting.formatter import MarkdownFormatter
from yt2doc.formatting.llm_topic_segmenter import LLMTopicSegmenter
from yt2doc.formatting.llm_adapter import LLMAdapter
from yt2doc.formatting.paragraphs_segmenter import ParagraphsSegmenter
from yt2doc.yt2doc import Yt2Doc


Expand Down Expand Up @@ -43,6 +44,7 @@ def get_yt2doc(
)

sat = SaT(sat_model)
paragraphs_segmenter = ParagraphsSegmenter(sat=sat)
if segment_unchaptered is True:
if llm_model is None:
raise LLMModelNotSpecified(
Expand All @@ -57,9 +59,12 @@ def get_yt2doc(
)
llm_adapter = LLMAdapter(llm_client=llm_client, llm_model=llm_model)
llm_topic_segmenter = LLMTopicSegmenter(llm_adapter=llm_adapter)
formatter = MarkdownFormatter(sat=sat, topic_segmenter=llm_topic_segmenter)
formatter = MarkdownFormatter(
paragraphs_segmenter=paragraphs_segmenter,
topic_segmenter=llm_topic_segmenter,
)
else:
formatter = MarkdownFormatter(sat=sat)
formatter = MarkdownFormatter(paragraphs_segmenter=paragraphs_segmenter)

video_info_extractor = MediaInfoExtractor(temp_dir=temp_dir)
transcriber = Transcriber(
Expand Down
44 changes: 19 additions & 25 deletions src/yt2doc/formatting/formatter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import typing
import logging

from wtpsplit import SaT

from yt2doc.extraction import interfaces as extraction_interfaces
from yt2doc.formatting import interfaces

Expand All @@ -12,24 +10,23 @@
class MarkdownFormatter:
def __init__(
self,
sat: SaT,
paragraphs_segmenter: interfaces.IParagraphsSegmenter,
topic_segmenter: typing.Optional[interfaces.ITopicSegmenter] = None,
) -> None:
self.sat = sat
self.paragraphs_segmenter = paragraphs_segmenter
self.topic_segmenter = topic_segmenter
self.video_title_template = "# {name}"
self.chapter_title_template = "## {name}"

def _paragraph_text(self, text: str) -> str:
if len(text) < 15:
return text
logger.info("Splitting text into paragraphs with Segment Any Text.")
paragraphed_sentences: typing.List[typing.List[str]] = self.sat.split(
text, do_paragraph_segmentation=True, verbose=True
)
paragraphs = ["".join(sentences) for sentences in paragraphed_sentences]
paragraphed_text = "\n\n".join(paragraphs)
return paragraphed_text
@staticmethod
def _paragraphs_to_text(
paragraphs: typing.Sequence[typing.Sequence[interfaces.Sentence]],
) -> str:
paragraph_texts = []
for paragraph in paragraphs:
paragraph_text = "".join(sentence.text for sentence in paragraph)
paragraph_texts.append(paragraph_text)
return "\n\n".join(paragraph_texts)

def format_chaptered_transcript(
self, chaptered_transcript: extraction_interfaces.ChapteredTranscript
Expand All @@ -42,24 +39,21 @@ def format_chaptered_transcript(
and len(chaptered_transcript.chapters) == 1
):
transcript_segments = chaptered_transcript.chapters[0].segments
full_text = "".join([segment.text for segment in transcript_segments])
logger.info(
"Splitting text into paragraphs with Segment Any Text for topic segmentation."
paragraphed_sentences = self.paragraphs_segmenter.segment(
transcription_segments=transcript_segments
)
paragraphed_sentences: typing.List[typing.List[str]] = self.sat.split(
full_text, do_paragraph_segmentation=True, verbose=True
chapters = self.topic_segmenter.segment(
sentences_in_paragraphs=paragraphed_sentences
)
chapters = self.topic_segmenter.segment(paragraphs=paragraphed_sentences)
chapter_and_text_list = [
(chapter.title, chapter.text) for chapter in chapters
(chapter.title, self._paragraphs_to_text(chapter.paragraphs))
for chapter in chapters
]

else:
for chapter in chaptered_transcript.chapters:
chapter_text = self._paragraph_text(
"".join(s.text for s in chapter.segments)
)
chapter_and_text_list.append((chapter.title, chapter_text.strip()))
chapter_full_text = "".join(s.text for s in chapter.segments)
chapter_and_text_list.append((chapter.title, chapter_full_text.strip()))

transcript_text = "\n\n".join(
[
Expand Down
6 changes: 3 additions & 3 deletions src/yt2doc/formatting/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Sentence(BaseModel):

class Chapter(BaseModel):
title: str
text: str
paragraphs: typing.Sequence[typing.Sequence[Sentence]]


class FormattedTranscript(BaseModel):
Expand All @@ -29,7 +29,7 @@ class FormattedPlaylist(BaseModel):
class IParagraphsSegmenter(typing.Protocol):
def segment(
self, transcription_segments: typing.Sequence[transcription_interfaces.Segment]
) -> typing.Sequence[typing.Sequence[Sentence]]: ...
) -> typing.List[typing.List[Sentence]]: ...


class ILLMAdapter(typing.Protocol):
Expand All @@ -44,7 +44,7 @@ def generate_title_for_paragraphs(

class ITopicSegmenter(typing.Protocol):
def segment(
self, paragraphs: typing.List[typing.List[str]]
self, sentences_in_paragraphs: typing.List[typing.List[Sentence]]
) -> typing.Sequence[Chapter]: ...


Expand Down
75 changes: 43 additions & 32 deletions src/yt2doc/formatting/llm_topic_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,14 @@ class LLMTopicSegmenter:
def __init__(self, llm_adapter: interfaces.ILLMAdapter) -> None:
self.llm_adapter = llm_adapter

def _get_title_for_chapter(self, paragraphs: typing.List[typing.List[str]]) -> str:
truncated_paragraphs = [p[:10] for p in paragraphs]
return self.llm_adapter.generate_title_for_paragraphs(
paragraphs=truncated_paragraphs
)

def segment(
self, paragraphs: typing.List[typing.List[str]]
self,
sentences_in_paragraphs: typing.List[typing.List[interfaces.Sentence]],
) -> typing.Sequence[interfaces.Chapter]:
group_size = 8
grouped_paragraphs_with_overlap = [
(i, paragraphs[i : i + group_size])
for i in range(0, len(paragraphs), group_size - 1)
(i, sentences_in_paragraphs[i : i + group_size])
for i in range(0, len(sentences_in_paragraphs), group_size - 1)
]
logger.info(
f"grouped_paragraphs_with_overlap: {grouped_paragraphs_with_overlap}"
Expand All @@ -37,9 +32,13 @@ def segment(
truncated_group_paragraphs = [
paragraph[:truncate_sentence_index] for paragraph in grouped_paragraphs
]
truncated_group_paragraphs_texts = [
[sentence.text for sentence in paragraph]
for paragraph in truncated_group_paragraphs
]

paragraph_indexes = self.llm_adapter.get_topic_changing_paragraph_indexes(
paragraphs=truncated_group_paragraphs
paragraphs=truncated_group_paragraphs_texts,
)

logger.info(f"paragraph indexes from LLM: {paragraph_indexes}")
Expand All @@ -49,34 +48,46 @@ def segment(
topic_changed_indexes += aligned_indexes

if len(topic_changed_indexes) == 0:
paragraph_texts = ["".join(paragraph) for paragraph in paragraphs]
text = "\n\n".join(paragraph_texts)
truncated_paragraphs_in_chapter = [p[:10] for p in sentences_in_paragraphs]
truncated_paragraphs_texts = [
[sentence.text for sentence in paragraph]
for paragraph in truncated_paragraphs_in_chapter
]
title = self.llm_adapter.generate_title_for_paragraphs(
paragraphs=truncated_paragraphs_texts
)
return [
interfaces.Chapter(
title=self._get_title_for_chapter(paragraphs=paragraphs),
text=text,
title=title,
paragraphs=sentences_in_paragraphs,
)
]

chapter_paragraphs: typing.List[typing.List[typing.List[str]]] = []
current_chapter_paragraphs: typing.List[typing.List[str]] = []
for index, paragraph in enumerate(paragraphs):
paragraphs_in_chapters: typing.List[
typing.List[typing.List[interfaces.Sentence]]
] = []
current_chapter_paragraphs: typing.List[typing.List[interfaces.Sentence]] = []
for index, sentences_in_paragraph in enumerate(sentences_in_paragraphs):
if index in topic_changed_indexes:
chapter_paragraphs.append(current_chapter_paragraphs)
paragraphs_in_chapters.append(current_chapter_paragraphs)
current_chapter_paragraphs = []
current_chapter_paragraphs.append(paragraph)
chapter_paragraphs.append(current_chapter_paragraphs)
current_chapter_paragraphs.append(sentences_in_paragraph)
paragraphs_in_chapters.append(current_chapter_paragraphs)

chapters: typing.List[interfaces.Chapter] = []
for paragraphs_in_chapter in tqdm(
paragraphs_in_chapters, desc="Generating titles for chapters"
):
truncated_paragraphs_in_chapter = [p[:10] for p in paragraphs_in_chapter]
truncated_paragraphs_texts = [
[sentence.text for sentence in paragraph]
for paragraph in truncated_paragraphs_in_chapter
]
title = self.llm_adapter.generate_title_for_paragraphs(
paragraphs=truncated_paragraphs_texts
)
chapters.append(
interfaces.Chapter(title=title, paragraphs=paragraphs_in_chapter)
)

chapter_titles_and_texts: typing.List[typing.Tuple[str, str]] = []
for chapter in tqdm(chapter_paragraphs, desc="Generating titles for chapters"):
paragraphs_: typing.List[str] = []
for paragraph in chapter:
paragraph_text = "".join(paragraph)
paragraphs_.append(paragraph_text)
title = self._get_title_for_chapter(paragraphs=chapter)
chapter_titles_and_texts.append((title, "\n\n".join(paragraphs_)))
chapters = [
interfaces.Chapter(title=title, text=text)
for title, text in chapter_titles_and_texts
]
return chapters
46 changes: 30 additions & 16 deletions src/yt2doc/formatting/paragraphs_segmenter.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,86 @@
import typing
import logging

from wtpsplit import SaT

from yt2doc.formatting import interfaces
from yt2doc.transcription import interfaces as transcription_interfaces

logger = logging.getLogger(__file__)


class ParagraphsSegmenter:
def __init__(self, sat: SaT) -> None:
self.sat = sat

def segment(
self, transcription_segments: typing.Sequence[transcription_interfaces.Segment]
) -> typing.Sequence[typing.Sequence[interfaces.Sentence]]:
) -> typing.List[typing.List[interfaces.Sentence]]:
# Get sentences from SaT
full_text = "".join(s.text for s in transcription_segments)
paragraphed_texts = self.sat.split(full_text, do_paragraph_segmentation=True, verbose=True)
logger.info("Splitting text into paragraphs with Segment Any Text.")
paragraphed_texts = self.sat.split(
full_text, do_paragraph_segmentation=True, verbose=True
)

# Align timestamps
segments_text = "".join(s.text for s in transcription_segments)
segments_pos = 0 # Position in segments text
curr_segment_idx = 0 # Current segment index
curr_segment_offset = 0 # Position within current segment

result_paragraphs = []

for paragraph in paragraphed_texts:
result_sentences = []

for sentence in paragraph:
# Find matching position for this sentence
sentence_pos = 0 # Position in current sentence

# Find start position
start_segment_idx = curr_segment_idx

# Match characters exactly including spaces
while sentence_pos < len(sentence):
if segments_pos >= len(segments_text):
break

# Match characters exactly
if sentence[sentence_pos] == segments_text[segments_pos]:
sentence_pos += 1
segments_pos += 1
curr_segment_offset += 1
# Update segment index if needed
while (curr_segment_idx < len(transcription_segments) - 1 and
curr_segment_offset >= len(transcription_segments[curr_segment_idx].text)):
while curr_segment_idx < len(
transcription_segments
) - 1 and curr_segment_offset >= len(
transcription_segments[curr_segment_idx].text
):
curr_segment_offset = 0
curr_segment_idx += 1
else:
# If no match, move forward in segments
segments_pos += 1
curr_segment_offset += 1
while (curr_segment_idx < len(transcription_segments) - 1 and
curr_segment_offset >= len(transcription_segments[curr_segment_idx].text)):
while curr_segment_idx < len(
transcription_segments
) - 1 and curr_segment_offset >= len(
transcription_segments[curr_segment_idx].text
):
curr_segment_offset = 0
curr_segment_idx += 1

# Create sentence with aligned timestamp
result_sentences.append(
interfaces.Sentence(
text=sentence,
start_second=transcription_segments[start_segment_idx].start_second
start_second=transcription_segments[
start_segment_idx
].start_second,
)
)

result_paragraphs.append(result_sentences)

return result_paragraphs
21 changes: 17 additions & 4 deletions tests/integ/formatting/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from src.yt2doc.formatting.formatter import MarkdownFormatter
from src.yt2doc.formatting.llm_topic_segmenter import LLMTopicSegmenter
from src.yt2doc.formatting.llm_adapter import LLMAdapter
from src.yt2doc.formatting.paragraphs_segmenter import ParagraphsSegmenter
from src.yt2doc.extraction.interfaces import ChapteredTranscript, TranscriptChapter
from src.yt2doc.transcription.interfaces import Segment

Expand Down Expand Up @@ -93,10 +94,15 @@ def test_format_chaptered_transcript_basic(
) -> None:
# Arrange
sat = SaT("sat-3l")
formatter = MarkdownFormatter(sat=sat)
paragraphs_segmenter = ParagraphsSegmenter(sat=sat)
formatter = MarkdownFormatter(paragraphs_segmenter=paragraphs_segmenter)

segments_dicts = [
{"start_second": segment.start_second, "end_second": segment.end_second, "text": segment.text}
{
"start_second": segment.start_second,
"end_second": segment.end_second,
"text": segment.text,
}
for segment in mock_transcript_segments
]

Expand Down Expand Up @@ -154,11 +160,18 @@ def mock_generate_title_for_paragraphs(
)

sat = SaT("sat-3l")
paragraphs_segmenter = ParagraphsSegmenter(sat=sat)
segmenter = LLMTopicSegmenter(llm_adapter=mock_llm_adapter)
formatter = MarkdownFormatter(sat=sat, topic_segmenter=segmenter)
formatter = MarkdownFormatter(
paragraphs_segmenter=paragraphs_segmenter, topic_segmenter=segmenter
)

segments_dicts = [
{"start_second": segment.start_second, "end_second": segment.end_second, "text": segment.text}
{
"start_second": segment.start_second,
"end_second": segment.end_second,
"text": segment.text,
}
for segment in mock_transcript_segments
]

Expand Down
Loading

0 comments on commit 7ad847d

Please sign in to comment.