Skip to content

Commit

Permalink
Audio PR - rocAL Python changes for Audio Decoder (ROCm#120)
Browse files Browse the repository at this point in the history
* Audio Decoder Python Changes

---------

Co-authored-by: root <[email protected]>
Co-authored-by: Swetha B S <[email protected]>
Co-authored-by: swetha097 <[email protected]>
Co-authored-by: swetha097 <[email protected]>
Co-authored-by: Swetha B S <>
Co-authored-by: SundarRajan28 <[email protected]>
  • Loading branch information
6 people authored Apr 23, 2024
1 parent 9a94efe commit 770e3d0
Show file tree
Hide file tree
Showing 10 changed files with 316 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
* Packages - dev & tests
* Support for audio loader and decoder, which uses libsndfile library to decode wav files
* C++ rocAL audio unit test and python script to run and compare the outputs
* Python support for audio decoders
* Pytorch iterator for Audio
* Python audio unit test, and support to verify outputs

### Optimizations

Expand Down
6 changes: 3 additions & 3 deletions rocAL/include/loaders/circular_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ struct DecodedDataInfo {
std::vector<uint32_t> _roi_height;
std::vector<uint32_t> _original_width;
std::vector<uint32_t> _original_height;
std::vector<uint32_t> _original_audio_samples; //! Amplitude of an audio signal at a specific point in time
std::vector<uint32_t> _original_audio_channels; //! Number of audio channels in an audio signal
std::vector<float> _original_audio_sample_rates; //! The number of samples of audio carried per second
std::vector<uint32_t> _audio_samples; //! Amplitude of an audio signal at a specific point in time
std::vector<uint32_t> _audio_channels; //! Number of audio channels in an audio signal
std::vector<float> _audio_sample_rates; //! The number of samples of audio carried per second
};

struct CropImageInfo {
Expand Down
10 changes: 5 additions & 5 deletions rocAL/source/loaders/audio/audio_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ void AudioLoader::initialize(ReaderConfig reader_cfg, DecoderConfig decoder_cfg,
_max_decoded_samples = _output_tensor->info().max_shape().at(0);
_max_decoded_channels = _output_tensor->info().max_shape().at(1);
_decoded_audio_info._data_names.resize(_batch_size);
_decoded_audio_info._original_audio_samples.resize(_batch_size);
_decoded_audio_info._original_audio_channels.resize(_batch_size);
_decoded_audio_info._original_audio_sample_rates.resize(_batch_size);
_decoded_audio_info._audio_samples.resize(_batch_size);
_decoded_audio_info._audio_channels.resize(_batch_size);
_decoded_audio_info._audio_sample_rates.resize(_batch_size);
_circ_buff.init(_mem_type, _output_mem_size, _prefetch_queue_depth);
_is_initialized = true;
LOG("Loader module initialized");
Expand Down Expand Up @@ -229,8 +229,8 @@ AudioLoader::update_output_audio() {
return LoaderModuleStatus::OK;
_output_decoded_audio_info = _circ_buff.get_decoded_data_info();
_output_names = _output_decoded_audio_info._data_names;
_output_tensor->update_tensor_roi(_output_decoded_audio_info._original_audio_samples, _output_decoded_audio_info._original_audio_channels);
_output_tensor->update_audio_tensor_sample_rate(_output_decoded_audio_info._original_audio_sample_rates);
_output_tensor->update_tensor_roi(_output_decoded_audio_info._audio_samples, _output_decoded_audio_info._audio_channels);
_output_tensor->update_audio_tensor_sample_rate(_output_decoded_audio_info._audio_sample_rates);
_circ_buff.pop();
if (!_loop)
_remaining_audio_count -= _batch_size;
Expand Down
6 changes: 3 additions & 3 deletions rocAL/source/loaders/audio/audio_read_and_decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ AudioReadAndDecode::Load(float *audio_buffer,
}
for (size_t i = 0; i < _batch_size; i++) {
audio_info._data_names[i] = _audio_meta_info[i].file_name;
audio_info._original_audio_samples[i] = _audio_meta_info[i].samples;
audio_info._original_audio_channels[i] = _audio_meta_info[i].channels;
audio_info._original_audio_sample_rates[i] = _audio_meta_info[i].sample_rate;
audio_info._audio_samples[i] = _audio_meta_info[i].samples;
audio_info._audio_channels[i] = _audio_meta_info[i].channels;
audio_info._audio_sample_rates[i] = _audio_meta_info[i].sample_rate;
}
}
_decode_time.end(); // Debug timing
Expand Down
25 changes: 25 additions & 0 deletions rocAL_pybind/amd/rocal/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,28 @@ def image_slice(*inputs, file_root='', path='', annotations_file='', shard_id=0,
image_decoder_slice = b.fusedDecoderCropShard(
Pipeline._current_pipeline._handle, *(kwargs_pybind.values()))
return (image_decoder_slice)

def audio(*inputs, file_root='', file_list_path='', shard_id=0, num_shards=1, random_shuffle=False, downmix=False, stick_to_shard=False, shard_size=-1):
"""!Decodes wav audio files.
@param inputs List of input audio.
@param file_root Folder Path to the audio data.
@param file_list_path (for future use) Path to the text file containing list of files and the labels
@param shard_id Shard ID for parallel processing.
@param num_shards Total number of shards for parallel processing.
@param random_shuffle Whether to shuffle audio samples randomly.
@param downmix Converts the audio data to single channel when enabled
@param stick_to_shard The reader sticks to the data for it's corresponding shard when enabled
@param shard_size Provides the number of files in an epoch of a particular shard.
@return Decoded audio.
"""
kwargs_pybind = {
"source_path": file_root,
"shard_id": shard_id,
"num_shards": num_shards,
"is_output": False,
"shuffle": random_shuffle,
"loop": False,
"downmix": downmix}
decoded_audio = b.audioDecoderSingleShard(Pipeline._current_pipeline._handle, *(kwargs_pybind.values()))
return decoded_audio
70 changes: 66 additions & 4 deletions rocAL_pybind/amd/rocal/plugin/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __del__(self):


class ROCALClassificationIterator(ROCALGenericIterator):
"""!ROCAL iterator for classification tasks for PyTorch. It returns 2 outputs
"""!rocAL iterator for classification tasks for PyTorch. It returns 2 outputs
(data and label) in the form of PyTorch's Tensors.
Calling
Expand All @@ -237,15 +237,15 @@ class ROCALClassificationIterator(ROCALGenericIterator):
ROCALGenericIterator(pipelines, ["data", "label"], size)
Please keep in mind that Tensors returned by the iterator are
still owned by ROCAL. They are valid till the next iterator call.
still owned by rocAL. They are valid till the next iterator call.
If the content needs to be preserved please copy it to another tensor.
pipelines (list of amd.rocal.pipeline.Pipeline) List of pipelines to use
size (int) Number of samples in the epoch (Usually the size of the dataset).
auto_reset (bool, optional, default = False) Whether the iterator resets itself for the next epoch or it requires reset() to be called separately.
fill_last_batch (bool, optional, default = True) Whether to fill the last batch with data up to 'self.batch_size'. The iterator would return the first integer multiple of self._num_gpus * self.batch_size entries which exceeds 'size'. Setting this flag to False will cause the iterator to return exactly 'size' entries.
dynamic_shape (bool, optional, default = False) Whether the shape of the output of the ROCAL pipeline can change during execution. If True, the pytorch tensor will be resized accordingly if the shape of ROCAL returned tensors changes during execution. If False, the iterator will fail in case of change.
last_batch_padded (bool, optional, default = False) Whether the last batch provided by ROCAL is padded with the last sample or it just wraps up. In the conjunction with fill_last_batch it tells if the iterator returning last batch with data only partially filled with data from the current epoch is dropping padding samples or samples from the next epoch. If set to False next epoch will end sooner as data from it was consumed but dropped. If set to True next epoch would be the same length as the first one.
dynamic_shape (bool, optional, default = False) Whether the shape of the output of the rocAL pipeline can change during execution. If True, the pytorch tensor will be resized accordingly if the shape of rocAL returned tensors changes during execution. If False, the iterator will fail in case of change.
last_batch_padded (bool, optional, default = False) Whether the last batch provided by rocAL is padded with the last sample or it just wraps up. In the conjunction with fill_last_batch it tells if the iterator returning last batch with data only partially filled with data from the current epoch is dropping padding samples or samples from the next epoch. If set to False next epoch will end sooner as data from it was consumed but dropped. If set to True next epoch would be the same length as the first one.
Example
-------
Expand All @@ -270,6 +270,68 @@ def __init__(self,
super(ROCALClassificationIterator, self).__init__(pipe, tensor_layout=pipe._tensor_layout, tensor_dtype=pipe._tensor_dtype,
multiplier=pipe._multiplier, offset=pipe._offset, display=display, device=device, device_id=device_id)

class ROCALAudioIterator(object):
"""! rocAL iterator for audio tasks for PyTorch
The Tensors that are returned by the iterator will be owned by rocAL and would be valid until next iteration.
@param pipeline The rocAL pipeline to use for processing data.
@param tensor_dtype Data type of the output tensors.
@param size Number of samples in the epoch (Usually the size of the dataset).
@param auto_reset Whether the iterator resets itself for the next epoch or it requires reset() to be called separately.
@param device The device to use for processing - CPU / GPU
@param device_id The ID of the device to use
"""
def __init__(self, pipeline, tensor_dtype = types.FLOAT, size = -1, auto_reset = False, device = "cpu", device_id = 0):
self.loader = pipeline
self.device = device
self.device_id = device_id
self.output = None
self.iterator_length = b.getRemainingImages(self.loader._handle) # To change the name of getRemainingImages to getRemainingSamples in upcoming PRs
self.max_shape = None
self.batch_size = self.loader._batch_size
self.output_list = None
self.labels_size = self.batch_size
self.output_memory_type = self.loader._output_memory_type

def next(self):
return self.__next__()

def __next__(self):
if self.loader.rocal_run() != 0:
raise StopIteration
else:
self.output_tensor_list = self.loader.get_output_tensors()
# Output list used to store pipeline outputs - can support multiple augmentation outputs
self.output_list = []
for i in range(len(self.output_tensor_list)):
dimensions = self.output_tensor_list[i].dimensions()
self.num_roi_dims = self.output_tensor_list[i].roi_dims_size()
self.roi_array = np.zeros(self.batch_size * self.num_roi_dims * 2, dtype=np.int32)
self.output_tensor_list[i].copy_roi(self.roi_array)
torch_dtype = self.output_tensor_list[i].dtype()
if self.device == "cpu":
output = torch.empty(dimensions, dtype=getattr(torch, torch_dtype))
self.labels_tensor = torch.empty(self.labels_size, dtype=getattr(torch, torch_dtype))
else:
torch_gpu_device = torch.device('cuda', self.device_id)
output = torch.empty(dimensions, dtype=getattr(torch, torch_dtype), device=torch_gpu_device)
self.labels_tensor = torch.empty(self.labels_size, dtype=getattr(torch, torch_dtype), device=torch_gpu_device)

self.output_tensor_list[i].copy_data(ctypes.c_void_p(output.data_ptr()), self.output_memory_type)
self.output_list.append(output)

return self.output_list, self.labels_tensor, torch.tensor(self.roi_array.reshape(self.batch_size,4)[...,2:4])

def reset(self):
b.rocalResetLoaders(self.loader._handle)

def __iter__(self):
return self

def __len__(self):
return self.iterator_length

def __del__(self):
b.rocalRelease(self.loader._handle)

def draw_patches(img, idx, bboxes):
"""!Writes images to disk as a PNG file.
Expand Down
4 changes: 4 additions & 0 deletions rocAL_pybind/rocal_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ PYBIND11_MODULE(rocal_pybind, m) {
py::return_value_policy::reference);
m.def("externalSourceFeedInput", &wrapperRocalExternalSourceFeedInput,
py::return_value_policy::reference);
m.def("audioDecoderSingleShard", &rocalAudioFileSourceSingleShard, "Reads file from the source given and decodes it",
py::return_value_policy::reference);
m.def("audioDecoder", &rocalAudioFileSource, "Reads file from the source given and decodes it",
py::return_value_policy::reference);
m.def("rocalResetLoaders", &rocalResetLoaders);
m.def("videoMetaDataReader", &rocalCreateVideoLabelReader, py::return_value_policy::reference);
// rocal_api_augmentation.h
Expand Down
21 changes: 21 additions & 0 deletions tests/python_api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,24 @@ This test runs a pipeline making use of the external source reader in 3 differen
```bash
python3 external_source_reader.py
```

## Audio Unit Test

To run the Audio unit test with all test cases. Follow the steps below

```bash
export ROCAL_DATA_PATH=<Absolute_path_to_MIVisionX-data>
```
To run the audio unit test and verify the correctness of the outputs

```bash
python3 audio_unit_test.py
```
To pass the audio data path, batch size, and run a particular test case use the following command

```bash
python3 audio_unit_test.py --audio_path=<path_to_data> --test_case <case(0-1)> --batch-size <batch_size>
```

**Available Test Cases**
* Case 0 - Audio Decoder
Loading

0 comments on commit 770e3d0

Please sign in to comment.