Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:pyannote/pyannote-core into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Oct 11, 2021
2 parents 937bcda + d96f0e1 commit 0e3986a
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 64 deletions.
30 changes: 19 additions & 11 deletions pyannote/core/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SlidingWindowFeature(np.lib.mixins.NDArrayOperatorsMixin):
"""

def __init__(
self, data: np.ndarray, sliding_window: SlidingWindow, labels: List[Text] = None
self, data: np.ndarray, sliding_window: SlidingWindow, labels: List[Text] = None
):
self.sliding_window: SlidingWindow = sliding_window
self.data = data
Expand Down Expand Up @@ -113,7 +113,7 @@ def next(self):
return self.__next__()

def iterfeatures(
self, window: Optional[bool] = False
self, window: Optional[bool] = False
) -> Iterator[Union[Tuple[np.ndarray, Segment], np.ndarray]]:
"""Feature vector iterator
Expand All @@ -132,11 +132,11 @@ def iterfeatures(
yield self.data[i]

def crop(
self,
focus: Union[Segment, Timeline],
mode: Alignment = "loose",
fixed: Optional[float] = None,
return_data: bool = True,
self,
focus: Union[Segment, Timeline],
mode: Alignment = "loose",
fixed: Optional[float] = None,
return_data: bool = True,
) -> Union[np.ndarray, "SlidingWindowFeature"]:
"""Extract frames
Expand Down Expand Up @@ -175,6 +175,10 @@ def crop(
)
raise ValueError(msg)

if (not return_data) and (fixed is not None):
msg = '"fixed" cannot be set when "return_data" is set to False.'
raise ValueError(msg)

ranges = self.sliding_window.crop(
focus, mode=mode, fixed=fixed, return_ranges=True
)
Expand All @@ -197,8 +201,9 @@ def crop(
# if all requested samples are out of bounds, skip
if end < 0 or start >= n_samples:
continue
# keep track of non-empty clipped ranges
clipped_ranges += [[max(start, 0), min(end, n_samples)]]
else:
# keep track of non-empty clipped ranges
clipped_ranges += [[max(start, 0), min(end, n_samples)]]

if clipped_ranges:
data = np.vstack([self.data[start:end, :] for start, end in clipped_ranges])
Expand All @@ -207,7 +212,7 @@ def crop(
shape = (0,) + self.data.shape[1:]
data = np.empty(shape)

# corner case when 'fixed' duration cropping is requested:
# corner case when "fixed" duration cropping is requested:
# correct number of samples even with out-of-bounds indices
if fixed is not None:
data = np.vstack(
Expand All @@ -228,19 +233,22 @@ def crop(

# wrap data in a SlidingWindowFeature and return
sliding_window = SlidingWindow(
start=self.sliding_window[ranges[0][0]].start,
start=self.sliding_window[clipped_ranges[0][0]].start,
duration=self.sliding_window.duration,
step=self.sliding_window.step,
)

return SlidingWindowFeature(data, sliding_window, labels=self.labels)

def _repr_png_(self):
from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING

if not MATPLOTLIB_IS_AVAILABLE:
warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
return None

from .notebook import repr_feature

return repr_feature(self)

_HANDLED_TYPES = (np.ndarray, numbers.Number)
Expand Down
169 changes: 116 additions & 53 deletions pyannote/core/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
pass
import numpy as np
from itertools import cycle, product, groupby
from .segment import Segment
from .segment import Segment, SlidingWindow
from .timeline import Timeline
from .annotation import Annotation
from .feature import SlidingWindowFeature
Expand All @@ -114,29 +114,29 @@
else:
MATPLOTLIB_IS_AVAILABLE = True

MATPLOTLIB_WARNING = \
"Couldn't import matplotlib to render the vizualization " \
"for object {klass}. To enable, install the required dependencies " \
MATPLOTLIB_WARNING = (
"Couldn't import matplotlib to render the vizualization "
"for object {klass}. To enable, install the required dependencies "
"with 'pip install pyannore.core[notebook]'"
)


class Notebook:

def __init__(self):
self.reset()

def reset(self):
from matplotlib.cm import get_cmap

linewidth = [3, 1]
linestyle = ['solid', 'dashed', 'dotted']
linestyle = ["solid", "dashed", "dotted"]

cm = get_cmap('Set1')
colors = [cm(1. * i / 8) for i in range(9)]
cm = get_cmap("Set1")
colors = [cm(1.0 * i / 8) for i in range(9)]

self._style_generator = cycle(product(linestyle, linewidth, colors))
self._style: Dict[Optional[Label], LabelStyle] = {
None: ('solid', 1, (0.0, 0.0, 0.0))
None: ("solid", 1, (0.0, 0.0, 0.0))
}
del self.crop
del self.width
Expand Down Expand Up @@ -168,17 +168,19 @@ def width(self):
self._width = 20

def __getitem__(self, label: Label) -> LabelStyle:
"""Get line style for a given label"""
if label not in self._style:
self._style[label] = next(self._style_generator)
return self._style[label]

def setup(self, ax=None, ylim=(0, 1), yaxis=False, time=True):
import matplotlib.pyplot as plt

if ax is None:
ax = plt.gca()
ax.set_xlim(self.crop)
if time:
ax.set_xlabel('Time')
ax.set_xlabel("Time")
else:
ax.set_xticklabels([])
ax.set_ylim(ylim)
Expand All @@ -194,13 +196,22 @@ def draw_segment(self, ax, segment: Segment, y, label=None, boundaries=True):
linestyle, linewidth, color = self[label]

# draw segment
ax.hlines(y, segment.start, segment.end, color,
linewidth=linewidth, linestyle=linestyle, label=label)
ax.hlines(
y,
segment.start,
segment.end,
color,
linewidth=linewidth,
linestyle=linestyle,
label=label,
)
if boundaries:
ax.vlines(segment.start, y + 0.05, y - 0.05,
color, linewidth=1, linestyle='solid')
ax.vlines(segment.end, y + 0.05, y - 0.05,
color, linewidth=1, linestyle='solid')
ax.vlines(
segment.start, y + 0.05, y - 0.05, color, linewidth=1, linestyle="solid"
)
ax.vlines(
segment.end, y + 0.05, y - 0.05, color, linewidth=1, linestyle="solid"
)

if label is None:
return
Expand Down Expand Up @@ -246,13 +257,11 @@ def get_y(self, segments: Iterable[Segment]) -> np.ndarray:
up_to.append(segment.end)

# from line numbers to actual y coordinates
y = 1. - 1. / (len(up_to) + 1) * (1 + np.array(y))
y = 1.0 - 1.0 / (len(up_to) + 1) * (1 + np.array(y))

return y

def __call__(self, resource: Resource,
time: bool = True,
legend: bool = True):
def __call__(self, resource: Resource, time: bool = True, legend: bool = True):

if isinstance(resource, Segment):
self.plot_segment(resource, time=time)
Expand All @@ -279,7 +288,7 @@ def plot_timeline(self, timeline: Timeline, ax=None, time=True):
if not self.crop and timeline:
self.crop = timeline.extent()

cropped = timeline.crop(self.crop, mode='loose')
cropped = timeline.crop(self.crop, mode="loose")

ax = self.setup(ax=ax, time=time)

Expand All @@ -293,15 +302,15 @@ def plot_annotation(self, annotation: Annotation, ax=None, time=True, legend=Tru
if not self.crop:
self.crop = annotation.get_timeline(copy=False).extent()

cropped = annotation.crop(self.crop, mode='intersection')
cropped = annotation.crop(self.crop, mode="intersection")
labels = cropped.labels()
segments = [s for s, _ in cropped.itertracks()]

ax = self.setup(ax=ax, time=time)

for (segment, track, label), y in zip(
cropped.itertracks(yield_label=True),
self.get_y(segments)):
cropped.itertracks(yield_label=True), self.get_y(segments)
):
self.draw_segment(ax, segment, y, label=label)

if legend:
Expand All @@ -313,28 +322,36 @@ def plot_annotation(self, annotation: Annotation, ax=None, time=True, legend=Tru

# this gets exactly one legend handle and one legend label per label
# (avoids repeated legends for repeated tracks with same label)
HL = groupby(sorted(zip(H, L), key=lambda h_l: h_l[1]),
key=lambda h_l: h_l[1])
HL = groupby(
sorted(zip(H, L), key=lambda h_l: h_l[1]), key=lambda h_l: h_l[1]
)
H, L = zip(*list((next(h_l)[0], l) for l, h_l in HL))
ax.legend(H, L, bbox_to_anchor=(0, 1), loc=3,
ncol=5, borderaxespad=0., frameon=False)

def plot_feature(self, feature: SlidingWindowFeature,
ax=None, time=True, ylim=None):
ax.legend(
H,
L,
bbox_to_anchor=(0, 1),
loc=3,
ncol=5,
borderaxespad=0.0,
frameon=False,
)

def plot_feature(
self, feature: SlidingWindowFeature, ax=None, time=True, ylim=None
):

if not self.crop:
self.crop = feature.getExtent()

window = feature.sliding_window
n, dimension = feature.data.shape
(start, stop), = window.crop(self.crop, mode='loose',
return_ranges=True)
((start, stop),) = window.crop(self.crop, mode="loose", return_ranges=True)
xlim = (window[start].middle, window[stop].middle)

start = max(0, start)
stop = min(stop, n)
t = window[0].middle + window.step * np.arange(start, stop)
data = feature[start: stop]
data = feature[start:stop]

if ylim is None:
m = np.nanmin(data)
Expand All @@ -352,50 +369,96 @@ def plot_feature(self, feature: SlidingWindowFeature,
def repr_segment(segment: Segment):
"""Get `png` data for `segment`"""
import matplotlib.pyplot as plt
figsize = plt.rcParams['figure.figsize']
plt.rcParams['figure.figsize'] = (notebook.width, 1)

figsize = plt.rcParams["figure.figsize"]
plt.rcParams["figure.figsize"] = (notebook.width, 1)
fig, ax = plt.subplots()
notebook.plot_segment(segment, ax=ax)
data = print_figure(fig, 'png')
data = print_figure(fig, "png")
plt.close(fig)
plt.rcParams['figure.figsize'] = figsize
plt.rcParams["figure.figsize"] = figsize
return data


def repr_timeline(timeline: Timeline):
"""Get `png` data for `timeline`"""
import matplotlib.pyplot as plt
figsize = plt.rcParams['figure.figsize']
plt.rcParams['figure.figsize'] = (notebook.width, 1)

figsize = plt.rcParams["figure.figsize"]
plt.rcParams["figure.figsize"] = (notebook.width, 1)
fig, ax = plt.subplots()
notebook.plot_timeline(timeline, ax=ax)
data = print_figure(fig, 'png')
data = print_figure(fig, "png")
plt.close(fig)
plt.rcParams['figure.figsize'] = figsize
plt.rcParams["figure.figsize"] = figsize
return data


def repr_annotation(annotation: Annotation):
"""Get `png` data for `annotation`"""
import matplotlib.pyplot as plt
figsize = plt.rcParams['figure.figsize']
plt.rcParams['figure.figsize'] = (notebook.width, 2)

figsize = plt.rcParams["figure.figsize"]
plt.rcParams["figure.figsize"] = (notebook.width, 2)
fig, ax = plt.subplots()
notebook.plot_annotation(annotation, ax=ax)
data = print_figure(fig, 'png')
data = print_figure(fig, "png")
plt.close(fig)
plt.rcParams['figure.figsize'] = figsize
plt.rcParams["figure.figsize"] = figsize
return data


def repr_feature(feature: SlidingWindowFeature):
"""Get `png` data for `feature`"""
import matplotlib.pyplot as plt
figsize = plt.rcParams['figure.figsize']
plt.rcParams['figure.figsize'] = (notebook.width, 2)
fig, ax = plt.subplots()
notebook.plot_feature(feature, ax=ax)
data = print_figure(fig, 'png')
plt.close(fig)
plt.rcParams['figure.figsize'] = figsize

figsize = plt.rcParams["figure.figsize"]

if feature.data.ndim == 2:

plt.rcParams["figure.figsize"] = (notebook.width, 2)
fig, ax = plt.subplots()
notebook.plot_feature(feature, ax=ax)
data = print_figure(fig, "png")
plt.close(fig)

elif feature.data.ndim == 3:

num_chunks = len(feature)

if notebook.crop is None:
notebook.crop = Segment(
start=feature.sliding_window.start,
end=feature.sliding_window[num_chunks - 1].end,
)
else:
feature = feature.crop(notebook.crop, mode="loose", return_data=False)

num_overlap = (
round(feature.sliding_window.duration // feature.sliding_window.step) + 1
)

num_overlap = min(num_chunks, num_overlap)

plt.rcParams["figure.figsize"] = (notebook.width, 1.5 * num_overlap)

fig, axes = plt.subplots(nrows=num_overlap, ncols=1,)
mini, maxi = np.nanmin(feature.data), np.nanmax(feature.data)
ylim = (mini - 0.2 * (maxi - mini), maxi + 0.2 * (maxi - mini))
for c, (window, data) in enumerate(feature):
ax = axes[c % num_overlap]
step = duration = window.duration / len(data)
frames = SlidingWindow(start=window.start, step=step, duration=duration)
window_feature = SlidingWindowFeature(data, frames, labels=feature.labels)
notebook.plot_feature(
window_feature,
ax=ax,
time=c % num_overlap == (num_overlap - 1),
ylim=ylim,
)
ax.set_prop_cycle(None)
data = print_figure(fig, "png")
plt.close(fig)

plt.rcParams["figure.figsize"] = figsize
return data

0 comments on commit 0e3986a

Please sign in to comment.