diff --git a/pyproject.toml b/pyproject.toml index 96b4afb2..d6cbd50a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = 'mmda' -version = '0.9.5' +version = '0.9.6' description = 'MMDA - multimodal document analysis' authors = [ {name = 'Allen Institute for Artificial Intelligence', email = 'contact@allenai.org'}, diff --git a/src/mmda/predictors/hf_predictors/mention_predictor.py b/src/mmda/predictors/hf_predictors/mention_predictor.py index bade051b..ba77b14f 100644 --- a/src/mmda/predictors/hf_predictors/mention_predictor.py +++ b/src/mmda/predictors/hf_predictors/mention_predictor.py @@ -107,7 +107,7 @@ def predict_page(self, page: Annotation, counter: Iterator[int], print_warnings: ) batch.to(self.model.device) batch_outputs = self.model(**batch) - batch_prediction_label_ids = torch.argmax(batch_outputs.logits, dim=-1)[0] + batch_prediction_label_ids = torch.argmax(batch_outputs.logits, dim=-1).tolist()[0] prediction_label_ids.append(batch_prediction_label_ids) def has_label_id(lbls: List[int], want_label_id: int) -> bool: