Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem loading cache in UnifiedDataset load_or_create_cache() function. #32

Open
aallaire91 opened this issue Oct 3, 2023 · 0 comments

Comments

@aallaire91
Copy link

aallaire91 commented Oct 3, 2023

Hello! I noticed an issue with loading a cache using the load_or_create_cache() function in UnifiedDataset. In the following code snippet, you can see the keep_ids is not defined when cache_path exists, as is the case when a cache has already been created and you are just trying to load it. After the if statement, it is expected that keep_ids exists in order to remove the undesired entries from the data index. See the line self.remove_elements(keep_ids=keep_ids). It seems like keep_mask, which is the one of the outputs of dill.load(f,encoding="latin1"), should be renamed to keep_ids in order to fix this issue.

def load_or_create_cache(
        self, cache_path: str, num_workers=0, filter_fn=None
    ) -> None:
        if isfile(cache_path):
            print(f"Loading cache from {cache_path} ...", end="")
            t = time.time()
            with open(cache_path, "rb") as f:
                self._cached_batch_elements, keep_mask = dill.load(f, encoding="latin1")
            print(f" done in {time.time() - t:.1f}s.")

        else:
            # Build cache
            cached_batch_elements = []
            keep_ids = []

            if num_workers <= 0:
                cache_data_iterator = self
            else:
                # Use DataLoader as a generic multiprocessing framework.
                # We set batchsize=1 and a custom collate function.
                # In effect this will just call self.__getitem__ in parallel.
                cache_data_iterator = DataLoader(
                    self,
                    batch_size=1,
                    num_workers=num_workers,
                    shuffle=False,
                    collate_fn=lambda xlist: xlist[0],
                )

            for element in tqdm(
                cache_data_iterator,
                desc=f"Caching batch elements ({num_workers} CPUs): ",
                disable=False,
            ):
                if filter_fn is None or filter_fn(element):
                    cached_batch_elements.append(element)
                    keep_ids.append(element.data_index)

            # Just deletes the variable cache_data_iterator,
            # not self (in case it is set to that)!
            del cache_data_iterator

            print(f"Saving cache to {cache_path} ....", end="")
            t = time.time()
            with open(cache_path, "wb") as f:
                dill.dump((cached_batch_elements, keep_ids), f)
            print(f" done in {time.time() - t:.1f}s.")

            self._cached_batch_elements = cached_batch_elements

        # Remove unwanted elements
        self.remove_elements(keep_ids=keep_ids)

        # Verify
        if len(self._cached_batch_elements) != self._data_len:
            raise ValueError("Current data and cached data lengths do not match!")
@aallaire91 aallaire91 changed the title Problem loading cache in UnifiedDataset load_or_create_cache() function. Problem loading cache in UnifiedDataset load_or_create_cache() function. Oct 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant