Skip to content

Commit

Permalink
parameterized model selections for positional encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
keighrim committed Jul 10, 2024
1 parent e819e77 commit 85d18ee
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def _extract_images(self, video):

def _classify(self, extracted: list, positions: list, total_ms: int):
t = time.perf_counter()
self.logger.info(f"Initiating classifier with {self.configs['modelName']}")
classifier = classify.Classifier(default_model_storage / self.configs['modelName'],
model_checkpoint_name = next(default_model_storage.glob(
f"*.{self.configs['modelName']}.pos{'T' if self.configs['usePosModel'] else 'F'}.pt"))
self.logger.info(f"Initiating classifier with {model_checkpoint_name.stem}")
classifier = classify.Classifier(model_checkpoint_name,
self.logger.name if self.logger.isEnabledFor(logging.DEBUG) else None)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Classifier initiation took {time.perf_counter() - t:.2f} seconds")
Expand Down
5 changes: 4 additions & 1 deletion metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def appmetadata() -> AppMetadata:
metadata.add_parameter(
name='modelName', type='string',
default='convnext_lg',
choices=[m.stem[16:] for m in available_models],
choices=[m.stem.split('.')[1] for m in available_models],
description='model name to use for classification')
metadata.add_parameter(
name='usePosModel', type='boolean', default=True,
description='Use the model trained with positional features')
metadata.add_parameter(
name='useStitcher', type='boolean', default=True,
description='Use the stitcher after classifying the TimePoints')
Expand Down

0 comments on commit 85d18ee

Please sign in to comment.