diff --git a/pastastore/store.py b/pastastore/store.py index b2a2954..20eabde 100644 --- a/pastastore/store.py +++ b/pastastore/store.py @@ -1496,7 +1496,7 @@ def search( case_sensitive: bool = True, sort=True, ): - """Search for names of time series or models starting with `s`. + """Search for names of time series or models containing string `s`. Parameters ---------- @@ -1515,30 +1515,45 @@ def search( list of names that match search result """ if libname == "models": - lib_names = self.model_names + lib_names = {"models": self.model_names} elif libname == "stresses": - lib_names = self.stresses_names + lib_names = {"stresses": self.stresses_names} elif libname == "oseries": - lib_names = self.oseries_names + lib_names = {"oseries": self.oseries_names} + elif libname is None: + lib_names = { + "oseries": self.oseries_names, + "stresses": self.stresses_names, + "models": self.model_names, + } else: raise ValueError("Provide valid libname: 'models', 'stresses' or 'oseries'") - if isinstance(s, str): - if case_sensitive: - matches = [n for n in lib_names if s in n] - else: - matches = [n for n in lib_names if s.lower() in n.lower()] - if isinstance(s, list): - m = np.array([]) - for sub in s: + result = {} + for lib, names in lib_names.items(): + if isinstance(s, str): if case_sensitive: - m = np.append(m, [n for n in lib_names if sub in n]) + matches = [n for n in names if s in n] else: - m = np.append(m, [n for n in lib_names if sub.lower() in n.lower()]) - matches = list(np.unique(m)) - if sort: - matches.sort() - return matches + matches = [n for n in names if s.lower() in n.lower()] + elif isinstance(s, list): + m = np.array([]) + for sub in s: + if case_sensitive: + m = np.append(m, [n for n in names if sub in n]) + else: + m = np.append(m, [n for n in names if sub.lower() in n.lower()]) + matches = list(np.unique(m)) + else: + raise TypeError("s must be str or list of str!") + if sort: + matches.sort() + result[lib] = matches + + if len(result) == 1: + return result[lib] + else: + return result def get_model_timeseries_names( self,