Skip to content

Commit

Permalink
MAINT: Typos and misc bugs fixed for load_from_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
arnabiswas committed Jul 18, 2023
1 parent 967338a commit 2ea523c
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions pylids/select_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,11 @@ def select_augmentations(trn_fls, tst_fls, aug_fls,
Returns:
TYPE: Description
"""
# expects .png folders for test and train files
# returns a list with path to frames to label

print('This code has a run time of about 40 mins, go make yourself some tea! \n')
print('Run time and RAM required scales with dataset size! \n')
print('Make sure you have enough RAM for this process, else, downsample your data using k-means clustering... \n')

print('This can take some time to run, go make yourself some tea!')
print('\n Run time and RAM required scales with dataset size!')
print('\n Track memory usage and make sure you have enough RAM for this process.')
print('\n else, downsample your data using k-means clustering within participants...')
model = ResNet50(weights="imagenet", include_top=False)
if not os.path.exists(cache_loc):
os.makedirs(cache_loc)
Expand Down Expand Up @@ -215,17 +214,16 @@ def select_frames_to_label(trn_fls=None, tst_fls=None,
num_frames (int, optional): Description
kmeans_batch_size (int, optional): Description
kmeans_type (str, optional): default, batch
load_from_cache (bool, optional): set to True if you want to load trn_fls or tst_fls from cache
load_from_cache (bool, optional): set to True if you want to load/save
trn_fls and tst_fls from cache
Returns:
TYPE: Description
"""
# expects .png folders for test and train files
# returns a list with path to frames to label

print('This code has a run time of about 40 mins, go make yourself some tea!')
print('This can take some time to run, go make yourself some tea!')
print('\n Run time and RAM required scales with dataset size!')
print('\n Make sure you have enough RAM for this process else, downsample your data using k-means clustering within participants...')
print('\n Track memory usage and make sure you have enough RAM for this process.')
print('\n else, downsample your data using k-means clustering within participants...')

model = ResNet50(weights="imagenet", include_top=False)
if not os.path.exists(cache_loc) and load_from_cache:
Expand All @@ -245,22 +243,28 @@ def select_frames_to_label(trn_fls=None, tst_fls=None,
np.save(os.path.join(cache_loc, trn_data+'.npy'), trn_rnfs)
else:
trn_rnfs = get_rnfs_from_list(trn_fls, model)
np.save(os.path.join(cache_loc, trn_data+'.npy'), trn_rnfs)

if return_min_rand_frames:
av_trn_fs = np.mean(trn_rnfs,axis=0)

if tst_fls is None:
assert trn_fls is not None, 'Test files should be provided'
else:
# extracting resnet features for test dataset images
tst_data = input("Enter the name of the test dataset:\n")
if os.path.isfile(os.path.join(cache_loc,tst_data+'.npy')):
print('Loading test features from cache')
tst_rnfs = np.load(os.path.join(cache_loc,tst_data+'.npy'))
if load_from_cache:
print('Files in cache: \n')
print(glob(cache_loc+'*.npy'))

tst_data = input("Enter the name of the test dataset:\n")
if os.path.isfile(os.path.join(cache_loc,tst_data+'.npy')):
print('Loading test features from cache')
tst_rnfs = np.load(os.path.join(cache_loc,tst_data+'.npy'))
else:
tst_rnfs = get_rnfs_from_list(tst_fls, model)
np.save(os.path.join(cache_loc,tst_data+'.npy'), tst_rnfs)
else:
tst_rnfs = get_rnfs_from_list(tst_fls, model)
np.save(os.path.join(cache_loc,tst_data+'.npy'), tst_rnfs)


# Iterative k means which keeps running till we find a given number of frames to label
# from the test dataset / set of augmented images
n_clusters = n_frames
Expand Down

0 comments on commit 2ea523c

Please sign in to comment.