Skip to content

Commit

Permalink
Make interval in SyncMapFragment required
Browse files Browse the repository at this point in the history
  • Loading branch information
naglis committed Nov 14, 2024
1 parent 0e1baa6 commit c1ccc38
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 126 deletions.
16 changes: 8 additions & 8 deletions aeneas/adjustboundaryalgorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,15 @@ def intervals_to_fragment_list(self, text_file, time_values):
)
self.log(" Creating HEAD fragment")
self.smflist.add(
SyncMapFragment(
SyncMapFragment.from_begin_end(
begin=time_values[0],
end=time_values[1],
# NOTE lines and filtered lines MUST be set,
# otherwise some output format might break
# when adding HEAD/TAIL to output
text_fragment=TextFragment(
identifier="HEAD", lines=[], filtered_lines=[]
),
begin=time_values[0],
end=time_values[1],
fragment_type=FragmentType.HEAD,
),
sort=False,
Expand All @@ -361,26 +361,26 @@ def intervals_to_fragment_list(self, text_file, time_values):
for i in range(1, len(time_values) - 2):
self.log([" Adding fragment %d ...", i])
self.smflist.add(
SyncMapFragment(
text_fragment=fragments[i - 1],
SyncMapFragment.from_begin_end(
begin=time_values[i],
end=time_values[i + 1],
text_fragment=fragments[i - 1],
fragment_type=FragmentType.REGULAR,
),
sort=False,
)
self.log([" Adding fragment %d ... done", i])
self.log(" Creating TAIL fragment")
self.smflist.add(
SyncMapFragment(
SyncMapFragment.from_begin_end(
begin=time_values[len(time_values) - 2],
end=end,
# NOTE lines and filtered lines MUST be set,
# otherwise some output format might break
# when adding HEAD/TAIL to output
text_fragment=TextFragment(
identifier="TAIL", lines=[], filtered_lines=[]
),
begin=time_values[len(time_values) - 2],
end=end,
fragment_type=FragmentType.TAIL,
),
sort=False,
Expand Down
5 changes: 3 additions & 2 deletions aeneas/syncmap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def output_html_for_tuning(

def read(
self,
sync_map_format: SyncMapFormat,
sync_map_format: str,
input_file_path: str,
parameters: dict | None = None,
):
Expand Down Expand Up @@ -423,7 +423,8 @@ def read(
if language is not None:
self.log(["Overwriting language to '%s'", language])
for fragment in self.fragments:
fragment.text_fragment.language = language
if fragment.text_fragment is not None:
fragment.text_fragment.language = language

def write(
self,
Expand Down
59 changes: 28 additions & 31 deletions aeneas/syncmap/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,34 @@ class SyncMapFragment:

def __init__(
self,
interval: TimeInterval,
*,
text_fragment: TextFragment | None = None,
interval: TimeInterval | None = None,
begin: TimeValue | None = None,
end: TimeValue | None = None,
fragment_type: FragmentType = FragmentType.REGULAR,
confidence: float = 1.0,
):
self.text_fragment = text_fragment
if interval is not None:
self.interval = interval
elif begin is not None and end is not None:
self.interval = TimeInterval(begin, end)
else:
self.interval = None
self.interval = interval
self.fragment_type = fragment_type
self.confidence = confidence

@classmethod
def from_begin_end(
cls,
begin: TimeValue,
end: TimeValue,
*,
text_fragment: TextFragment | None = None,
fragment_type: FragmentType = FragmentType.REGULAR,
confidence: float = 1.0,
) -> "SyncMapFragment":
return cls(
interval=TimeInterval(begin, end),
text_fragment=text_fragment,
fragment_type=fragment_type,
confidence=confidence,
)

def __str__(self):
return "%s %d %.3f %.3f" % (
self.text_fragment.identifier,
Expand Down Expand Up @@ -124,14 +135,14 @@ def text_fragment(self, text_fragment: TextFragment | None):
self.__text_fragment = text_fragment

@property
def interval(self) -> TimeInterval | None:
def interval(self) -> TimeInterval:
"""
The time interval corresponding to this fragment.
"""
return self.__interval

@interval.setter
def interval(self, interval: TimeInterval | None):
def interval(self, interval: TimeInterval):
self.__interval = interval

@property
Expand Down Expand Up @@ -185,10 +196,10 @@ def pretty_print(self) -> str:
.. versionadded:: 1.7.0
"""
return "{}\t{:.3f}\t{:.3f}\t{}".format(
(self.identifier or ""),
(self.begin if self.begin is not None else TimeValue("-2.000")),
(self.end if self.end is not None else TimeValue("-1.000")),
(self.text or ""),
self.identifier or "",
self.interval.begin,
self.interval.end,
self.text or "",
)

@property
Expand All @@ -214,37 +225,27 @@ def text(self) -> str | None:
return self.text_fragment.text

@property
def begin(self) -> TimeValue | None:
def begin(self) -> TimeValue:
"""
The begin time of this sync map fragment.
"""
if self.interval is None:
return None
return self.interval.begin

@begin.setter
def begin(self, begin: TimeValue):
if self.interval is None:
raise TypeError("Attempting to set begin when interval is None")
if not isinstance(begin, TimeValue):
raise TypeError("The given begin value is not an instance of TimeValue")
self.interval.begin = begin

@property
def end(self) -> TimeValue | None:
def end(self) -> TimeValue:
"""
The end time of this sync map fragment.
:rtype: :class:`~aeneas.exacttiming.TimeValue`
"""
if self.interval is None:
return None
return self.interval.end

@end.setter
def end(self, end: TimeValue):
if self.interval is None:
raise TypeError("Attempting to set end when interval is None")
if not isinstance(end, TimeValue):
raise TypeError("The given end value is not an instance of TimeValue")
self.interval.end = end
Expand All @@ -254,11 +255,7 @@ def length(self) -> TimeValue:
"""
The audio duration of this sync map fragment,
as end time minus begin time.
:rtype: :class:`~aeneas.exacttiming.TimeValue`
"""
if self.interval is None:
return TimeValue("0.000")
return self.interval.length

@property
Expand Down
12 changes: 8 additions & 4 deletions aeneas/syncmap/fragmentlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def move_transition_point(self, fragment_index: int, value: TimeValue):

def fragments_ending_inside_nonspeech_intervals(
self, nonspeech_intervals: list[TimeInterval], tolerance: TimeValue
) -> list[tuple[TimeInterval, int]]:
) -> list[tuple[TimeInterval | None, int]]:
"""
Determine a list of pairs (nonspeech interval, fragment index),
such that the nonspeech interval contains exactly one fragment
Expand All @@ -463,7 +463,9 @@ def fragments_ending_inside_nonspeech_intervals(
self.log([" List end: %.3f", self.end])
nsi_index = 0
frag_index = 0
nsi_counter = [(n, []) for n in nonspeech_intervals]
nsi_counter: list[tuple[TimeInterval | None, list[int]]] = [
(n, []) for n in nonspeech_intervals
]
# NOTE the last fragment is not eligible to be returned
while (nsi_index < len(nonspeech_intervals)) and (frag_index < len(self) - 1):
nsi = nonspeech_intervals[nsi_index]
Expand Down Expand Up @@ -569,13 +571,13 @@ def inject_long_nonspeech_fragments(
identifier = "n%06d" % i
self.add(
SyncMapFragment(
interval=nsi,
text_fragment=TextFragment(
identifier=identifier,
language=None,
lines=lines,
filtered_lines=lines,
),
interval=nsi,
fragment_type=FragmentType.NONSPEECH,
),
sort=False,
Expand Down Expand Up @@ -664,7 +666,9 @@ def fix_zero_length_fragments(
self[i].interval,
]
)
moves = [(i, "ENLARGE", duration)]
moves: list[tuple[int, str, TimeValue | None]] = [
(i, "ENLARGE", duration)
]
slack = duration
j = i + 1
self.log([" Entered while with j == %d", j])
Expand Down
4 changes: 2 additions & 2 deletions aeneas/syncmap/smfbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _add_fragment(
cls,
syncmap: "SyncMap",
identifier: str,
lines: typing.Sequence[str],
lines: list[str],
begin: TimeValue,
end: TimeValue,
language: Language | None = None,
Expand All @@ -88,7 +88,7 @@ def _add_fragment(
:type language: string
"""
syncmap.add_fragment(
SyncMapFragment(
SyncMapFragment.from_begin_end(
text_fragment=TextFragment(
identifier=identifier, lines=lines, language=language
),
Expand Down
14 changes: 8 additions & 6 deletions aeneas/tests/test_syncmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ class TestSyncMap(BaseSyncMapCase):
NOT_EXISTING_SRT = gf.absolute_path("not_existing.srt", __file__)
EXISTING_SRT = gf.absolute_path("res/syncmaps/sonnet001.srt", __file__)
NOT_WRITEABLE_SRT = gf.absolute_path("x/y/z/not_writeable.srt", __file__)
EMPTY_INTERVAL = TimeInterval(begin=TimeValue("0.000"), end=TimeValue("0.000"))

def build_tree_from_intervals(
self, intervals: typing.Sequence[tuple[str, str]]
) -> Tree:
tree = Tree()
for begin, end in intervals:
interval = TimeInterval(begin=TimeValue(begin), end=TimeValue(end))
smf = SyncMapFragment(interval=interval)
smf = SyncMapFragment.from_begin_end(
begin=TimeValue(begin), end=TimeValue(end)
)
child = Tree(value=smf)
tree.add_child(child, as_last=True)

Expand Down Expand Up @@ -110,7 +112,7 @@ def test_fragments_tree_empty(self):
self.assertEqual(len(syn.fragments_tree), 0)

def test_fragments_tree_not_empty(self):
smf = SyncMapFragment()
smf = SyncMapFragment(interval=self.EMPTY_INTERVAL)
child = Tree(value=smf)
tree = Tree()
tree.add_child(child)
Expand All @@ -122,17 +124,17 @@ def test_is_single_level_true_empty(self):
self.assertTrue(syn.is_single_level)

def test_is_single_level_true_not_empty(self):
smf = SyncMapFragment()
smf = SyncMapFragment(interval=self.EMPTY_INTERVAL)
child = Tree(value=smf)
tree = Tree()
tree.add_child(child)
syn = SyncMap(tree=tree)
self.assertTrue(syn.is_single_level)

def test_is_single_level_false(self):
smf2 = SyncMapFragment()
smf2 = SyncMapFragment(interval=self.EMPTY_INTERVAL)
child2 = Tree(value=smf2)
smf = SyncMapFragment()
smf = SyncMapFragment(interval=self.EMPTY_INTERVAL)
child = Tree(value=smf)
child.add_child(child2)
tree = Tree()
Expand Down
Loading

0 comments on commit c1ccc38

Please sign in to comment.