Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConcatDataset support #37

Open
jasonbian97 opened this issue Aug 5, 2021 · 1 comment
Open

ConcatDataset support #37

jasonbian97 opened this issue Aug 5, 2021 · 1 comment

Comments

@jasonbian97
Copy link

Thanks for the great work!

I try to combine two datasets by using "dataset = dataset1+dataset2", and it gives me such error:
AttributeError: 'ConcatDataset' object has no attribute 'get_labels'

Is there any workaround?

@jasonbian97
Copy link
Author

jasonbian97 commented Aug 5, 2021

Nvm, I found myself a workaround, pretty simple:

add two lines and one helper function:

def _get_labels(self, dataset):
        if self.callback_get_label:
            return self.callback_get_label(dataset)
        elif isinstance(dataset, torchvision.datasets.MNIST):
            return dataset.train_labels.tolist()
        elif isinstance(dataset, torchvision.datasets.ImageFolder):
            return [x[1] for x in dataset.imgs]
        elif isinstance(dataset, torchvision.datasets.DatasetFolder):
            return dataset.samples[:][1]
        elif isinstance(dataset, torch.utils.data.Subset):
            return dataset.dataset.imgs[:][1]
        elif isinstance(dataset, torch.utils.data.ConcatDataset): # added. add before next `elif` because ConcatDataset belong to torch.utils.data.Dataset
            return self._get_concat_labels(dataset) # added
        elif isinstance(dataset, torch.utils.data.Dataset):
            return dataset.get_labels()
        else:
            raise NotImplementedError

    def _get_concat_labels(self,concatdataset): # added
        dataset_list = concatdataset.datasets
        concat_labels = []
        for ds in dataset_list:
            concat_labels.extend(ds.get_labels())
        return concat_labels

Let me know if you have a better solution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant