Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
shun-liang committed Nov 2, 2024
1 parent f1e5a90 commit 6dae8d2
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 16 deletions.
15 changes: 14 additions & 1 deletion src/yt2doc/formatting/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from pydantic import BaseModel

from yt2doc.extraction import interfaces as extraction_interfaces
from yt2doc.transcription import interfaces as transcription_interfaces


class Sentence(BaseModel):
start_second: float
end_second: float
text: str


class Chapter(BaseModel):
Expand All @@ -20,10 +27,16 @@ class FormattedPlaylist(BaseModel):
transcripts: typing.Sequence[FormattedTranscript]


class IParagraphsSegmenter(typing.Protocol):
def segment(
self, transcription_segments: typing.Sequence[transcription_interfaces.Segment]
) -> typing.Sequence[typing.Sequence[Sentence]]: ...


class ILLMAdapter(typing.Protocol):
def get_topic_changing_paragraph_indexes(
self, paragraphs: typing.List[typing.List[str]]
) -> typing.List[int]: ...
) -> typing.Sequence[int]: ...

def generate_title_for_paragraphs(
self, paragraphs: typing.List[typing.List[str]]
Expand Down
2 changes: 1 addition & 1 deletion src/yt2doc/formatting/llm_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, llm_client: Instructor, llm_model: str) -> None:

def get_topic_changing_paragraph_indexes(
self, paragraphs: typing.List[typing.List[str]]
) -> typing.List[int]:
) -> typing.Sequence[int]:
def validate_paragraph_indexes(v: typing.List[int]) -> typing.List[int]:
n = len(paragraphs)
unique_values = set(v)
Expand Down
23 changes: 23 additions & 0 deletions src/yt2doc/formatting/paragraphs_segmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import typing

from wtpsplit import SaT

from src.yt2doc.formatting import interfaces
from yt2doc.transcription import interfaces as transcription_interfaces


class ParagraphsSegmenter:
def __init__(self, sat: SaT) -> None:
self.sat = sat

def segment(
self, transcription_segments: typing.Sequence[transcription_interfaces.Segment]
) -> typing.Sequence[typing.Sequence[interfaces.Sentence]]:
transcription_segment_texts = [s.text for s in transcription_segments]
full_text = "".join(transcription_segment_texts)
paragraphed_texts: typing.List[typing.List[str]] = self.sat.split(
full_text, do_paragraph_segmentation=True, verbose=True
)

# Paragraph-Transcription Segment timestamp alignment

4 changes: 3 additions & 1 deletion src/yt2doc/transcription/faster_whisper_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def transcribe(
audio=audio_path, initial_prompt=initial_prompt, vad_filter=True
)
return (
interfaces.Segment(start=segment.start, end=segment.end, text=segment.text)
interfaces.Segment(
start_second=segment.start, end_second=segment.end, text=segment.text
)
for segment in segments
)
4 changes: 2 additions & 2 deletions src/yt2doc/transcription/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class Segment(BaseModel):
start: float
end: float
start_second: float
end_second: float
text: str


Expand Down
8 changes: 4 additions & 4 deletions src/yt2doc/transcription/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,15 @@ def transcribe(
)
for segment in segments:
aligned_segment = interfaces.Segment(
start=chapter.start_time + segment.start,
end=chapter.start_time + segment.end,
start_second=chapter.start_time + segment.start_second,
end_second=chapter.start_time + segment.end_second,
text=self._fix_comma(
segment_text=segment.text, language_code=language_code
),
)
chapter_segments.append(aligned_segment)
progress_bar.update(aligned_segment.end - current_timestamp)
current_timestamp = aligned_segment.end
progress_bar.update(aligned_segment.end_second - current_timestamp)
current_timestamp = aligned_segment.end_second

chaptered_transcriptions.append(
interfaces.ChapterTranscription(
Expand Down
4 changes: 2 additions & 2 deletions src/yt2doc/transcription/whisper_cpp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def _parse_whisper_line(self, line: str) -> interfaces.Segment:
if match:
start_time, end_time, text = match.groups()
return interfaces.Segment(
start=self._time_to_seconds(start_time),
end=self._time_to_seconds(end_time),
start_second=self._time_to_seconds(start_time),
end_second=self._time_to_seconds(end_time),
text=text,
)
raise CannotParseWhisperCppLineException(
Expand Down
9 changes: 4 additions & 5 deletions tests/integ/formatting/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def parse_whisper_line(line: str) -> Segment:
if match:
start_time, end_time, text = match.groups()
return Segment(
start=time_to_seconds(start_time),
end=time_to_seconds(end_time),
start_second=time_to_seconds(start_time),
end_second=time_to_seconds(end_time),
text=text,
)
assert False, "Line does not match expected whisper segment pattern."
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_format_chaptered_transcript_basic(
formatter = MarkdownFormatter(sat=sat)

segments_dicts = [
{"start": segment.start, "end": segment.end, "text": segment.text}
{"start": segment.start_second, "end": segment.end_second, "text": segment.text}
for segment in mock_transcript_segments
]

Expand Down Expand Up @@ -158,7 +158,7 @@ def mock_generate_title_for_paragraphs(
formatter = MarkdownFormatter(sat=sat, topic_segmenter=segmenter)

segments_dicts = [
{"start": segment.start, "end": segment.end, "text": segment.text}
{"start": segment.start_second, "end": segment.end_second, "text": segment.text}
for segment in mock_transcript_segments
]

Expand All @@ -183,4 +183,3 @@ def mock_generate_title_for_paragraphs(
# Assert
assert "# Test Video Title" in formatted_output.transcript
assert "## Chapter Title" in formatted_output.transcript
# mock_llm_client.chat.completions.create.assert_called_once()

0 comments on commit 6dae8d2

Please sign in to comment.