diff --git a/rankers/modelling/pyterrier/sparse.py b/rankers/modelling/pyterrier/sparse.py index 468dd79..1bb0527 100644 --- a/rankers/modelling/pyterrier/sparse.py +++ b/rankers/modelling/pyterrier/sparse.py @@ -32,6 +32,10 @@ def __init__(self, model_name_or_path, device=None, batch_size=32, text_field='t self.text_field = text_field self.topk = topk + @classmethod + def from_pretrained(cls, model_name_or_path, device=None, batch_size=32, text_field='text', fp16=False, topk=None): + return cls(model_name_or_path, device, batch_size, text_field, fp16, topk) + def encode_queries(self, texts, out_fmt='dict', topk=None): outputs = [] if out_fmt != 'dict':