Skip to content

Commit

Permalink
Merge pull request #9 from desh2608/spyder
Browse files Browse the repository at this point in the history
  • Loading branch information
desh2608 authored Mar 6, 2021
2 parents adfa0f7 + 66090a3 commit 6896e97
Show file tree
Hide file tree
Showing 12 changed files with 629 additions and 3,363 deletions.
66 changes: 40 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ Official implementation for [DOVER-Lap: A method for combining overlap-aware dia

## Installation

To install, simply run:

```shell
pip install dover-lap
```
Expand All @@ -31,41 +29,56 @@ Usage: dover-lap [OPTIONS] OUTPUT_RTTM [INPUT_RTTMS]...
Apply the DOVER-Lap algorithm on the input RTTM files.

Options:
--custom-weight TEXT Weights for input RTTMs
--dover-weight FLOAT DOVER weighting factor [default: 0.1]
--weight-type [rank|custom] Specify whether to use rank weighting or
provide custom weights [default: rank]

--tie-breaking [uniform|all] Specify whether to assign tied regions to all
speakers or divide uniformly [default: all]

--second-maximal If this flag is set, run a second iteration of
the maximal matching for label mapping. It may
give better results sometimes. [default:
False]

-c, --channel INTEGER Use this value for output channel IDs
[default: 1]

-u, --uem-file PATH UEM file path
--help Show this message and exit.
--custom-weight TEXT Weights for input RTTMs
--dover-weight FLOAT DOVER weighting factor [default: 0.1]
--weight-type [rank|custom] Specify whether to use rank weighting or
provide custom weights [default: rank]

--tie-breaking [uniform|all] Specify whether to assign tied regions to
all speakers or divide uniformly [default:
all]

--second-maximal If this flag is set, run a second iteration
of the maximal matching for greedy label
mapping [default: False]

--sort-first If this flag is set, sort inputs by DER
first before label mapping (only applicable
when label mapping type is hungarian)
[default: False]

--label-mapping [hungarian|greedy]
Choose label mapping algorithm to use
[default: greedy]

--random-seed INTEGER
-c, --channel INTEGER Use this value for output channel IDs
[default: 1]

-u, --uem-file PATH UEM file path
--help Show this message and exit.
```

**Note:** If `--weight-type custom` is used, then `--custom-weight` must be provided.
For example:
**Note:**

1. If `--weight-type custom` is used, then `--custom-weight` must be provided. For example:

```shell
dover-lap egs/ami/rttm_dl_test egs/ami/rttm_test_* --weight-type custom --custom-weight '[0.4,0.3,0.3]'
```

2. `label-mapping` can be set to `greedy` (default) or `hungarian`, which was the mapping
technique originally proposed in [DOVER](https://arxiv.org/abs/1909.08090).

## Results

We provide a sample result on the AMI mix-headset test set. The results can be
obtained as follows:
obtained using [`spyder`](https://github.com/desh2608/spyder), which is automatically
installed with `dover-lap`:

```shell
dover-lap egs/ami/rttm_dl_test egs/ami/rttm_test_*
md-eval.pl -r egs/ami/ref_rttm_test -s egs/ami/rttm_dl_test
spyder egs/ami/ref_rttm_test egs/ami/rttm_dl_test
```

and similarly for the input hypothesis. The DER results are shown below.
Expand All @@ -75,9 +88,10 @@ and similarly for the input hypothesis. The DER results are shown below.
| Overlap-aware VB resegmentation | 9.84 | **2.06** | 9.60 | 21.50 |
| Overlap-aware spectral clustering | 11.48 | 2.27 | 9.81 | 23.56 |
| Region Proposal Network | **9.49** | 7.68 | 8.25 | 25.43 |
| DOVER-Lap | 9.71 | 3.00 | **7.59** | **20.30** |
| DOVER-Lap (Hungarian mapping) | 9.81 | 2.80 | 8.10 | 20.70 |
| DOVER-Lap (Greedy mapping) | 9.71 | 3.02 | **7.68** | **20.40** |


**Note:** A version of md-eval.pl can be found in `dover_lap/libs`.

## Running time

Expand Down
134 changes: 94 additions & 40 deletions dover_lap/dover_lap.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,27 @@
"""
import sys
import click
import random
import numpy as np

from typing import List

from dover_lap.libs.rttm import load_rttm, write_rttm
from dover_lap.libs.turn import merge_turns, trim_turns, Turn
from dover_lap.libs.uem import load_uem
from dover_lap.libs.utils import error, info, warn, groupby, \
command_required_option, PythonLiteralOption
from dover_lap.libs.utils import (
error,
info,
warn,
groupby,
command_required_option,
PythonLiteralOption,
)

from dover_lap.src.doverlap import DOVERLap


def load_rttms(
rttm_list: List[str]
) -> List[List[Turn]]:
def load_rttms(rttm_list: List[str]) -> List[List[Turn]]:
"""Loads speaker turns from input RTTMs in a list of turns."""
turns_list = []
file_ids = []
Expand All @@ -40,31 +46,84 @@ def load_rttms(
return turns_list


@click.argument('input_rttms', nargs=-1, type=click.Path(exists=True))
@click.argument('output_rttm', nargs=1, type=click.Path())
@click.option('-u', '--uem-file', type=click.Path(), help='UEM file path')
@click.option('-c', '--channel', type=int, default=1, show_default=True,
help='Use this value for output channel IDs')
@click.option('--second-maximal', is_flag=True, default=False, show_default=True,
help='If this flag is set, run a second iteration of the maximal matching for'
' label mapping. It may give better results sometimes.')
@click.option('--tie-breaking', type=click.Choice(['uniform','all']), default='all',
help='Specify whether to assign tied regions to all speakers or divide uniformly', show_default=True)
@click.option('--weight-type', type=click.Choice(['rank','custom']), default='rank',
help='Specify whether to use rank weighting or provide custom weights', show_default=True)
@click.option('--dover-weight', type=float, default=0.1, help='DOVER weighting factor', show_default=True)
@click.option('--custom-weight', cls=PythonLiteralOption, help='Weights for input RTTMs')
@click.command(cls=command_required_option('weight_type',
{'custom':'custom_weight','rank':'dover_weight'}))
@click.argument("input_rttms", nargs=-1, type=click.Path(exists=True))
@click.argument("output_rttm", nargs=1, type=click.Path())
@click.option("-u", "--uem-file", type=click.Path(), help="UEM file path")
@click.option(
"-c",
"--channel",
type=int,
default=1,
show_default=True,
help="Use this value for output channel IDs",
)
@click.option("--random-seed", type=int, default=0)
@click.option(
"--label-mapping",
type=click.Choice(["hungarian", "greedy"]),
default="greedy",
show_default=True,
help="Choose label mapping algorithm to use",
)
@click.option(
"--sort-first",
is_flag=True,
default=False,
show_default=True,
help="If this flag is set, sort inputs by DER first before label mapping "
"(only applicable when label mapping type is hungarian)",
)
@click.option(
"--second-maximal",
is_flag=True,
default=False,
show_default=True,
help="If this flag is set, run a second iteration of the maximal matching for"
" greedy label mapping",
)
@click.option(
"--tie-breaking",
type=click.Choice(["uniform", "all"]),
default="all",
help="Specify whether to assign tied regions to all speakers or divide uniformly",
show_default=True,
)
@click.option(
"--weight-type",
type=click.Choice(["rank", "custom"]),
default="rank",
help="Specify whether to use rank weighting or provide custom weights",
show_default=True,
)
@click.option(
"--dover-weight",
type=float,
default=0.1,
help="DOVER weighting factor",
show_default=True,
)
@click.option(
"--custom-weight", cls=PythonLiteralOption, help="Weights for input RTTMs"
)
@click.command(
cls=command_required_option(
"weight_type", {"custom": "custom_weight", "rank": "dover_weight"}
)
)
def main(
input_rttms: List[click.Path],
output_rttm: click.Path,
uem_file: click.Path,
channel: int,
**kwargs # these are passed directly to combine_turns_list() method
random_seed: int,
**kwargs, # these are passed directly to combine_turns_list() method
) -> None:
"""Apply the DOVER-Lap algorithm on the input RTTM files."""


# Set random seeds globally
random.seed(random_seed)
np.random.seed(random_seed)

# Load hypothesis speaker turns.
info("Loading speaker turns from input RTTMs...", file=sys.stderr)
turns_list = load_rttms(input_rttms)
Expand All @@ -74,16 +133,14 @@ def main(
uem = load_uem(uem_file)

# Trim turns to UEM scoring regions and merge any that overlap.
info("Trimming reference speaker turns to UEM scoring regions...", file=sys.stderr)
turns_list = [
trim_turns(turns, uem) for turns in turns_list
]
info(
"Trimming reference speaker turns to UEM scoring regions...",
file=sys.stderr,
)
turns_list = [trim_turns(turns, uem) for turns in turns_list]

info("Merging overlapping speaker turns...", file=sys.stderr)
turns_list = [
merge_turns(turns) for turns in turns_list
]
turns_list = [merge_turns(turns) for turns in turns_list]

file_to_turns_list = dict()
for turns in turns_list:
Expand All @@ -98,15 +155,12 @@ def main(
for file_id in file_to_turns_list:
info("Processing file {}..".format(file_id), file=sys.stderr)
turns_list = file_to_turns_list[file_id]
random.shuffle(
turns_list
) # We shuffle so that the hypothesis order is randomized
file_to_out_turns[file_id] = DOVERLap.combine_turns_list(
turns_list,
file_id,
**kwargs
turns_list, file_id, **kwargs
)

# Write output RTTM file
write_rttm(
output_rttm,
sum(list(file_to_out_turns.values()), []),
channel=channel
)
write_rttm(output_rttm, sum(list(file_to_out_turns.values()), []), channel=channel)
Loading

0 comments on commit 6896e97

Please sign in to comment.