diff --git a/rankers/modelling/cat.py b/rankers/modelling/cat.py index f0d8958..3747fd2 100644 --- a/rankers/modelling/cat.py +++ b/rankers/modelling/cat.py @@ -1,6 +1,4 @@ import pyterrier as pt -if not pt.started(): - pt.init() from transformers import PreTrainedModel, PretrainedConfig, PreTrainedTokenizer, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig from typing import Union import torch diff --git a/rankers/modelling/dot.py b/rankers/modelling/dot.py index 60d58fc..58f5c32 100644 --- a/rankers/modelling/dot.py +++ b/rankers/modelling/dot.py @@ -3,8 +3,6 @@ import torch from torch import nn import pyterrier as pt -if not pt.started(): - pt.init() from transformers import PreTrainedModel, PreTrainedTokenizer, PretrainedConfig, AutoModel, AutoTokenizer, AutoConfig from typing import Union import pandas as pd diff --git a/rankers/modelling/seq2seq.py b/rankers/modelling/seq2seq.py index 2c98c78..c9e0ca5 100644 --- a/rankers/modelling/seq2seq.py +++ b/rankers/modelling/seq2seq.py @@ -1,6 +1,4 @@ import pyterrier as pt -if not pt.started(): - pt.init() from transformers import PreTrainedModel, PreTrainedConfig, PreTrainedTokenizer, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM from typing import Union import torch