Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add audformat.Scheme.replace_labels() #62

Merged
merged 11 commits into from
Apr 28, 2021
17 changes: 15 additions & 2 deletions audformat/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,16 @@ def __init__(
r"""Dictionary of media information"""
self.raters = HeaderDict(value_type=Rater)
r"""Dictionary of raters"""
self.schemes = HeaderDict(value_type=Scheme)
self.schemes = HeaderDict(
value_type=Scheme,
set_callback=self._set_scheme,
)
r"""Dictionary of schemes"""
self.splits = HeaderDict(value_type=Split)
r"""Dictionary of splits"""
self.tables = HeaderDict(
value_type=Table, set_callback=self._set_table,
value_type=Table,
set_callback=self._set_table,
)
r"""Dictionary of tables"""

Expand Down Expand Up @@ -733,6 +737,15 @@ def load_header_from_yaml(header: dict) -> 'Database':

return db

def _set_scheme(
self,
scheme_id: str,
scheme: Scheme,
) -> Scheme:
scheme._db = self
scheme._id = scheme_id
return scheme

def _set_table(
self,
table_id: str,
Expand Down
131 changes: 110 additions & 21 deletions audformat/core/scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,28 +81,16 @@ def __init__(
if dtype is None and labels is None:
dtype = define.DataType.STRING

if labels is not None and len(labels) > 0:
if not isinstance(labels, (dict, list)):
if labels is not None:
dtype_labels = self._dtype_from_labels(labels)
if dtype is not None and dtype != dtype_labels:
raise ValueError(
'Labels must be passed as a dictionary or a list.'
"Data type is set to "
f"'{dtype}', "
"but data type of labels is "
f"'{dtype_labels}'."
)
derived_dtype = type(list(labels)[0])
if not all(isinstance(x, derived_dtype) for x in list(labels)):
raise ValueError(
'All labels must be of the same data type.'
)
if derived_dtype in self._dtypes:
derived_dtype = self._dtypes[derived_dtype]
define.DataType.assert_has_attribute_value(derived_dtype)
if dtype is not None:
if dtype != derived_dtype:
raise ValueError(
"Data type is set to "
f"'{dtype}', "
"but data type of labels is "
f"'{derived_dtype}'."
)
dtype = derived_dtype
dtype = dtype_labels

self.dtype = dtype
r"""Data type"""
Expand All @@ -113,6 +101,9 @@ def __init__(
self.maximum = maximum if self.is_numeric else None
r"""Maximum value"""

self._db = None
self._id = None

@property
def is_numeric(self) -> bool:
r"""Check if data type is numeric.
Expand Down Expand Up @@ -199,7 +190,8 @@ def to_pandas_dtype(self) -> typing.Union[
# allow nullable
labels = pd.array(labels, dtype=pd.Int64Dtype())
return pd.api.types.CategoricalDtype(
categories=labels, ordered=False,
categories=labels,
ordered=False,
)
elif self.dtype == define.DataType.BOOL:
return 'boolean'
Expand All @@ -211,6 +203,103 @@ def to_pandas_dtype(self) -> typing.Union[
return 'timedelta64[ns]'
return self.dtype

def replace_labels(
self,
labels: typing.Union[dict, list],
):
r"""Replace labels.

If scheme is part of a :class:`audformat.Database`
the dtype of all :class:`audformat.Column` objects
that reference the scheme will be updated.
Removed labels are set to ``NaN``.

Args:
labels: new labels

Raises:
ValueError: if scheme does not define labels
ValueError: if dtype of new labels does not match dtype of
scheme

Example:
>>> speaker = Scheme(
... labels={
... 0: {'gender': 'female'},
... 1: {'gender': 'male'},
... }
... )
>>> speaker
dtype: int
labels:
0: {gender: female}
1: {gender: male}
>>> speaker.replace_labels(
... {
... 1: {'gender': 'male', 'age': 33},
... 2: {'gender': 'female', 'age': 44},
... }
... )
>>> speaker
dtype: int
labels:
1: {gender: male, age: 33}
2: {gender: female, age: 44}

"""
if self.labels is None:
raise ValueError(
'Cannot replace labels when '
'scheme does not define labels.'
)

dtype_labels = self._dtype_from_labels(labels)
if dtype_labels != self.dtype:
raise ValueError(
"Data type of labels must not change: \n"
f"'{self.dtype}' \n"
f"!=\n"
f"'{dtype_labels}'"
)

self.labels = labels

if self._db is not None and self._id is not None:
for table in self._db.tables.values():
for column in table.columns.values():
if column.scheme_id == self._id:
column.get(copy=False).cat.set_categories(
new_categories=self.labels,
ordered=False,
inplace=True,
)

def _dtype_from_labels(
self,
labels: typing.Union[dict, list],
) -> str:
r"""Derive dtype from labels."""

if not isinstance(labels, (dict, list)):
raise ValueError(
'Labels must be passed as a dictionary or a list.'
)

if len(labels) > 0:
dtype = type(list(labels)[0])
else:
dtype = 'str'
if not all(isinstance(x, dtype) for x in list(labels)):
raise ValueError(
'All labels must be of the same data type.'
)

if dtype in self._dtypes:
dtype = self._dtypes[dtype]
define.DataType.assert_has_attribute_value(dtype)

return dtype

def __contains__(self, item: typing.Any) -> bool:
r"""Check if scheme contains data type of item.

Expand Down
24 changes: 22 additions & 2 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,16 @@ def test_update():
db.meta['meta'] = 'meta'
db.raters['rater'] = audformat.Rater()
db.schemes['float'] = audformat.Scheme(float)
db.schemes['labels'] = audformat.Scheme(labels=['a', 'b'])
audformat.testing.add_table(
db,
'table',
audformat.define.IndexType.FILEWISE,
num_files=[0, 1],
columns={'float': ('float', 'rater')},
columns={
'float': ('float', 'rater'),
'labels': ('labels', None),
},
)

assert db.update(db) == db
Expand All @@ -329,12 +333,17 @@ def test_update():
other1.raters['rater2'] = audformat.Rater()
other1.schemes['int'] = audformat.Scheme(int)
other1.schemes['float'] = audformat.Scheme(float)
other1.schemes['labels'] = audformat.Scheme(labels=['b', 'c'])
audformat.testing.add_table(
other1,
'table',
audformat.define.IndexType.FILEWISE,
num_files=[1, 2],
columns={'int': ('int', 'rater'), 'float': ('float', 'rater2')},
columns={
'int': ('int', 'rater'),
'float': ('float', 'rater2'),
'labels': ('labels', None),
},
)

# database with new table
Expand All @@ -349,6 +358,17 @@ def test_update():
columns={'str': ('str', 'rater2')},
)

# raises error because schemes do not match

with pytest.raises(ValueError):
audformat.utils.concat(
[db['table'].df, other1['table'].df],
overwrite=True,
)

# replace labels to avoid error

db.schemes['labels'].replace_labels(other1.schemes['labels'].labels)
df = audformat.utils.concat(
[db['table'].df, other1['table'].df],
overwrite=True,
Expand Down
Loading