From 2197ad8b22a8a44b03287ff3e4ee5c66f4ed2587 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Wed, 7 Aug 2024 09:10:58 -0500 Subject: [PATCH] [SPARKNLP-856] Adding CamemBertForZeroShotClassification to ResourceDownloader --- .../camembert_for_zero_shot_classification.py | 6 +++--- .../CamemBertForZeroShotClassification.scala | 6 ++---- .../nlp/pretrained/ResourceDownloader.scala | 3 ++- .../nlp/pretrained/ResourceMetadata.scala | 21 ------------------- 4 files changed, 7 insertions(+), 29 deletions(-) diff --git a/python/sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py b/python/sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py index f32e5777f96e96..7b16c4475e5511 100644 --- a/python/sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +++ b/python/sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py @@ -34,7 +34,7 @@ class CamemBertForZeroShotClassification(AnnotatorModel, >>> sequenceClassifier = CamemBertForZeroShotClassification.pretrained() \\ ... .setInputCols(["token", "document"]) \\ ... .setOutputCol("label") - The default model is ``"deberta_base_zero_shot_classifier_mnli_anli_v3"``, if no name is + The default model is ``"camembert_zero_shot_classifier_xnli_onnx"``, if no name is provided. For available pretrained models please see the `Models Hub `__. @@ -179,14 +179,14 @@ def loadSavedModel(folder, spark_session): return CamemBertForZeroShotClassification(java_model=jModel) @staticmethod - def pretrained(name="camembert-base-xnli", lang="fr", remote_loc=None): + def pretrained(name="camembert_zero_shot_classifier_xnli_onnx", lang="fr", remote_loc=None): """Downloads and loads a pretrained model. Parameters ---------- name : str, optional Name of the pretrained model, by default - "camembert_base_sequence_classifier_allocine" + "camembert_zero_shot_classifier_xnli_onnx" lang : str, optional Language of the pretrained model, by default "fr" remote_loc : str, optional diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassification.scala index 0c70ff6d21ad50..4a5bcde0e87ef1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassification.scala @@ -284,10 +284,8 @@ class CamemBertForZeroShotClassification(override val uid: String) trait ReadPretrainedCamemBertForZeroShotClassification extends ParamsAndFeaturesReadable[CamemBertForZeroShotClassification] with HasPretrained[CamemBertForZeroShotClassification] { - override val defaultModelName: Some[String] = Some( - "camembert-zero-shot-classifier-xnli-onnx" - ) - override val defaultLang: String = "en" + override val defaultModelName: Some[String] = Some("camembert_zero_shot_classifier_xnli_onnx") + override val defaultLang: String = "fr" override def pretrained(): CamemBertForZeroShotClassification = super.pretrained() diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index e8f797e56e3238..9460f64b13c007 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -690,7 +690,8 @@ object PythonResourceDownloader { "MPNetForQuestionAnswering" -> MPNetForQuestionAnswering, "LLAMA2Transformer" -> LLAMA2Transformer, "M2M100Transformer" -> M2M100Transformer, - "UAEEmbeddings" -> UAEEmbeddings) + "UAEEmbeddings" -> UAEEmbeddings, + "CamemBertForZeroShotClassification" -> CamemBertForZeroShotClassification) // List pairs of types such as the one with key type can load a pretrained model from the value type val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering") diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala index e0905f92c6176f..992708e86c0992 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala @@ -109,27 +109,6 @@ object ResourceMetadata { candidates: List[ResourceMetadata], request: ResourceRequest): Option[ResourceMetadata] = { - val compatibleCandidatesName = candidates - .filter(item => - item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined - && item.name == request.name) - - val compatibleCandidatesLanguage = candidates - .filter(item => - item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined - && item.name == request.name - && (request.language.isEmpty || item.language.isEmpty || request.language.get == item.language.get) - ) - - val compatibleCandidatesSparkNLPVersion = - candidates - .filter(item => - item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined - && item.name == request.name - && (request.language.isEmpty || item.language.isEmpty || request.language.get == item.language.get) - && Version.isCompatible(request.libVersion, item.libVersion)) - - println("") val compatibleCandidates = candidates .filter(item => item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined