Skip to content

Commit

Permalink
ENH: some more enhancemenst to load_from_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
arnabiswas committed Jul 18, 2023
1 parent 94aa983 commit 967338a
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions pylids/select_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def select_augmentations(trn_fls, tst_fls, aug_fls,
kmeans_type='batch',
return_min_rand_frames = False,
num_augs_per_cnd = 2080,
guess_n_clusters=None):
guess_n_clusters=None,
cache_loc = './pylids_cache/'):
"""Summary
Args:
Expand All @@ -73,12 +74,13 @@ def select_augmentations(trn_fls, tst_fls, aug_fls,
print('Run time and RAM required scales with dataset size! \n')
print('Make sure you have enough RAM for this process, else, downsample your data using k-means clustering... \n')
model = ResNet50(weights="imagenet", include_top=False)
cache_loc = './cache/'
if not os.path.exists(cache_loc):
os.makedirs(cache_loc)

# extracting resnet features for training dataset images
trn_data = input("Enter the name of the train dataset: giw_edited, giw_eyelids, giw_eyelids_pupils \n")
print('Files in cache: \n')
print(glob(cache_loc+'*.npy'))
trn_data = input("Enter the name of the train dataset: \n")
if os.path.isfile(os.path.join(cache_loc, trn_data+'.npy')):
print('Loading train features from cache')
trn_rnfs = np.load(os.path.join(cache_loc, trn_data+'.npy'))
Expand All @@ -89,7 +91,9 @@ def select_augmentations(trn_fls, tst_fls, aug_fls,
av_trn_fs = np.mean(trn_rnfs,axis=0)

# extracting resnet features for test dataset images
tst_data = input("Enter the name of the test dataset: santini, vedb, lpw \n")
print('Files in cache: \n')
print(glob(cache_loc+'*.npy'))
tst_data = input("Enter the name of the test dataset: \n")
if os.path.isfile(os.path.join(cache_loc,tst_data+'.npy')):
print('Loading test features from cache')
tst_rnfs = np.load(os.path.join(cache_loc,tst_data+'.npy'))
Expand Down Expand Up @@ -249,7 +253,7 @@ def select_frames_to_label(trn_fls=None, tst_fls=None,
assert trn_fls is not None, 'Test files should be provided'
else:
# extracting resnet features for test dataset images
tst_data = input("Enter the name of the test dataset: santini, vedb, lpw, fMRI \n")
tst_data = input("Enter the name of the test dataset:\n")
if os.path.isfile(os.path.join(cache_loc,tst_data+'.npy')):
print('Loading test features from cache')
tst_rnfs = np.load(os.path.join(cache_loc,tst_data+'.npy'))
Expand Down

0 comments on commit 967338a

Please sign in to comment.