Skip to content

Commit

Permalink
Merge pull request #117 from m42k0/main
Browse files Browse the repository at this point in the history
Add command line argument for number of speakers for diarization
  • Loading branch information
jordimas authored Dec 13, 2024
2 parents 5df6aef + 10fec03 commit f0be40a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
7 changes: 7 additions & 0 deletions src/whisper_ctranslate2/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,13 @@ def read_command_line():
help="Name to use to identify the speaker (e.g. SPEAKER_00).",
)

diarization_args.add_argument(
"--speaker_num",
type=int,
default=2,
help="Number of speakers to use for diarization.",
)

live_args = parser.add_argument_group("Live transcribe options")

live_args.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion src/whisper_ctranslate2/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ def __init__(
self,
use_auth_token=None,
device: str = "cpu",
num_speakers=2,
):
self.device = device
self.use_auth_token = use_auth_token
self.num_speakers = num_speakers

def set_threads(self, threads):
torch.set_num_threads(threads)
Expand Down Expand Up @@ -51,7 +53,7 @@ def run_model(self, audio: str):
"waveform": torch.from_numpy(audio[None, :]),
"sample_rate": 16000,
}
segments = self.model(audio_data)
segments = self.model(audio_data, num_speakers=self.num_speakers)
return segments

def assign_speakers_to_segments(self, segments, transcript_result, speaker_name):
Expand Down
5 changes: 4 additions & 1 deletion src/whisper_ctranslate2/whisper_ctranslate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def main():
live_input_device_sample_rate: int = args.pop("live_input_device_sample_rate")
hf_token = args.pop("hf_token")
speaker_name = args.pop("speaker_name")
speaker_num = args.pop("speaker_num")
batched = args.pop("batched")
batch_size = args.pop("batch_size")

Expand Down Expand Up @@ -233,7 +234,9 @@ def main():
from .diarization import Diarization

diarization_device = "cpu" if device == "auto" else device
diarize_model = Diarization(use_auth_token=hf_token, device=diarization_device)
diarize_model = Diarization(
use_auth_token=hf_token, device=diarization_device, num_speakers=speaker_num
)
if threads > 0:
diarize_model.set_threads(threads)

Expand Down

0 comments on commit f0be40a

Please sign in to comment.