diff --git a/test/test_training.py b/test/test_training.py index 8a6b3a8..3ea9f32 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -80,7 +80,7 @@ def test_training(): log_cuda_info() log_model_parameters(model) - dataset = MIDIDataset(Path('test_files'), max_seq_len=128, min_seq_len=64, padding_token=0) + dataset = MIDIDataset(Path('test', 'test_files'), max_seq_len=128, min_seq_len=64, padding_token=0) subset_train, subset_valid = create_subsets(dataset, [0.4]) dataloader_train = DataLoader(subset_train, batch_size=8, collate_fn=collate_ar) dataloader_valid = DataLoader(subset_valid, batch_size=8, collate_fn=collate_ar)