diff --git a/app.py b/app.py index 608f63a..8544bff 100644 --- a/app.py +++ b/app.py @@ -116,8 +116,10 @@ def _extract_images(self, video): def _classify(self, extracted: list, positions: list, total_ms: int): t = time.perf_counter() - self.logger.info(f"Initiating classifier with {self.configs['modelName']}") - classifier = classify.Classifier(default_model_storage / self.configs['modelName'], + model_checkpoint_name = next(default_model_storage.glob( + f"*.{self.configs['modelName']}.pos{'T' if self.configs['usePosModel'] else 'F'}.pt")) + self.logger.info(f"Initiating classifier with {model_checkpoint_name.stem}") + classifier = classify.Classifier(model_checkpoint_name, self.logger.name if self.logger.isEnabledFor(logging.DEBUG) else None) if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f"Classifier initiation took {time.perf_counter() - t:.2f} seconds") diff --git a/metadata.py b/metadata.py index 43c808a..303aa46 100644 --- a/metadata.py +++ b/metadata.py @@ -70,8 +70,11 @@ def appmetadata() -> AppMetadata: metadata.add_parameter( name='modelName', type='string', default='convnext_lg', - choices=[m.stem[16:] for m in available_models], + choices=[m.stem.split('.')[1] for m in available_models], description='model name to use for classification') + metadata.add_parameter( + name='usePosModel', type='boolean', default=True, + description='Use the model trained with positional features') metadata.add_parameter( name='useStitcher', type='boolean', default=True, description='Use the stitcher after classifying the TimePoints')