forked from kaixxx/noScribe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiarize.py
112 lines (89 loc) · 4.41 KB
/
diarize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# noScribe - AI-powered Audio Transcription
# Copyright (C) 2023 Kai Dröge
# ported to MAC by Philipp Schneider (gernophil)
# Diarization with PyAnnote (https://github.com/pyannote/pyannote-audio)
# usage: python diarize.py <device['cpu', 'mps']> <audio file> <output yaml-file>
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import os
import platform
import yaml
from pyannote.audio import Pipeline
import torch
from typing import Any, Mapping, Optional, Text
import sys
from pathlib import Path
from tempfile import TemporaryDirectory
app_dir = os.path.abspath(os.path.dirname(__file__))
device = sys.argv[1]
audio_file = sys.argv[2]
segments_yaml = sys.argv[3]
class SimpleProgressHook:
#Hook to show progress of each internal step
def __init__(self, parent, transient: bool = False):
super().__init__()
self.parent = parent
self.transient = transient
def __enter__(self):
self.progress = 0
return self
def __exit__(self, *args):
pass
def __call__(
self,
step_name: Text,
step_artifact: Any,
file: Optional[Mapping] = None,
total: Optional[int] = None,
completed: Optional[int] = None,
):
if completed is None:
completed = total = 1
if not hasattr(self, 'step_name') or step_name != self.step_name:
self.step_name = step_name
progress_percent = int(completed/total*100)
if progress_percent > 100:
progress_percent = 100
print(f'progress {step_name} {progress_percent}', flush=True)
# Start Diarization:
try:
if platform.system() == 'Windows':
pipeline = Pipeline.from_pretrained(os.path.join(app_dir, 'models', 'pyannote_config.yaml'))
pipeline.to(torch.device(device))
elif platform.system() in ("Darwin", "Linux"): # = MAC
if device == 'mps' and not torch.backends.mps.is_available(): # should only happen on x86_64, but checked on all archs to be sure
device = 'cpu'
print("log: 'pyannote_xpu: mps' was selected, but mps is not available on this system!")
print("log: This happens, because availability cannot be checked earlier.")
print("log: 'pyannote_xpu: cpu' was set.") # The string needs to be the same as in noScribe.py `if line.strip() == "log: 'pyannote_xpu: cpu' was set.":`.
with open(os.path.join(app_dir, 'models', 'pyannote_config.yaml'), 'r') as yaml_file:
pyannote_config = yaml.safe_load(yaml_file)
pyannote_config['pipeline']['params']['embedding'] = os.path.join(app_dir, *pyannote_config['pipeline']['params']['embedding'].split("/")[1:])
pyannote_config['pipeline']['params']['segmentation'] = os.path.join(app_dir, *pyannote_config['pipeline']['params']['segmentation'].split("/")[1:])
tmpdir = TemporaryDirectory('noScribe_diarize')
with open(os.path.join(tmpdir.name, 'pyannote_config_macOS.yaml'), 'w') as yaml_file:
yaml.safe_dump(pyannote_config, yaml_file)
pipeline = Pipeline.from_pretrained(os.path.join(tmpdir.name, 'pyannote_config_macOS.yaml'))
pipeline.to(torch.device(device))
else:
raise Exception('Platform not supported yet.')
with SimpleProgressHook(parent=None) as hook:
diarization = pipeline(audio_file, hook=hook) # apply the pipeline to the audio file
seg_list = []
for segment, _, label in diarization.itertracks(yield_label=True):
seg_list.append({'start': int(segment.start * 1000),
'end': int((segment.start + segment.duration) * 1000),
'label': label})
with open(segments_yaml, 'w') as filestream:
yaml.safe_dump(seg_list, filestream)
except Exception as e:
print('error ', e, file=sys.stderr)
sys.exit(1) # return error code