Skip to content

Commit

Permalink
Move TODOs to Issues
Browse files Browse the repository at this point in the history
Issues now written to address later: #98, #99, #100, #101, #102
issues #103 #104 #105 #106 #107 and #108
  • Loading branch information
JamesOwers committed Nov 26, 2019
1 parent 7038857 commit 1d9255c
Show file tree
Hide file tree
Showing 8 changed files with 3 additions and 73 deletions.
7 changes: 0 additions & 7 deletions make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def parse_args(args_input=None):
parser.add_argument('--local-midi-dirs', metavar='midi_dir', type=str,
nargs='*', help='directories containing midi files to '
'include in the dataset', default=[])
# TODO: check this works!
parser.add_argument('--local-csv-dirs', metavar='csv_dir', type=str,
nargs='*', help='directories containing csv files to '
'include in the dataset', default=[])
Expand All @@ -114,12 +113,10 @@ def parse_args(args_input=None):
'--download-cache-dir and --clear-download-cache. To '
'download no data, provide an input of "None"',
)
# TODO: check this works!
parser.add_argument('--download-cache-dir', type=str,
default=downloaders.DEFAULT_CACHE_PATH, help='The '
'directory to use for storing intermediate downloaded '
'data e.g. zip files, and prior to preprocessing.')
# TODO: check this works!
parser.add_argument('--clear-download-cache', action='store_true',
help='clear downloaded data cache')
parser.add_argument('--degradations', metavar='deg_name', nargs='*',
Expand All @@ -136,7 +133,6 @@ def parse_args(args_input=None):
parser.add_argument('--min-notes', metavar='N', type=int, default=10,
help='The minimum number of notes required for an '
'excerpt to be valid.')
# TODO: check this works!
parser.add_argument('--degradation-kwargs', metavar='json_string',
help='json with keyword arguments for the '
'degradation functions. First provide the degradation '
Expand All @@ -145,7 +141,6 @@ def parse_args(args_input=None):
'kwarg. e.g. {"pitch_shift__distribution": "poisson", '
'"pitch_shift__min_pitch: 5"}',
type=json.loads, default=None)
# TODO: check this works!
parser.add_argument('--degradation-kwarg-json', metavar='json_file',
help='A file containing parameters as described in '
'--degradation-kwargs. If this file is given, '
Expand Down Expand Up @@ -242,7 +237,6 @@ def parse_args(args_input=None):


# Instantiate downloaders =================================================
# TODO: make OVERWRITE this an arg for the script
OVERWRITE = None
ds_names = ARGS.datasets
if len(ds_names) == 1 and ds_names[0].lower() == 'none':
Expand Down Expand Up @@ -568,5 +562,4 @@ def parse_args(args_input=None):

print('\nTo reproduce this dataset again, run the script with argument '
f'--seed {seed}')
#TODO: print('see the examples directory for baseline models using this data')
print(LOGO)
6 changes: 0 additions & 6 deletions mdtk/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def fix_overlapping_notes(df):
bad_note = np.append(bad_note, False) # last note always fine
df.loc[bad_note, 'dur'] = (next_note_on[bad_note[:-1]]
- df.loc[bad_note, 'onset'].values)
# TODO: add an assertion to catch dur==0 and add a test
return df


Expand All @@ -246,7 +245,6 @@ def fix_overlapping_pitches(df):
bad_note = np.append(bad_note, False) # last note always fine
df.loc[bad_note, 'dur'] = (next_note_on[bad_note[:-1]]
- df.loc[bad_note, 'onset'].values)
# TODO: add an assertion to catch dur==0 and add a test
return df


Expand Down Expand Up @@ -930,10 +928,6 @@ def __init__(self, note_df=None, csv_path=None,
self.note_df = note_df
# We do not assume that the supplied note_df is correctly formed,
# and simply bomb out if it is not
# TODO: implement df methods to fix issues instead e.g. overlaps.
# Copy code from read_note_csv. e.g.:
# * reorder columns
# * if all columns but track and no extra cols, assume 1 trk
if self.monophonic_tracks is not None:
make_monophonic(self.note_df, tracks=monophonic_tracks)
if self.max_note_len is not None:
Expand Down
13 changes: 1 addition & 12 deletions mdtk/downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@


# Classes =====================================================================
# TODO: make attributes useful to users standard e.g. beat-aligned=True/False
# TODO: some things are likely to be important re preprocessing e.g. the unit
# for the onset and duration of notes. Add these as attributes too.
class DataDownloader:
"""Base class for data downloaders"""
def __init__(self, cache_path=DEFAULT_CACHE_PATH):
Expand Down Expand Up @@ -60,20 +57,14 @@ def download_csv(self, output_path, cache_path=None, overwrite=None,
'implement the download_csv method.')


# TODO: since these datasets already have CSV in the right format, we should
# implement download_csv() methods to use in favour of the download_midi
# this method would reformat csv to be correct (cols need renaming etc.)
# TODO: handle conversion from quarters to ms - use tempo data, but use min/max
# tempo values as some were a bit spurious

class PPDDSep2018Monophonic(DataDownloader):
"""Patterns for Preditction Development Dataset. Monophonic data only.
References
----------
https://www.music-ir.org/mirex/wiki/2019:Patterns_for_Prediction
"""
# TODO: add 'sample_size', to allow only a small random sample of the
# total midi files to be copied to the output
def __init__(self, cache_path=DEFAULT_CACHE_PATH,
sizes=['small', 'medium', 'large'], clean=False):
super().__init__(cache_path = cache_path)
Expand Down Expand Up @@ -136,8 +127,6 @@ class PPDDSep2018Polyphonic(PPDDSep2018Monophonic):
----------
https://www.music-ir.org/mirex/wiki/2019:Patterns_for_Prediction
"""
# TODO: add 'sample_size', to allow only a small random sample of the
# total midi files to be copied to the output
def __init__(self, cache_path=DEFAULT_CACHE_PATH,
sizes=['small', 'medium', 'large'], clean=False):
super().__init__(cache_path=cache_path, sizes=sizes, clean=clean)
Expand Down
3 changes: 1 addition & 2 deletions mdtk/filesystem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def make_directory(path, overwrite=None, verbose=False):


def extract_zip(zip_path, out_path, overwrite=None, verbose=False):
"""Convenience function to extract zip file to out_path.
TODO: make work for all types of zip files."""
"""Convenience function to extract zip file to out_path."""
if verbose:
print(f'Extracting {zip_path} to {out_path}')
dirname = os.path.splitext(os.path.basename(zip_path))[0]
Expand Down
5 changes: 1 addition & 4 deletions mdtk/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def diff_pd(df1, df2):
return pd.DataFrame({'from': changed_from, 'to': changed_to},
index=changed.index)

# TODO: later can auto detect vocab from corpus if necessary
# I'm doing things this way just for ability to change things
# later with ease

class CommandVocab(object):
def __init__(self, min_pitch=MIN_PITCH_DEFAULT,
max_pitch=MAX_PITCH_DEFAULT,
Expand Down Expand Up @@ -351,7 +349,6 @@ def df_to_command_str(df, min_pitch=MIN_PITCH_DEFAULT, max_pitch=MAX_PITCH_DEFAU
assert time_increment > 0, "time_increment must be positive."
assert max_time_shift > 0, "max_time_shift must be positive."

# TODO: This rounding may result in notes of length 0.
note_off = df.loc[:, ['onset', 'pitch']]
note_off['onset'] = note_off['onset'] + df['dur']
note_off['cmd'] = note_off['pitch'].apply(lambda x: f'f{x}')
Expand Down
2 changes: 0 additions & 2 deletions mdtk/pytorch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size=2,
self.vocab_size = vocab_size

self.embedding = nn.Embedding(vocab_size, embedding_dim)
# TODO: try getting batch_first to work on this model
self.lstm = nn.LSTM(embedding_dim, hidden_dim,
num_layers=num_lstm_layers)

Expand All @@ -49,7 +48,6 @@ def forward(self, batch, input_lengths=None):
device = batch.device
self.hidden = self.init_hidden(batch_size, device=device)
embeds = self.embedding(batch).permute(1, 0, 2)
# TODO: try getting batch_first to work on this model
# embeds = self.embedding(batch)
outputs, (ht, ct) = self.lstm(embeds, self.hidden)
# ht is the last hidden state of the sequences
Expand Down
11 changes: 0 additions & 11 deletions mdtk/pytorch_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@



# TODO: I don't like the fomatter being passed in here - would prefer these
# Trainers to be more general except for the iteration method for which you
# hardcode how to do the train/test iteration.
class BaseTrainer:
"""Provides methods to train pytorch models. Adapted from:
https://github.com/codertimo/BERT-pytorch/blob/master/bert_pytorch/trainer/pretrain.py"""
Expand Down Expand Up @@ -137,7 +134,6 @@ def save(self, file_path=None, epoch=None):
print(f"Model saved {output_path}")
return output_path

# TODO: implement load method (for use with load from checkpoint)


class ErrorDetectionTrainer(BaseTrainer):
Expand Down Expand Up @@ -706,13 +702,6 @@ def iteration(self, epoch, data_loader, train=True, evaluate=False):
total_data_points += len(input_data)
for in_data, out_data, clean_data in \
zip(input_data, model_output, labels):
# TODO: Only 1 of these calls is necessary. deg and clean
# could conceivably be returned by the data loader.
# N.B. Currently, the precise min and max pitch don't
# matter here. The converter just treats them all the same,
# corrects and warns if the range doesn't make sense.
# However, if loading deg and clean from the original df,
# using the correct min and max pitch will be important.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
deg_df = self.formatter['model_to_df'](
Expand Down
29 changes: 0 additions & 29 deletions mdtk/tests/test_data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,18 +445,7 @@ def test_pianoroll_all_pitches():
assert (pianoroll == np.ones((1, 2, 128, 1), dtype='uint8')).all()


# TODO: test all note_on occur with sounding

# TODO: test all note_off occur with sounding

# TODO: test all sounding begin note_on and end_note_off

# TODO: test all methods in pianoroll and all attributes


# Composition class tests =====================================================
# TODO: write import from csv tests

def test_composition_df_assertions():
"""Essentially the same tests as test_check_note_df"""
assertion = False
Expand Down Expand Up @@ -504,16 +493,6 @@ def test_composition_all_pitches():



# TODO: reimplement this if and when we implement auto fix of note_df
#def test_auto_sort_onset_and_pitch():
# comp = Composition(note_df=note_df_2pitch_aligned, fix_note_df=True)
# assert comp.note_df.equals(
# note_df_2pitch_aligned
# .sort_values(['onset', 'pitch'])
# .reset_index(drop=True)
# )


def test_not_ending_in_silence():
for df in ALL_VALID_DF.values():
comp = Composition(note_df=df)
Expand Down Expand Up @@ -555,19 +534,11 @@ def test_composition_read_csv():
comp.plot()
comp.synthesize()

def test_csv_and_df_imports_same():
# TODO: write test that imports from all csvs and checks same as
# importing from df
pass


# TODO: Check if anything alters input data - loop over all functions and
# methods


# Cleanup =====================================================================
# TODO: This isn't technichally a test...should probably be some other function
# look up the proper way to do this.
def test_remove_csvs():
for csv in ALL_CSV:
os.remove(csv)

0 comments on commit 1d9255c

Please sign in to comment.