Skip to content

Commit

Permalink
Updated degradations to use logging. #132
Browse files Browse the repository at this point in the history
  • Loading branch information
apmcleod committed Aug 30, 2020
1 parent 21fdef5 commit 6dfe6e7
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 353 deletions.
71 changes: 22 additions & 49 deletions mdtk/degradations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Code to perform the degradations i.e. edits to the midi data"""
import logging
import sys
import warnings

import numpy as np
import pandas as pd
Expand All @@ -22,7 +22,7 @@
TRIES_DEFAULT = 10

TRIES_WARN_MSG = (
"WARNING: Generated invalid (overlapping) degraded excerpt "
"Generated invalid (overlapping) degraded excerpt "
"too many times. Try raising tries parameter (default 10). "
"Returning None."
)
Expand Down Expand Up @@ -231,9 +231,7 @@ def pitch_shift(
or None if the degradation cannot be performed.
"""
if len(excerpt) == 0:
warnings.warn(
"WARNING: No notes to pitch shift. Returning None.", category=UserWarning
)
logging.warning("No notes to pitch shift. Returning None.")
return None

excerpt = pre_process(excerpt)
Expand All @@ -255,8 +253,8 @@ def pitch_shift(
distribution[zero_idx] = 0

if np.sum(distribution) == 0:
warnings.warn(
"WARNING: distribution contains only 0s after "
logging.warning(
"distribution contains only 0s after "
"setting distribution[zero_idx] value to 0. "
"Returning None."
)
Expand All @@ -278,8 +276,8 @@ def pitch_shift(
].tolist()

if not valid_notes:
warnings.warn(
"WARNING: No valid pitches to shift given "
logging.warning(
"No valid pitches to shift given "
f"min_pitch {min_pitch}, max_pitch {max_pitch}, "
f"and distribution {distribution} (after setting "
"distribution[zero_idx] to 0). Returning None."
Expand Down Expand Up @@ -312,7 +310,7 @@ def pitch_shift(
# Check if overlaps
if overlaps(degraded, note_index) or degraded.loc[note_index, "pitch"] == pitch:
if tries == 1:
warnings.warn(TRIES_WARN_MSG)
logging.warning(TRIES_WARN_MSG)
return None
return pitch_shift(
excerpt,
Expand Down Expand Up @@ -415,10 +413,7 @@ def time_shift(
valid_notes = list(valid.index[valid])

if not valid_notes:
warnings.warn(
"WARNING: No valid notes to time shift. Returning " "None.",
category=UserWarning,
)
logging.warning("No valid notes to time shift. Returning None.")
return None

# Sample a random note
Expand All @@ -443,7 +438,7 @@ def time_shift(
# Check if overlaps
if overlaps(degraded, index):
if tries == 1:
warnings.warn(TRIES_WARN_MSG)
logging.warning(TRIES_WARN_MSG)
return None
return time_shift(
excerpt,
Expand Down Expand Up @@ -600,10 +595,7 @@ def onset_shift(
valid_notes = list(valid.index[valid])

if not valid_notes:
warnings.warn(
"WARNING: No valid notes to onset shift. Returning " "None.",
category=UserWarning,
)
logging.warning("No valid notes to onset shift. Returning None.")
return None

# Sample a random note
Expand Down Expand Up @@ -645,7 +637,7 @@ def onset_shift(
# Check if overlaps
if overlaps(degraded, index):
if tries == 1:
warnings.warn(TRIES_WARN_MSG)
logging.warning(TRIES_WARN_MSG)
return None
return onset_shift(
excerpt,
Expand Down Expand Up @@ -765,10 +757,7 @@ def offset_shift(
valid_notes = list(valid.index[valid])

if not valid_notes:
warnings.warn(
"WARNING: No valid notes to offset shift. Returning " "None.",
category=UserWarning,
)
logging.warning("No valid notes to offset shift. Returning None.")
return None

# Sample a random note
Expand All @@ -794,7 +783,7 @@ def offset_shift(
# Check if overlaps
if overlaps(degraded, index):
if tries == 1:
warnings.warn(TRIES_WARN_MSG)
logging.warning(TRIES_WARN_MSG)
return None
return offset_shift(
excerpt,
Expand Down Expand Up @@ -836,9 +825,7 @@ def remove_note(excerpt, tries=TRIES_DEFAULT):
the degradations cannot be performed.
"""
if excerpt.shape[0] == 0:
warnings.warn(
"WARNING: No notes to remove. Returning None.", category=UserWarning
)
logging.warning("No notes to remove. Returning None.")
return None

degraded = pre_process(excerpt)
Expand Down Expand Up @@ -929,10 +916,7 @@ def add_note(
pitch = excerpt["pitch"].between(min_pitch, max_pitch, inclusive=True)
pitch = excerpt["pitch"][pitch].unique()
if len(pitch) == 0:
warnings.warn(
"WARNING: No valid aligned pitch in given " "range.",
category=UserWarning,
)
logging.warning("No valid aligned pitch in given range.")
return None
pitch = choice(pitch)
else:
Expand All @@ -941,10 +925,7 @@ def add_note(
# Find onset and duration
if align_time:
if min_duration > excerpt["dur"].max() or max_duration < excerpt["dur"].min():
warnings.warn(
"WARNING: No valid aligned duration in " "given range.",
category=UserWarning,
)
logging.warning("No valid aligned duration in given range.")
return None

durations = excerpt["dur"].between(min_duration, max_duration, inclusive=True)
Expand Down Expand Up @@ -983,7 +964,7 @@ def add_note(
# Check if overlaps
if overlaps(degraded, degraded.index[-1]):
if tries == 1:
warnings.warn(TRIES_WARN_MSG)
logging.warning(TRIES_WARN_MSG)
return None
return add_note(
excerpt,
Expand Down Expand Up @@ -1036,9 +1017,7 @@ def split_note(
the degradation cannot be performed.
"""
if excerpt.shape[0] == 0:
warnings.warn(
"WARNING: No notes to split. Returning None.", category=UserWarning
)
logging.warning("No notes to split. Returning None.")
return None

excerpt = pre_process(excerpt)
Expand All @@ -1048,9 +1027,7 @@ def split_note(
valid_notes = list(long_enough.index[long_enough])

if not valid_notes:
warnings.warn(
"WARNING: No valid notes to split. Returning " "None.", category=UserWarning
)
logging.warning("No valid notes to split. Returning None.")
return None

note_index = choice(valid_notes)
Expand Down Expand Up @@ -1132,9 +1109,7 @@ def join_notes(
the degradation cannot be performed.
"""
if excerpt.shape[0] < 2:
warnings.warn(
"WARNING: No notes to join. Returning None.", category=UserWarning
)
logging.warning("No notes to join. Returning None.")
return None

excerpt = pre_process(excerpt, sort=True)
Expand Down Expand Up @@ -1175,9 +1150,7 @@ def join_notes(
valid_starts.extend(valid_starts_this)

if not valid_starts:
warnings.warn(
"WARNING: No valid notes to join. Returning " "None.", category=UserWarning
)
logging.warning("No valid notes to join. Returning None.")
return None

index = randint(len(valid_starts))
Expand Down
Loading

0 comments on commit 6dfe6e7

Please sign in to comment.