diff --git a/utils/datasets/datasets.py b/utils/datasets/datasets.py index 8ad8c68..2971b93 100644 --- a/utils/datasets/datasets.py +++ b/utils/datasets/datasets.py @@ -55,7 +55,8 @@ def name(self): @property def n_categories(self): if self._n_categories is None: - assert self.training_set.target.dtype in (int, str), \ + target_dtype = self.training_set.target.dtype + assert (np.issubdtype(target_dtype, int) or np.issubdtype(target_dtype, str)), \ 'n_categories is only a valid attribute when target data is int or str. It is %s' \ % (self.training_set.target.dtype, ) self._n_categories = len(np.unique(self.training_set.target))