You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! I noticed an issue with loading a cache using the load_or_create_cache() function in UnifiedDataset. In the following code snippet, you can see the keep_ids is not defined when cache_path exists, as is the case when a cache has already been created and you are just trying to load it. After the if statement, it is expected that keep_ids exists in order to remove the undesired entries from the data index. See the line self.remove_elements(keep_ids=keep_ids). It seems like keep_mask, which is the one of the outputs of dill.load(f,encoding="latin1"), should be renamed to keep_ids in order to fix this issue.
def load_or_create_cache(
self, cache_path: str, num_workers=0, filter_fn=None
) -> None:
if isfile(cache_path):
print(f"Loading cache from {cache_path} ...", end="")
t = time.time()
with open(cache_path, "rb") as f:
self._cached_batch_elements, keep_mask = dill.load(f, encoding="latin1")
print(f" done in {time.time() - t:.1f}s.")
else:
# Build cache
cached_batch_elements = []
keep_ids = []
if num_workers <= 0:
cache_data_iterator = self
else:
# Use DataLoader as a generic multiprocessing framework.
# We set batchsize=1 and a custom collate function.
# In effect this will just call self.__getitem__ in parallel.
cache_data_iterator = DataLoader(
self,
batch_size=1,
num_workers=num_workers,
shuffle=False,
collate_fn=lambda xlist: xlist[0],
)
for element in tqdm(
cache_data_iterator,
desc=f"Caching batch elements ({num_workers} CPUs): ",
disable=False,
):
if filter_fn is None or filter_fn(element):
cached_batch_elements.append(element)
keep_ids.append(element.data_index)
# Just deletes the variable cache_data_iterator,
# not self (in case it is set to that)!
del cache_data_iterator
print(f"Saving cache to {cache_path} ....", end="")
t = time.time()
with open(cache_path, "wb") as f:
dill.dump((cached_batch_elements, keep_ids), f)
print(f" done in {time.time() - t:.1f}s.")
self._cached_batch_elements = cached_batch_elements
# Remove unwanted elements
self.remove_elements(keep_ids=keep_ids)
# Verify
if len(self._cached_batch_elements) != self._data_len:
raise ValueError("Current data and cached data lengths do not match!")
The text was updated successfully, but these errors were encountered:
aallaire91
changed the title
Problem loading cache in UnifiedDataset load_or_create_cache() function.
Problem loading cache in UnifiedDatasetload_or_create_cache() function.
Oct 3, 2023
Hello! I noticed an issue with loading a cache using the
load_or_create_cache()
function inUnifiedDataset
. In the following code snippet, you can see thekeep_ids
is not defined whencache_path
exists, as is the case when a cache has already been created and you are just trying to load it. After theif
statement, it is expected thatkeep_ids
exists in order to remove the undesired entries from the data index. See the lineself.remove_elements(keep_ids=keep_ids)
. It seems likekeep_mask
, which is the one of the outputs ofdill.load(f,encoding="latin1")
, should be renamed tokeep_ids
in order to fix this issue.The text was updated successfully, but these errors were encountered: