Skip to content

Commit

Permalink
Add table of content feature and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
shun-liang committed Nov 9, 2024
1 parent f717967 commit 1dd57df
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 86 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ dependencies = [
"faster-whisper>=1.0.3",
"ffmpeg-python>=0.2.0",
"instructor>=1.5.1",
"jinja2>=3.1.4",
"openai>=1.51.0",
"pathvalidate>=3.2.1",
"pydantic>=2.9.1",
"python-slugify>=8.0.4",
"torch>=2.4.1",
"tqdm>=4.66.5",
"typer-slim>=0.12.5",
Expand Down
6 changes: 6 additions & 0 deletions src/yt2doc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def main(
"--timestamp-paragraphs",
help="Prepend timestamp to paragraphs",
),
add_table_of_contents: bool = typer.Option(
False,
"--add-table-of-contents",
help="Add table of contents at the beginning of the document",
),
skip_cache: typing.Annotated[
bool,
typer.Option("--skip-cache", help="If should skip reading from cache"),
Expand Down Expand Up @@ -151,6 +156,7 @@ def main(
segment_unchaptered=segment_unchaptered,
ignore_source_chapters=ignore_source_chapters,
to_timestamp_paragraphs=to_timestamp_paragraphs,
add_table_of_contents=add_table_of_contents,
llm_model=llm_model,
llm_server=llm_server,
llm_api_key=llm_api_key,
Expand Down
3 changes: 3 additions & 0 deletions src/yt2doc/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_yt2doc(
segment_unchaptered: bool,
ignore_source_chapters: bool,
to_timestamp_paragraphs: bool,
add_table_of_contents: bool,
llm_model: typing.Optional[str],
llm_server: str,
llm_api_key: str,
Expand Down Expand Up @@ -63,12 +64,14 @@ def get_yt2doc(
formatter = MarkdownFormatter(
paragraphs_segmenter=paragraphs_segmenter,
to_timestamp_paragraphs=to_timestamp_paragraphs,
add_table_of_contents=add_table_of_contents,
topic_segmenter=llm_topic_segmenter,
)
else:
formatter = MarkdownFormatter(
paragraphs_segmenter=paragraphs_segmenter,
to_timestamp_paragraphs=to_timestamp_paragraphs,
add_table_of_contents=add_table_of_contents,
)

media_info_extractor = MediaInfoExtractor(temp_dir=temp_dir)
Expand Down
141 changes: 75 additions & 66 deletions src/yt2doc/formatting/formatter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
import typing
import logging

import jinja2

from datetime import timedelta
from pathlib import Path

from pydantic import BaseModel
from slugify import slugify

from yt2doc.extraction import interfaces as extraction_interfaces
from yt2doc.formatting import interfaces

logger = logging.getLogger(__file__)


class ParagraphToRender(BaseModel):
start_h_m_s: str
text: str


class ChapterToRender(BaseModel):
title: str
custom_id: str
start_second: float
full_text: str
start_h_m_s: str
paragraphs: typing.Sequence[ParagraphToRender]


class MarkdownFormatter:
Expand All @@ -34,51 +43,66 @@ def __init__(
self.add_table_of_contents = add_table_of_contents

@staticmethod
def _paragraphs_to_text(
paragraphs: typing.Sequence[typing.Sequence[interfaces.Sentence]],
video_id: str,
webpage_url_domain: str,
to_timestamp_paragraphs: bool,
def _start_second_to_start_h_m_s(
start_second: float, webpage_url_domain: str, video_id: str
) -> str:
paragraph_texts = []
for paragraph in paragraphs:
first_sentence = paragraph[0]
paragraph_text = "".join(sentence.text for sentence in paragraph)
paragraph_text = paragraph_text.strip()
if to_timestamp_paragraphs:
paragraph_start_second = round(first_sentence.start_second)
paragraph_start_h_m_s = str(timedelta(seconds=paragraph_start_second))
if webpage_url_domain == "youtube.com":
timestamp_prefix = f"[({paragraph_start_h_m_s})](https://youtu.be/{video_id}?t={paragraph_start_second})"
else:
timestamp_prefix = f"({paragraph_start_h_m_s})"
paragraph_text = f"{timestamp_prefix} {paragraph_text}"
paragraph_texts.append(paragraph_text)
return "\n\n".join(paragraph_texts)
rounded_start_second = round(start_second)
start_h_m_s = str(timedelta(rounded_start_second))
if webpage_url_domain == "youtube.com":
return f"[{start_h_m_s}](https://youtu.be/{video_id}?t={start_second})"
return start_h_m_s

@staticmethod
def _get_table_of_contents(
chapters: typing.Sequence[ChapterToRender],
def _render(
self,
title: str,
chapters: typing.Sequence[interfaces.Chapter],
video_url: str,
video_id: str,
webpage_url_domain: str,
) -> str:
chapter_links = []
chapters_to_render: typing.List[ChapterToRender] = []
for chapter in chapters:
chapter_start_second = round(chapter.start_second)
chapter_start_h_m_s = str(timedelta(seconds=chapter_start_second))
if webpage_url_domain == "youtube.com":
timestamp_prefix = f"[({chapter_start_h_m_s})](https://youtu.be/{video_id}?t={chapter_start_second})"
else:
timestamp_prefix = f"({chapter_start_h_m_s})"
chapter_link = f"{timestamp_prefix} {chapter.title}"
chapter_links.append(chapter_link)
table_of_content_text = f"### Table of Contents\n\n{"\n".join(chapter_links)}"
if len(chapter.paragraphs) == 0:
continue

paragraphs_to_render = [
ParagraphToRender(
text="".join([sentence.text for sentence in paragraph]),
start_h_m_s=self._start_second_to_start_h_m_s(
start_second=paragraph[0].start_second,
webpage_url_domain=webpage_url_domain,
video_id=video_id,
),
)
for paragraph in chapter.paragraphs
]
first_paragraph_to_render = paragraphs_to_render[0]
chapters_to_render.append(
ChapterToRender(
title=chapter.title,
custom_id=slugify(chapter.title),
start_h_m_s=first_paragraph_to_render.start_h_m_s,
paragraphs=paragraphs_to_render,
)
)

current_dir = Path(__file__).parent
jinja_environment = jinja2.Environment(
loader=jinja2.FileSystemLoader(current_dir)
)
template = jinja_environment.get_template("template.md")
rendered = template.render(
title=title,
chapters=[chapter.model_dump() for chapter in chapters_to_render],
video_url=video_url,
add_table_of_contents=self.add_table_of_contents,
to_timestamp_paragraphs=self.to_timestamp_paragraphs,
)
return rendered

def format_chaptered_transcript(
self, chaptered_transcript: extraction_interfaces.ChapteredTranscript
) -> interfaces.FormattedTranscript:
chapter_and_text_list: typing.List[typing.Tuple[str, str]] = []

if (
self.topic_segmenter is not None
and not chaptered_transcript.chaptered_at_source
Expand All @@ -91,42 +115,27 @@ def format_chaptered_transcript(
chapters = self.topic_segmenter.segment(
sentences_in_paragraphs=paragraphed_sentences
)
chapter_and_text_list = [
(
chapter.title,
self._paragraphs_to_text(
paragraphs=chapter.paragraphs,
video_id=chaptered_transcript.video_id,
webpage_url_domain=chaptered_transcript.webpage_url_domain,
to_timestamp_paragraphs=self.to_timestamp_paragraphs,
),
)
for chapter in chapters
]

else:
for chapter in chaptered_transcript.chapters:
paragraphed_sentences = self.paragraphs_segmenter.segment(
transcription_segments=chapter.segments
)
chapter_full_text = self._paragraphs_to_text(
paragraphs=paragraphed_sentences,
video_id=chaptered_transcript.video_id,
webpage_url_domain=chaptered_transcript.webpage_url_domain,
to_timestamp_paragraphs=self.to_timestamp_paragraphs,
chapters = [
interfaces.Chapter(
title=chapter.title,
paragraphs=self.paragraphs_segmenter.segment(chapter.segments),
)
chapter_and_text_list.append((chapter.title, chapter_full_text))

transcript_text = "\n\n".join(
[
f"{self.chapter_title_template.format(name=chapter_title)}\n\n{chapter_text}"
for chapter_title, chapter_text in chapter_and_text_list
for chapter in chaptered_transcript.chapters
]

rendered_transcript = self._render(
title=chaptered_transcript.title,
chapters=chapters,
video_url=chaptered_transcript.url,
video_id=chaptered_transcript.video_id,
webpage_url_domain=chaptered_transcript.webpage_url_domain,
)
transcript_text = f"{self.video_title_template.format(name=chaptered_transcript.title)}\n\n{chaptered_transcript.url}\n\n{transcript_text}"

return interfaces.FormattedTranscript(
title=chaptered_transcript.title,
transcript=transcript_text,
rendered_transcript=rendered_transcript,
)

def format_chaptered_playlist_transcripts(
Expand Down
2 changes: 1 addition & 1 deletion src/yt2doc/formatting/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Chapter(BaseModel):

class FormattedTranscript(BaseModel):
title: str
transcript: str
rendered_transcript: str


class FormattedPlaylist(BaseModel):
Expand Down
21 changes: 21 additions & 0 deletions src/yt2doc/formatting/template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# {{ title }}

{{ video_url }}

{% if add_table_of_contents %}
### Table of contents
{% for chapter in chapters %}
- ({{ chapter.start_h_m_s }}) {{ chapter.title }}
{% endfor %}

{% endif %}
{% for chapter in chapters %}
## {{ chapter.title }} {{ '{' }}{{ chapter.custom_id }}{{ '}' }}

{% for paragraph in chapter.paragraphs %}
{% if to_timestamp_paragraphs %}({{ paragraph.start_h_m_s }}) {% endif %}{{ paragraph.text }}
{% if not loop.last %}

{% endif %}
{% endfor %}
{% endfor %}
10 changes: 5 additions & 5 deletions src/yt2doc/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def write_video_transcript(
formatted_transcript: formatting_interfaces.FormattedTranscript,
) -> None:
if output_target is None or output_target == "-":
print(formatted_transcript.transcript + "\n")
print(formatted_transcript.rendered_transcript + "\n")
return

output_path = Path(output_target)
Expand All @@ -40,7 +40,7 @@ def write_video_transcript(
file_path = output_path

with open(file_path, "w+") as f:
f.write(formatted_transcript.transcript)
f.write(formatted_transcript.rendered_transcript)

def write_playlist(
self,
Expand All @@ -52,7 +52,7 @@ def write_playlist(
print(
"\n\n".join(
[
transcript.transcript
transcript.rendered_transcript
for transcript in formatted_playlist.transcripts
]
)
Expand All @@ -69,13 +69,13 @@ def write_playlist(
output_dir=output_path, title=transcript.title
)
with open(file_path, "w+") as f:
f.write(transcript.transcript)
f.write(transcript.rendered_transcript)
else:
with open(output_path, "w+") as f:
f.write(
"\n\n".join(
[
transcript.transcript
transcript.rendered_transcript
for transcript in formatted_playlist.transcripts
]
)
Expand Down
31 changes: 18 additions & 13 deletions tests/integ/formatting/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def test_format_chaptered_transcript_basic(
sat = SaT("sat-3l")
paragraphs_segmenter = ParagraphsSegmenter(sat=sat)
formatter = MarkdownFormatter(
paragraphs_segmenter=paragraphs_segmenter, to_timestamp_paragraphs=False
paragraphs_segmenter=paragraphs_segmenter,
to_timestamp_paragraphs=False,
add_table_of_contents=False,
)

segments_dicts = [
Expand Down Expand Up @@ -130,15 +132,15 @@ def test_format_chaptered_transcript_basic(

# Assert
assert formatted_output.title == "Test Video Title"
assert "# Test Video Title" in formatted_output.transcript
assert "https://example.com/video" in formatted_output.transcript
assert "## Chapter 1" in formatted_output.transcript
assert "# Test Video Title" in formatted_output.rendered_transcript
assert "https://example.com/video" in formatted_output.rendered_transcript
assert "## Chapter 1" in formatted_output.rendered_transcript
assert (
"Hi class. So today I'll be talking about climate change"
in formatted_output.transcript
in formatted_output.rendered_transcript
)
assert formatted_output.transcript.count("\n\n") > 6
assert "(0:00:00)" not in formatted_output.transcript
assert formatted_output.rendered_transcript.count("\n\n") > 6
assert "(0:00:00)" not in formatted_output.rendered_transcript


def test_markdown_formatter_with_segmentation(
Expand Down Expand Up @@ -171,6 +173,7 @@ def mock_generate_title_for_paragraphs(
formatter = MarkdownFormatter(
paragraphs_segmenter=paragraphs_segmenter,
to_timestamp_paragraphs=False,
add_table_of_contents=False,
topic_segmenter=segmenter,
)

Expand Down Expand Up @@ -204,10 +207,10 @@ def mock_generate_title_for_paragraphs(
)

# Assert
assert "# Test Video Title" in formatted_output.transcript
assert "## Chapter Title" in formatted_output.transcript
assert formatted_output.transcript.count("\n\n") > 6
assert "(0:00:00)" not in formatted_output.transcript
assert "# Test Video Title" in formatted_output.rendered_transcript
assert "## Chapter Title" in formatted_output.rendered_transcript
assert formatted_output.rendered_transcript.count("\n\n") > 6
assert "(0:00:00)" not in formatted_output.rendered_transcript


def test_format_chaptered_transcript_timestamp_paragraphs(
Expand All @@ -217,7 +220,9 @@ def test_format_chaptered_transcript_timestamp_paragraphs(
sat = SaT("sat-3l")
paragraphs_segmenter = ParagraphsSegmenter(sat=sat)
formatter = MarkdownFormatter(
paragraphs_segmenter=paragraphs_segmenter, to_timestamp_paragraphs=True
paragraphs_segmenter=paragraphs_segmenter,
to_timestamp_paragraphs=True,
add_table_of_contents=False,
)

segments_dicts = [
Expand Down Expand Up @@ -251,4 +256,4 @@ def test_format_chaptered_transcript_timestamp_paragraphs(

# Assert
assert formatted_output.title == "Test Video Title"
assert "(0:00:00)" in formatted_output.transcript
assert "(0:00:00)" in formatted_output.rendered_transcript
Loading

0 comments on commit 1dd57df

Please sign in to comment.