From 967338ad49ff663fd4d99d38263d9a1380d69dcd Mon Sep 17 00:00:00 2001 From: Arnab Date: Tue, 18 Jul 2023 12:21:30 -0700 Subject: [PATCH] ENH: some more enhancemenst to load_from_cache --- pylids/select_frames.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pylids/select_frames.py b/pylids/select_frames.py index c57eb0e..326a610 100644 --- a/pylids/select_frames.py +++ b/pylids/select_frames.py @@ -52,7 +52,8 @@ def select_augmentations(trn_fls, tst_fls, aug_fls, kmeans_type='batch', return_min_rand_frames = False, num_augs_per_cnd = 2080, - guess_n_clusters=None): + guess_n_clusters=None, + cache_loc = './pylids_cache/'): """Summary Args: @@ -73,12 +74,13 @@ def select_augmentations(trn_fls, tst_fls, aug_fls, print('Run time and RAM required scales with dataset size! \n') print('Make sure you have enough RAM for this process, else, downsample your data using k-means clustering... \n') model = ResNet50(weights="imagenet", include_top=False) - cache_loc = './cache/' if not os.path.exists(cache_loc): os.makedirs(cache_loc) # extracting resnet features for training dataset images - trn_data = input("Enter the name of the train dataset: giw_edited, giw_eyelids, giw_eyelids_pupils \n") + print('Files in cache: \n') + print(glob(cache_loc+'*.npy')) + trn_data = input("Enter the name of the train dataset: \n") if os.path.isfile(os.path.join(cache_loc, trn_data+'.npy')): print('Loading train features from cache') trn_rnfs = np.load(os.path.join(cache_loc, trn_data+'.npy')) @@ -89,7 +91,9 @@ def select_augmentations(trn_fls, tst_fls, aug_fls, av_trn_fs = np.mean(trn_rnfs,axis=0) # extracting resnet features for test dataset images - tst_data = input("Enter the name of the test dataset: santini, vedb, lpw \n") + print('Files in cache: \n') + print(glob(cache_loc+'*.npy')) + tst_data = input("Enter the name of the test dataset: \n") if os.path.isfile(os.path.join(cache_loc,tst_data+'.npy')): print('Loading test features from cache') tst_rnfs = np.load(os.path.join(cache_loc,tst_data+'.npy')) @@ -249,7 +253,7 @@ def select_frames_to_label(trn_fls=None, tst_fls=None, assert trn_fls is not None, 'Test files should be provided' else: # extracting resnet features for test dataset images - tst_data = input("Enter the name of the test dataset: santini, vedb, lpw, fMRI \n") + tst_data = input("Enter the name of the test dataset:\n") if os.path.isfile(os.path.join(cache_loc,tst_data+'.npy')): print('Loading test features from cache') tst_rnfs = np.load(os.path.join(cache_loc,tst_data+'.npy'))