From dfad63fe5dde2b4e6666020d970b9267e2bf426d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 11 Jun 2024 23:43:52 +0200 Subject: [PATCH] Datasets updated. --- dpmhm/cli.py | 4 +- dpmhm/datasets/__init__.py | 27 +++--- dpmhm/datasets/cwru/cwru.py | 4 +- dpmhm/datasets/dirg/dirg.py | 7 +- dpmhm/datasets/femto/femto.py | 4 - dpmhm/datasets/fraunhofer151/fraunhofer151.py | 4 +- dpmhm/datasets/fraunhofer205/fraunhofer205.py | 6 +- dpmhm/datasets/ims/ims.py | 13 ++- dpmhm/datasets/paderborn/paderborn.py | 4 +- .../{untested => }/phmap2021/__init__.py | 0 .../{untested => }/phmap2021/phmap2021.py | 84 ++++++++++++------- .../datasets/{untested => }/seuc/__init__.py | 0 dpmhm/datasets/{untested => }/seuc/secu.bib | 0 dpmhm/datasets/{untested => }/seuc/seuc.py | 48 ++++++----- dpmhm/datasets/transformer.py | 34 +++++--- .../untested/phm2022/Exported Items.bib | 0 dpmhm/datasets/untested/phm2022/__init__.py | 1 + dpmhm/datasets/utils.py | 24 ++++++ dpmhm/datasets/xjtu/xjtu.py | 13 ++- dpmhm/models/sl/clae.py | 2 +- dpmhm/models/sl/vggish.py | 2 +- 21 files changed, 182 insertions(+), 99 deletions(-) rename dpmhm/datasets/{untested => }/phmap2021/__init__.py (100%) rename dpmhm/datasets/{untested => }/phmap2021/phmap2021.py (67%) rename dpmhm/datasets/{untested => }/seuc/__init__.py (100%) rename dpmhm/datasets/{untested => }/seuc/secu.bib (100%) rename dpmhm/datasets/{untested => }/seuc/seuc.py (85%) delete mode 100644 dpmhm/datasets/untested/phm2022/Exported Items.bib create mode 100644 dpmhm/datasets/untested/phm2022/__init__.py diff --git a/dpmhm/cli.py b/dpmhm/cli.py index 9c7f5d6..53d1e72 100644 --- a/dpmhm/cli.py +++ b/dpmhm/cli.py @@ -21,8 +21,8 @@ 'mafaulda': 'Mafaulda', 'ottawa': 'Ottawa', 'paderborn': 'Paderborn', - # 'phmap2021': 'Phmap2021', - # 'seuc': 'SEUC', + 'phmap2021': 'Phmap2021', + 'seuc': 'SEUC', 'xjtu': 'XJTU' } diff --git a/dpmhm/datasets/__init__.py b/dpmhm/datasets/__init__.py index a430870..4afad8f 100644 --- a/dpmhm/datasets/__init__.py +++ b/dpmhm/datasets/__init__.py @@ -18,19 +18,6 @@ from .. import cli -# from .cwru import CWRU -# from .dcase import DCASE2021 -# from .seuc import SEUC -# from .mfpt import MFPT -# from .dirg import DIRG -# from .mafaulda import MAFAULDA -# from .ims import IMS -# from .ottawa import Ottawa -# from .paderborn import Paderborn -# from .femto import FEMTO -# from .fraunhofer import Fraunhofer205, Fraunhofer151 -# from .phmdc import Phmap2021 - # Data type _FLOAT16 = np.float16 _FLOAT32 = np.float32 @@ -163,4 +150,16 @@ def extract_zenodo_urls(url:str) -> list: urls.append(header+'/'+s.split('?download=1')[0]) except: pass - return urls \ No newline at end of file + return urls + + +def load_compact(ds_name:str, split:str|list, **kwargs): + from .transformer import DatasetCompactor + + ds0 = tfds.load(ds_name, split=split) + + compactor = DatasetCompactor( + ds0, **kwargs + ) + + return compactor \ No newline at end of file diff --git a/dpmhm/datasets/cwru/cwru.py b/dpmhm/datasets/cwru/cwru.py index c307be1..482e140 100644 --- a/dpmhm/datasets/cwru/cwru.py +++ b/dpmhm/datasets/cwru/cwru.py @@ -100,10 +100,8 @@ # URL to the zip file # _DATA_URLS = ('https://engineering.case.edu/sites/default/files/'+_METAINFO['FileName']).tolist() -# _DATA_URLS = extract_zenodo_urls('https://sandbox.zenodo.org/record/1183527/) _DATA_URLS = [ - 'https://sandbox.zenodo.org/record/1183527/files/cwru.zip' - # 'https://zenodo.org/api/records/7457149/draft/files/cwru.zip/content' + 'https://zenodo.org/records/7457149/files/cwru.zip?download=1' ] diff --git a/dpmhm/datasets/dirg/dirg.py b/dpmhm/datasets/dirg/dirg.py index b541881..cfd8844 100644 --- a/dpmhm/datasets/dirg/dirg.py +++ b/dpmhm/datasets/dirg/dirg.py @@ -19,7 +19,7 @@ - Format: Matlab - Channels: 6, for two accelerometers in the x-y-z axis - Split: 'Variable speed and load' test, 'Endurance' test -- Sampling rate: 51200 Hz for `Variable speed and load` test and 102400 Hz for `Endurance` test +- Sampling rate: 51200 Hz (51.2 kHz) for `Variable speed and load` test and 102400 Hz (102.4 kHz) for `Endurance` test - Recording duration: 10 seconds for `Variable speed and load` test and 8 seconds for `Endurance` test - Label: normal and faulty - Size: ~ 3Gb unzipped @@ -49,6 +49,7 @@ Notes ===== - Conversion: load is converted from mV to N using the sensitivity factor 0.499 mV/N +- Only the bearing `B1` contains faults so `B2` and `B3` are not used. - The endurance test was originally with the fault type 4A but in the processed data we marked its label as "unknown". """ @@ -97,7 +98,9 @@ '6A': ('Roller', 150), } -_DATA_URLS = ['https://sandbox.zenodo.org/record/1183545/files/dirg.zip'] +_DATA_URLS = [ + 'https://zenodo.org/records/11394418/files/dirg.zip?download=1' + ] class DIRG(tfds.core.GeneratorBasedBuilder): diff --git a/dpmhm/datasets/femto/femto.py b/dpmhm/datasets/femto/femto.py index 941b765..68d99b2 100644 --- a/dpmhm/datasets/femto/femto.py +++ b/dpmhm/datasets/femto/femto.py @@ -117,10 +117,6 @@ 'https://github.com/Lucky-Loek/ieee-phm-2012-data-challenge-dataset/archive/refs/heads/master.zip' ] -# _DATA_URLS = [ -# 'https://sandbox.zenodo.org/record/1183585/files/femto.zip' -# ] - # Date of experiment _DATE = { 'Bearing1_1': datetime(2010,12,1), diff --git a/dpmhm/datasets/fraunhofer151/fraunhofer151.py b/dpmhm/datasets/fraunhofer151/fraunhofer151.py index 7fdf39e..71bc6f6 100644 --- a/dpmhm/datasets/fraunhofer151/fraunhofer151.py +++ b/dpmhm/datasets/fraunhofer151/fraunhofer151.py @@ -88,7 +88,9 @@ } """ -_DATA_URLS = ['https://fordatis.fraunhofer.de/bitstream/fordatis/151.2/1/fraunhofer_eas_dataset_for_unbalance_detection_v1.zip'] +_DATA_URLS = [ + 'https://fordatis.fraunhofer.de/bitstream/fordatis/151.2/1/fraunhofer_eas_dataset_for_unbalance_detection_v1.zip' + ] _RADIUS = {'0': 0., '1': 14., '2': 18.5, '3':23., '4':23.} diff --git a/dpmhm/datasets/fraunhofer205/fraunhofer205.py b/dpmhm/datasets/fraunhofer205/fraunhofer205.py index 7c88574..4232133 100644 --- a/dpmhm/datasets/fraunhofer205/fraunhofer205.py +++ b/dpmhm/datasets/fraunhofer205/fraunhofer205.py @@ -102,9 +102,11 @@ _COMPONENT = ['Ball', 'InnerRace', 'OuterRace', 'None'] -_METAINFO = pd.read_csv(Path(__file__).parent/'metainfo.csv', index_col=0) +_METAINFO = pd.read_csv(Path(__file__).parent/'metainfo.csv', index_col=0, keep_default_na=False) -_DATA_URLS = ['https://fordatis.fraunhofer.de/bitstream/fordatis/205/1/fraunhofer_iis_eas_dataset_vibrations_acoustic_emissions_of_drive_train_v1.zip'] +_DATA_URLS = [ + 'https://fordatis.fraunhofer.de/bitstream/fordatis/205/1/fraunhofer_iis_eas_dataset_vibrations_acoustic_emissions_of_drive_train_v1.zip' + ] class Fraunhofer205(tfds.core.GeneratorBasedBuilder): diff --git a/dpmhm/datasets/ims/ims.py b/dpmhm/datasets/ims/ims.py index 9204add..9d197e2 100644 --- a/dpmhm/datasets/ims/ims.py +++ b/dpmhm/datasets/ims/ims.py @@ -66,7 +66,10 @@ } # _DATA_URLS = 'https://phm-datasets.s3.amazonaws.com/NASA/4.+Bearings.zip' -_DATA_URLS = ['https://sandbox.zenodo.org/record/1184320/files/ims.zip'] +_DATA_URLS = [ + # 'https://sandbox.zenodo.org/record/1184320/files/ims.zip' + 'https://zenodo.org/records/11545355/files/ims.zip?download=1' + ] _CITATION = """ - Hai Qiu, Jay Lee, Jing Lin. “Wavelet Filter-based Weak Signature Detection Method and its Application on Roller Bearing Prognostics.” Journal of Sound and Vibration 289 (2006) 1066-1090 @@ -115,9 +118,11 @@ def _info(self) -> tfds.core.DatasetInfo: def _split_generators(self, dl_manager: tfds.download.DownloadManager): def _get_split_dict(datadir): return { - 'dataset1': (datadir/'1st_test').glob('*'), - 'dataset2': (datadir/'2nd_test').glob('*'), - 'dataset3': (datadir/'3rd_test').glob('*'), + 'dataset1': next(datadir.rglob('1st_test')).glob('*'), + 'dataset2': next(datadir.rglob('2nd_test')).glob('*'), + 'dataset3': next(datadir.rglob('3rd_test')).glob('*'), + # 'dataset2': (datadir/'2nd_test').glob('*'), + # 'dataset3': (datadir/'3rd_test').glob('*'), } if dl_manager._manual_dir.exists(): # prefer to use manually downloaded data diff --git a/dpmhm/datasets/paderborn/paderborn.py b/dpmhm/datasets/paderborn/paderborn.py index 7c6a2ac..43f2130 100644 --- a/dpmhm/datasets/paderborn/paderborn.py +++ b/dpmhm/datasets/paderborn/paderborn.py @@ -102,7 +102,9 @@ Christian Lessmeier et al., KAt-DataCenter: mb.uni-paderborn.de/kat/datacenter, Chair of Design and Drive Technology, Paderborn University. """ -_METAINFO = pd.read_csv(Path(__file__).parent / 'metainfo.csv', index_col=0) # use 'Bearing Code' as index +# use 'Bearing Code' as index +# same as CWRU, use `keep_default_na` to preserve `None` as string +_METAINFO = pd.read_csv(Path(__file__).parent / 'metainfo.csv', index_col=0, keep_default_na=False) # _DATA_URLS = ('http://groups.uni-paderborn.de/kat/BearingDataCenter/' + _METAINFO.index+'.rar').tolist() diff --git a/dpmhm/datasets/untested/phmap2021/__init__.py b/dpmhm/datasets/phmap2021/__init__.py similarity index 100% rename from dpmhm/datasets/untested/phmap2021/__init__.py rename to dpmhm/datasets/phmap2021/__init__.py diff --git a/dpmhm/datasets/untested/phmap2021/phmap2021.py b/dpmhm/datasets/phmap2021/phmap2021.py similarity index 67% rename from dpmhm/datasets/untested/phmap2021/phmap2021.py rename to dpmhm/datasets/phmap2021/phmap2021.py index 0e6277f..06bec92 100644 --- a/dpmhm/datasets/untested/phmap2021/phmap2021.py +++ b/dpmhm/datasets/phmap2021/phmap2021.py @@ -73,7 +73,10 @@ 'Bearing': ['train_1st_Bearing.csv', 'train_2nd_Bearing.csv'] } -_DATA_URLS = ['https://sandbox.zenodo.org/record/1184362/files/phmap.zip'] +_DATA_URLS = [ + # 'https://sandbox.zenodo.org/record/1184362/files/phmap.zip' + 'https://zenodo.org/records/11546285/files/phmap.zip?download=1' + ] class Phmap2021(tfds.core.GeneratorBasedBuilder): @@ -119,31 +122,54 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager): 'train': self._generate_examples(datadir), } - def _generate_examples(self, path): - for sp, fnames in _SPLIT_PATH_MATCH.items(): - for fn in fnames: - fp = path / fn - - _signal = pd.read_csv(fp, index_col=0).T.values.astype(_DTYPE.as_numpy_dtype) - - metadata = { - 'Label': sp, - # 'OriginalSplit': sp, - 'FileName': fp.name, - 'Dataset': 'PHMAP2021', - } - - yield hash(frozenset(metadata.items())), { - 'signal': {'vibration': _signal}, - # 'label': sp, - 'sampling_rate': 10544, - 'metadata': metadata - } - - @staticmethod - def get_references(): - try: - with open(Path(__file__).parent / 'Exported Items.bib') as fp: - return fp.read() - except: - pass + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + def _get_split_dict(datadir): + # This doesn't work: + # return {sp: (datadir/fn).rglob('*.csv') for sp, fn in _SPLIT_PATH_MATCH.items()} + return { + 'train': datadir.rglob('*.csv'), + } + + if dl_manager._manual_dir.exists(): # prefer to use manually downloaded data + datadir = Path(dl_manager._manual_dir) + elif dl_manager._extract_dir.exists(): # automatically downloaded & extracted data + datadir = Path(dl_manager._extract_dir) + # elif dl_manager._download_dir.exists(): # automatically downloaded data + # datadir = Path(dl_manager._download_dir) + # tfds.download.iter_archive(fp, tfds.download.ExtractMethod.ZIP) + else: + raise FileNotFoundError() + + return {sp: self._generate_examples(files) for sp, files in _get_split_dict(datadir).items()} + + # def _generate_examples(self, path): + # for sp, fnames in _SPLIT_PATH_MATCH.items(): + def _generate_examples(self, files): + for fp in files: + # for fn in fnames: + # fp = path / fn + + _signal = pd.read_csv(fp, index_col=0).T.values.astype(_DTYPE) + sp = fp.stem.split('_')[-1].capitalize() + + metadata = { + 'Label': sp, + # 'OriginalSplit': sp, + 'FileName': fp.name, + 'Dataset': 'PHMAP2021', + } + + yield hash(frozenset(metadata.items())), { + 'signal': {'vibration': _signal}, + # 'label': sp, + 'sampling_rate': 10544, + 'metadata': metadata + } + + # @staticmethod + # def get_references(): + # try: + # with open(Path(__file__).parent / 'Exported Items.bib') as fp: + # return fp.read() + # except: + # pass diff --git a/dpmhm/datasets/untested/seuc/__init__.py b/dpmhm/datasets/seuc/__init__.py similarity index 100% rename from dpmhm/datasets/untested/seuc/__init__.py rename to dpmhm/datasets/seuc/__init__.py diff --git a/dpmhm/datasets/untested/seuc/secu.bib b/dpmhm/datasets/seuc/secu.bib similarity index 100% rename from dpmhm/datasets/untested/seuc/secu.bib rename to dpmhm/datasets/seuc/secu.bib diff --git a/dpmhm/datasets/untested/seuc/seuc.py b/dpmhm/datasets/seuc/seuc.py similarity index 85% rename from dpmhm/datasets/untested/seuc/seuc.py rename to dpmhm/datasets/seuc/seuc.py index a82fad0..02d1b6d 100644 --- a/dpmhm/datasets/untested/seuc/seuc.py +++ b/dpmhm/datasets/seuc/seuc.py @@ -67,10 +67,7 @@ ``` """ -# import os -# import pathlib -# import itertools -# import json +from pathlib import Path import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -89,7 +86,9 @@ doi={10.1109/TII.2018.2864759}} """ -_DATA_URLs = 'https://github.com/cathysiyu/Mechanical-datasets/archive/refs/heads/master.zip' +_DATA_URLS = [ + 'https://github.com/cathysiyu/Mechanical-datasets/archive/refs/heads/master.zip' +] # Components of fault _FAULT_GEARBOX = ['Chipped', 'Missing', 'Root', 'Surface'] @@ -162,22 +161,27 @@ def _fname_parser(cls, fname): return _component, _load def _split_generators(self, dl_manager: tfds.download.DownloadManager): + def _get_split_dict(datadir): + return { + 'gearbox': next(datadir.rglob('gearset')).rglob('*.csv'), + 'bearing': next(datadir.rglob('bearingset')).rglob('*.csv'), + } + if dl_manager._manual_dir.exists(): # prefer to use manually downloaded data - datadir = dl_manager._manual_dir / 'gearbox' - else: # automatically download data - datadir = list(dl_manager.download_and_extract(_DATA_URLs).iterdir())[0] / 'gearbox' - # print(datadir) - - return { - # Use the original splits - 'gearbox': self._generate_examples(datadir/'gearset'), - 'bearing': self._generate_examples(datadir/'bearingset'), - # 'train': self._generate_examples(datadir) # this will rewrite on precedent splits - } - - def _generate_examples(self, path): - # !! Recursive glob `path.rglob` may not behave as expected - for fp in path.glob('*.csv'): + datadir = Path(dl_manager._manual_dir) + elif dl_manager._extract_dir.exists(): # automatically downloaded & extracted data + datadir = Path(dl_manager._extract_dir) + # elif dl_manager._download_dir.exists(): # automatically downloaded data + # datadir = Path(dl_manager._download_dir) + # tfds.download.iter_archive(fp, tfds.download.ExtractMethod.ZIP) + else: + raise FileNotFoundError() + + return {sp: self._generate_examples(files) for sp, files in _get_split_dict(datadir).items()} + + + def _generate_examples(self, files): + for fp in files: _component, _load = self._fname_parser(fp.name) # try: # df = pd.read_csv(fp,skiprows=15, sep='\t').iloc[:,:-1] @@ -187,11 +191,11 @@ def _generate_examples(self, path): # df = pd.read_csv(fp,skiprows=15, sep=',').iloc[:,:-1] # if df.shape[1] != 8: # raise Exception - df = pd.read_csv(fp,skiprows=15, sep=None, engine='python').iloc[:,:-1] + df = pd.read_csv(fp, skiprows=15, sep=None, engine='python').iloc[:,:-1] if df.shape[1] != 8: raise Exception() - _signal = df.T.values.astype(_DTYPE.as_numpy_dtype) # strangely, df.values.T will give a tuple + _signal = df.T.values.astype(_DTYPE) # strangely, df.values.T will give a tuple metadata = { 'LoadForce': _load, diff --git a/dpmhm/datasets/transformer.py b/dpmhm/datasets/transformer.py index 51c974a..5d456c5 100644 --- a/dpmhm/datasets/transformer.py +++ b/dpmhm/datasets/transformer.py @@ -8,7 +8,7 @@ """ from typing import Union # List, Dict -from abc import ABC, abstractmethod, abstractproperty, abstractclassmethod +from abc import ABC, abstractmethod import itertools import numpy as np @@ -27,7 +27,8 @@ from dpmhm.datasets import utils, _DTYPE, _ENCLEN from dpmhm.datasets.augment import randomly, random_crop, fade -from . import Logger +import logging +Logger = logging.getLogger(__name__) class AbstractDatasetTransformer(ABC): @@ -99,27 +100,29 @@ class DatasetCompactor(AbstractDatasetTransformer): - Data of the subfield 'signal' must be either 1D tensor or 2D tensor of shape `(channel, time)`. """ - def __init__(self, dataset:Dataset, *, channels:list=[], keys:list=[], filters:dict={},resampling_rate:int=None, window_size:int=None, hop_size:int=None): + def __init__(self, dataset:Dataset, *, channels:list=[], keys:list=[], filters:dict={}, resampling_rate:int=None, window_size:int=None, hop_size:int=None, separate_dims:bool=False): """ Args ---- dataset: original dataset channels: - channels for extraction of data, subset of 'signal'. If empty all channels are simultaneously extracted. + channels for extraction of data, subset of 'signal', if not given all channels are extracted. keys: - keys for extraction of new labels, subset of 'metadata'. If empty the original labels are used (no effect). + keys for extraction of new labels, subset of 'metadata', if not given no label is extracted. filters: - filters on the field 'metadata', a dictionary of keys and admissible values. By default no filter is applied. + filters on the field 'metadata', a dictionary of keys and admissible value(s). By default no filter is applied. resampling_rate: rate for resampling, if None use the original sampling rate. window_size: size of the sliding window on time axis, if None no window is applied. hop_size: hop size for the sliding window. No hop if None or `hop_size=1` (no downsampling). Effective only when `window_size` is given. + separate_dims: + if True dimensions of channels are separated and the final dataset consists of 1d signals. """ - self._channels = channels - self._channels_dim = get_number_of_channels(dataset.element_spec['signal'], channels) + self._channels = channels if channels else list(dataset.element_spec['signal'].keys()) + self._channels_dim = get_number_of_channels(dataset.element_spec['signal'], self._channels) self._keys = keys # self._n_chunk = n_chunk self._resampling_rate = resampling_rate @@ -127,6 +130,7 @@ def __init__(self, dataset:Dataset, *, channels:list=[], keys:list=[], filters:d self._window_size = window_size self._hop_size = hop_size + self._separate_dims = separate_dims # dictionary for extracted labels, will be populated only after scanning the compacted dataset self._label_dict = {} @@ -153,6 +157,14 @@ def build(self): utils.sliding_window_generator(ds, 'signal', self._window_size, self._hop_size), output_signature=ds.element_spec, ) + if self._separate_dims: + foo = ds.element_spec.copy() # must use `.copy()` + foo['signal'] = tf.TensorSpec((1,None,)) + ds = Dataset.from_generator( + utils.separate_dims_generator(ds, 'signal'), + output_signature=foo, + ) + return ds @property @@ -302,7 +314,7 @@ def _compact(X): # `self.encode_labels(d)` doesn't work # `ensure_shape()` recover the lost shape due to `py_function()` 'label': tf.ensure_shape(tf.py_function(func=self.encode_labels, inp=d, Tout=tf.string), ()), - # 'metadata': X['metadata'], + 'metadata': X['metadata'], 'sampling_rate': X['sampling_rate'], # 'signal': tf.squeeze(x), 'signal': tf.reshape(x, (self._channels_dim, -1)) @@ -343,7 +355,7 @@ def build(self): def to_feature(cls, ds:Dataset, extractor:callable) -> Dataset: """Feature transform of a compacted dataset of signal. - The transformed database has a dictionary structure which contains + The transformed database has a dictionary structure which contains the fields {'label', 'feature'} """ n_channels = ds.element_spec['signal'].shape[0] @@ -357,7 +369,7 @@ def _feature_map(X): Xf.set_shape((n_channels, None, None)) return { 'label': X['label'], # string label - # 'metadata': X['metadata'], + 'metadata': X['metadata'], # 'feature': tf.reshape(Xf, tf.shape(Xf)) # has no effect 'feature': Xf } diff --git a/dpmhm/datasets/untested/phm2022/Exported Items.bib b/dpmhm/datasets/untested/phm2022/Exported Items.bib deleted file mode 100644 index e69de29..0000000 diff --git a/dpmhm/datasets/untested/phm2022/__init__.py b/dpmhm/datasets/untested/phm2022/__init__.py new file mode 100644 index 0000000..f6487b6 --- /dev/null +++ b/dpmhm/datasets/untested/phm2022/__init__.py @@ -0,0 +1 @@ +from .phm2022 import * \ No newline at end of file diff --git a/dpmhm/datasets/utils.py b/dpmhm/datasets/utils.py index 4718939..4637ca0 100644 --- a/dpmhm/datasets/utils.py +++ b/dpmhm/datasets/utils.py @@ -60,6 +60,30 @@ def extract_by_category(ds:Dataset, labels:list) -> dict: return dp +def separate_dims_generator(ds:Dataset, key:str, *, axis:int=0) -> callable: + """Generator function for separating dimensions. + + This generator create 1d samples from a nd dataset. For example, samples of shape `(3, None)` will be separated to 3 samples of shape `(1, None)`. + + Args + ---- + ds: + input dataset with dictionary structure. + key: str + ds[key] is the signal to be divided. + """ + assert ds.element_spec[key].shape.ndims > 1 + + def _get_generator(): + for X in ds: + Y = X.copy() + for x in X[key]: + Y[key] = tf.reshape(x, [1,-1]) + # Y[key] = x # drop the channel dimension + yield Y + return _get_generator + + def split_signal_generator(ds:Dataset, key:str, n_chunk:int, *, axis:int=-1) -> callable: """Generator function for splitting a signal into chunks. diff --git a/dpmhm/datasets/xjtu/xjtu.py b/dpmhm/datasets/xjtu/xjtu.py index 3ea7c82..586adbb 100644 --- a/dpmhm/datasets/xjtu/xjtu.py +++ b/dpmhm/datasets/xjtu/xjtu.py @@ -90,7 +90,10 @@ # # 'file name pattern': # } -_DATA_URLS = ['https://sandbox.zenodo.org/record/1184368/files/xjtu.zip'] +_DATA_URLS = [ + # 'https://sandbox.zenodo.org/record/1184368/files/xjtu.zip' + 'https://zenodo.org/records/11545558/files/xjtu.zip?download=1' + ] class XJTU(tfds.core.GeneratorBasedBuilder): @@ -129,7 +132,13 @@ def _info(self) -> tfds.core.DatasetInfo: def _split_generators(self, dl_manager: tfds.download.DownloadManager): def _get_split_dict(datadir): - return {sp: (datadir/fn).rglob('*.csv') for sp, fn in _SPLIT_PATH_MATCH.items()} + # This doesn't work: + # return {sp: (datadir/fn).rglob('*.csv') for sp, fn in _SPLIT_PATH_MATCH.items()} + return { + 'condition1': next(datadir.rglob('35Hz12kN')).rglob('*.csv'), + 'condition2': next(datadir.rglob('37.5Hz11kN')).rglob('*.csv'), + 'condition3': next(datadir.rglob('40Hz10kN')).rglob('*.csv'), + } if dl_manager._manual_dir.exists(): # prefer to use manually downloaded data datadir = Path(dl_manager._manual_dir) diff --git a/dpmhm/models/sl/clae.py b/dpmhm/models/sl/clae.py index 3a6a59b..fbef004 100644 --- a/dpmhm/models/sl/clae.py +++ b/dpmhm/models/sl/clae.py @@ -6,7 +6,7 @@ CNN on the raw wavform. Input: fixed-length trunks of wavform -The original method can also be implemented in a self-supervised fashion by considering each wavform record a class. +The original method can also be implemented in a self-supervised fashion by considering each waveform record a class. """ diff --git a/dpmhm/models/sl/vggish.py b/dpmhm/models/sl/vggish.py index a4128e2..4e0f12b 100644 --- a/dpmhm/models/sl/vggish.py +++ b/dpmhm/models/sl/vggish.py @@ -18,7 +18,7 @@ - Hershey, S. et al. (2017) ‘CNN architectures for large-scale audio classification’, in 2017 ieee international conference on acoustics, speech and signal processing (icassp). IEEE, pp. 131–135. """ -import tensorflow as tf +# import tensorflow as tf from keras import layers, models #, regularizers from dataclasses import dataclass from .. import AbstractConfig