Skip to content

Commit

Permalink
[SPARKNLP-856] Adding CamemBertForZeroShotClassification to ResourceD…
Browse files Browse the repository at this point in the history
…ownloader
  • Loading branch information
danilojsl committed Aug 7, 2024
1 parent 79643e9 commit 2197ad8
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CamemBertForZeroShotClassification(AnnotatorModel,
>>> sequenceClassifier = CamemBertForZeroShotClassification.pretrained() \\
... .setInputCols(["token", "document"]) \\
... .setOutputCol("label")
The default model is ``"deberta_base_zero_shot_classifier_mnli_anli_v3"``, if no name is
The default model is ``"camembert_zero_shot_classifier_xnli_onnx"``, if no name is
provided.
For available pretrained models please see the `Models Hub
<https://sparknlp.orgtask=Text+Classification>`__.
Expand Down Expand Up @@ -179,14 +179,14 @@ def loadSavedModel(folder, spark_session):
return CamemBertForZeroShotClassification(java_model=jModel)

@staticmethod
def pretrained(name="camembert-base-xnli", lang="fr", remote_loc=None):
def pretrained(name="camembert_zero_shot_classifier_xnli_onnx", lang="fr", remote_loc=None):
"""Downloads and loads a pretrained model.
Parameters
----------
name : str, optional
Name of the pretrained model, by default
"camembert_base_sequence_classifier_allocine"
"camembert_zero_shot_classifier_xnli_onnx"
lang : str, optional
Language of the pretrained model, by default "fr"
remote_loc : str, optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,8 @@ class CamemBertForZeroShotClassification(override val uid: String)
trait ReadPretrainedCamemBertForZeroShotClassification
extends ParamsAndFeaturesReadable[CamemBertForZeroShotClassification]
with HasPretrained[CamemBertForZeroShotClassification] {
override val defaultModelName: Some[String] = Some(
"camembert-zero-shot-classifier-xnli-onnx"
)
override val defaultLang: String = "en"
override val defaultModelName: Some[String] = Some("camembert_zero_shot_classifier_xnli_onnx")
override val defaultLang: String = "fr"

override def pretrained(): CamemBertForZeroShotClassification = super.pretrained()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ object PythonResourceDownloader {
"MPNetForQuestionAnswering" -> MPNetForQuestionAnswering,
"LLAMA2Transformer" -> LLAMA2Transformer,
"M2M100Transformer" -> M2M100Transformer,
"UAEEmbeddings" -> UAEEmbeddings)
"UAEEmbeddings" -> UAEEmbeddings,
"CamemBertForZeroShotClassification" -> CamemBertForZeroShotClassification)

// List pairs of types such as the one with key type can load a pretrained model from the value type
val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,6 @@ object ResourceMetadata {
candidates: List[ResourceMetadata],
request: ResourceRequest): Option[ResourceMetadata] = {

val compatibleCandidatesName = candidates
.filter(item =>
item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined
&& item.name == request.name)

val compatibleCandidatesLanguage = candidates
.filter(item =>
item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined
&& item.name == request.name
&& (request.language.isEmpty || item.language.isEmpty || request.language.get == item.language.get)
)

val compatibleCandidatesSparkNLPVersion =
candidates
.filter(item =>
item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined
&& item.name == request.name
&& (request.language.isEmpty || item.language.isEmpty || request.language.get == item.language.get)
&& Version.isCompatible(request.libVersion, item.libVersion))

println("")
val compatibleCandidates = candidates
.filter(item =>
item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined
Expand Down

0 comments on commit 2197ad8

Please sign in to comment.