diff --git a/delft/applications/citationClassifier.py b/delft/applications/citationClassifier.py index 4d954c9b..7ced0bba 100644 --- a/delft/applications/citationClassifier.py +++ b/delft/applications/citationClassifier.py @@ -122,7 +122,7 @@ def classify(texts, output_format, architecture="gru", embeddings_name=None, tra args = parser.parse_args() if args.action not in ('train', 'train_eval', 'classify'): - print('action not specifed, must be one of [train,train_eval,classify]') + print('action not specified, must be one of [train,train_eval,classify]') embeddings_name = args.embedding transformer = args.transformer diff --git a/delft/applications/textClassifier.py b/delft/applications/textClassifier.py new file mode 100644 index 00000000..979b5427 --- /dev/null +++ b/delft/applications/textClassifier.py @@ -0,0 +1,294 @@ +import argparse +import sys +import time + +import numpy as np +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder, OneHotEncoder + +from delft.textClassification import Classifier +from delft.textClassification.models import architectures +from delft.textClassification.reader import load_texts_and_classes_generic +from delft.utilities.Utilities import t_or_f + +pretrained_transformers_examples = ['bert-base-cased', 'bert-large-cased', 'allenai/scibert_scivocab_cased'] + +actions = ['train', 'train_eval', 'eval', 'classify'] + + +def get_one_hot(y): + label_encoder = LabelEncoder() + integer_encoded = label_encoder.fit_transform(y) + onehot_encoder = OneHotEncoder(sparse=False) + integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) + y2 = onehot_encoder.fit_transform(integer_encoded) + return y2 + + + +def configure(architecture, max_sequence_length_=-1, batch_size_=-1, max_epoch_=-1, patience_=-1, early_stop=True): + batch_size = 256 + maxlen = 150 if max_sequence_length_ == -1 else max_sequence_length_ + patience = 5 if patience_ == -1 else patience_ + max_epoch = 60 if max_epoch_ == -1 else max_epoch_ + + # default bert model parameters + if architecture == "bert": + batch_size = 32 + # early_stop = False + # max_epoch = 3 + + batch_size = batch_size_ if batch_size_ != -1 else batch_size + + return batch_size, maxlen, patience, early_stop, max_epoch + + +def train(model_name, + architecture, + input_file, + embeddings_name, + fold_count, + transformer=None, + x_index=0, + y_indexes=[1], + batch_size=-1, + max_sequence_length=-1, + patience=-1, + incremental=False, + learning_rate=None, + multi_gpu=False, + max_epoch=50, + early_stop=True + ): + + batch_size, maxlen, patience, early_stop, max_epoch = configure(architecture, + max_sequence_length, + batch_size, + max_epoch, + patience, + early_stop=early_stop) + + print('loading ' + model_name + ' training corpus...') + xtr, y = load_texts_and_classes_generic(input_file, x_index, y_indexes) + + list_classes = list(set([y_[0] for y_ in y])) + + model = Classifier(model_name, + architecture=architecture, + list_classes=list_classes, + max_epoch=max_epoch, + fold_number=fold_count, + patience=patience, + transformer_name=transformer, + use_roc_auc=True, + embeddings_name=embeddings_name, + early_stop=early_stop, + batch_size=batch_size, + maxlen=maxlen, + class_weights=None, + learning_rate=learning_rate) + + y_ = get_one_hot(y) + + if fold_count == 1: + model.train(xtr, y_, incremental=incremental, multi_gpu=multi_gpu) + else: + model.train_nfold(xtr, y_, multi_gpu=multi_gpu) + # saving the model + model.save() + + +def eval(model_name, architecture, input_file, x_index=0, y_indexes=[1]): + # model_name += model_name + '-' + architecture + + print('loading ' + model_name + ' evaluation corpus...') + + xtr, y = load_texts_and_classes_generic(input_file, x_index, y_indexes) + print(len(xtr), 'evaluation sequences') + + model = Classifier(model_name, architecture=architecture) + model.load() + + y_ = get_one_hot(y) + + model.eval(xtr, y_) + + +def train_and_eval(model_name, architecture, input_file, embeddings_name, fold_count, transformer=None, + x_index=0, y_indexes=[1], batch_size=-1, + max_sequence_length=-1, patience=-1, multi_gpu=False): + batch_size, maxlen, patience, early_stop, max_epoch = configure(architecture, batch_size, max_sequence_length, + patience) + + print('loading ' + model_name + ' corpus...') + xtr, y = load_texts_and_classes_generic(input_file, x_index, y_indexes) + + list_classes = list(set([y_[0] for y_ in y])) + + y_one_hot = get_one_hot(y) + + model = Classifier(model_name, architecture=architecture, list_classes=list_classes, max_epoch=max_epoch, + fold_number=fold_count, patience=patience, transformer_name=transformer, + use_roc_auc=True, embeddings_name=embeddings_name, early_stop=early_stop, + batch_size=batch_size, maxlen=maxlen, class_weights=None) + + # segment train and eval sets + x_train, x_test, y_train, y_test = train_test_split(xtr, y_one_hot, test_size=0.1) + + if fold_count == 1: + model.train(x_train, y_train, multi_gpu=multi_gpu) + else: + model.train_nfold(x_train, y_train, multi_gpu=multi_gpu) + + model.eval(x_test, y_test) + + # saving the model + model.save() + + +# classify a list of texts +def classify(model_name, architecture, texts, output_format='json'): + model = Classifier(model_name, architecture=architecture) + model.load() + + start_time = time.time() + result = model.predict(texts, output_format) + runtime = round(time.time() - start_time, 3) + + if output_format == 'json': + result["runtime"] = runtime + else: + print("runtime: %s seconds " % (runtime)) + + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="General classification of text ") + + parser.add_argument("action", help="the action", choices=actions) + parser.add_argument("model", help="The name of the model") + parser.add_argument("--fold-count", type=int, default=1) + parser.add_argument("--input", type=str, required=True, + help="The file to be used for train, train_eval, eval and, classify") + parser.add_argument("--x-index", type=int, required=True, help="Index of the columns for the X value " + "(assuming a TSV file)") + parser.add_argument("--y-indexes", type=str, required=False, help="Index(es) of the columns for the Y (classes) " + "separated by comma, without spaces (assuming " + "a TSV file)") + parser.add_argument("--architecture", default='gru', choices=architectures, + help="type of model architecture to be used, one of " + str(architectures)) + parser.add_argument( + "--embedding", default='word2vec', + help=( + "The desired pre-trained word embeddings using their descriptions in the file" + " embedding-registry.json." + " Be sure to use here the same name as in the registry ('glove-840B', 'fasttext-crawl', 'word2vec')," + " and that the path in the registry to the embedding file is correct on your system.")) + + parser.add_argument( + "--transformer", + default=None, + help="The desired pre-trained transformer to be used in the selected architecture. " + \ + "For local loading use, delft/resources-registry.json, and be sure to use here the " + "same name as in the registry, e.g. " + \ + str(pretrained_transformers_examples) + \ + " and that the path in the registry to the model path is correct on your system. " + \ + "HuggingFace transformers hub will be used otherwise to fetch the model, " + "see https://huggingface.co/models " + \ + "for model names" + ) + + parser.add_argument("--max-sequence-length", type=int, default=-1, help="max-sequence-length parameter to be used.") + parser.add_argument("--batch-size", type=int, default=-1, help="batch-size parameter to be used.") + parser.add_argument("--patience", type=int, default=-1, help="patience, number of extra epochs to perform after " + "the best epoch before stopping a training.") + parser.add_argument("--learning-rate", type=float, default=None, help="Initial learning rate") + parser.add_argument("--incremental", action="store_true", help="training is incremental, starting from existing model if present") + parser.add_argument("--max-epoch", type=int, default=-1, + help="Maximum number of epochs for training.") + parser.add_argument("--early-stop", type=t_or_f, default=None, + help="Force early training termination when metrics scores are not improving " + + "after a number of epochs equals to the patience parameter.") + + parser.add_argument("--multi-gpu", default=False, + help="Enable the support for distributed computing (the batch size needs to be set accordingly using --batch-size)", + action="store_true") + + args = parser.parse_args() + + embeddings_name = args.embedding + input_file = args.input + model_name = args.model + transformer = args.transformer + architecture = args.architecture + x_index = args.x_index + patience = args.patience + batch_size = args.batch_size + incremental = args.incremental + max_sequence_length = args.max_sequence_length + learning_rate = args.learning_rate + max_epoch = args.max_epoch + early_stop = args.early_stop + multi_gpu = args.multi_gpu + + if args.action != "classify": + if args.y_indexes is None: + print("--y-indexes is mandatory") + sys.exit(-1) + y_indexes = [int(index) for index in args.y_indexes.split(",")] + + if len(y_indexes) > 1: + print("At the moment we support just one value per class. Taking the first value only. ") + y_indexes = y_indexes[0] + + if transformer is None and embeddings_name is None: + # default word embeddings + embeddings_name = "glove-840B" + + if args.action == 'train': + train(model_name, architecture, input_file, embeddings_name, args.fold_count, + transformer=transformer, + x_index=x_index, + y_indexes=y_indexes, + batch_size=batch_size, + incremental=incremental, + max_sequence_length=max_sequence_length, + patience=patience, + learning_rate=learning_rate, + max_epoch=max_epoch, + early_stop=early_stop, + multi_gpu=multi_gpu) + + elif args.action == 'eval': + eval(model_name, architecture, input_file, x_index=x_index, y_indexes=y_indexes) + + elif args.action == 'train_eval': + if args.fold_count < 1: + raise ValueError("fold-count should be equal or more than 1") + + train_and_eval(model_name, + architecture, + input_file, + embeddings_name, + args.fold_count, + transformer=transformer, + x_index=x_index, + y_indexes=y_indexes, + batch_size=batch_size, + max_sequence_length=max_sequence_length, + patience=patience, + multi_gpu=multi_gpu) + + elif args.action == 'classify': + lines, _ = load_texts_and_classes_generic(input_file, x_index, None) + + result = classify(model_name, lines, "csv") + + result_binary = [np.argmax(line) for line in result] + + for x in result_binary: + print(x) + # See https://github.com/tensorflow/tensorflow/issues/3388 + # K.clear_session() diff --git a/delft/sequenceLabelling/data_generator.py b/delft/sequenceLabelling/data_generator.py index 3d079dcc..15d661b4 100644 --- a/delft/sequenceLabelling/data_generator.py +++ b/delft/sequenceLabelling/data_generator.py @@ -332,7 +332,7 @@ def __data_generation(self, index): # to have input as sentence piece token index for transformer layer input_ids, token_type_ids, attention_mask, input_chars, input_features, input_labels, input_offsets = self.bert_preprocessor.tokenize_and_align_features_and_labels( - x_tokenized, + x_tokenized, batch_c, sub_f, batch_y, diff --git a/delft/sequenceLabelling/preprocess.py b/delft/sequenceLabelling/preprocess.py index 2f33e521..4245dc38 100644 --- a/delft/sequenceLabelling/preprocess.py +++ b/delft/sequenceLabelling/preprocess.py @@ -203,7 +203,7 @@ def transform(self, X, extend=False): out.append([0] * features_count) features_vector_padded, _ = pad_sequences(features_vector, [0] * features_count) - output = np.asarray(features_vector_padded) + output = np.asarray(features_vector_padded, dtype='object') return output @@ -757,13 +757,13 @@ def pad_sequence(self, char_ids, labels=None, label_indices=False): labels_final = None if labels: labels_padded, _ = pad_sequences(labels, 0) - labels_final = np.asarray(labels_padded) + labels_final = np.asarray(labels_padded, dtype='object') if not label_indices: labels_final = dense_to_one_hot(labels_final, len(self.vocab_tag), nlevels=2) #if self.return_chars: char_ids, word_lengths = pad_sequences(char_ids, pad_tok=0, nlevels=2, max_char_length=self.max_char_length) - char_ids = np.asarray(char_ids) + char_ids = np.asarray(char_ids, dtype='object') return [char_ids], labels_final #else: # return labels_final diff --git a/delft/sequenceLabelling/reader.py b/delft/sequenceLabelling/reader.py index 32c106d2..aa8fe7a9 100644 --- a/delft/sequenceLabelling/reader.py +++ b/delft/sequenceLabelling/reader.py @@ -499,7 +499,10 @@ def load_data_crf_string(crfString): #print('sents:', len(sents)) #print('featureSets:', len(featureSets)) - return sents, featureSets + return ( + np.asarray(sents, dtype='object'), + np.asarray(featureSets, dtype='object') + ) def _translate_tags_grobid_to_IOB(tag): @@ -735,8 +738,8 @@ def load_data_and_labels_ontonotes(ontonotesRoot, lang='en'): total_tokens += len(sentence) print('nb total tokens:', total_tokens) - final_tokens = np.asarray(tokens) - final_label = np.asarray(labels) + final_tokens = np.asarray(tokens, dtype=object) + final_label = np.asarray(labels, dtype=object) return final_tokens, final_label diff --git a/delft/sequenceLabelling/wrapper.py b/delft/sequenceLabelling/wrapper.py index f2b2d802..538e483d 100644 --- a/delft/sequenceLabelling/wrapper.py +++ b/delft/sequenceLabelling/wrapper.py @@ -151,8 +151,9 @@ def train(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_va # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 - import atexit - atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore with strategy.scope(): self.train_(x_train, y_train, f_train, x_valid, y_valid, f_valid, incremental, callbacks) @@ -219,8 +220,9 @@ def train_nfold(self, x_train, y_train, x_valid=None, y_valid=None, f_train=None # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 - import atexit - atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore with strategy.scope(): self.train_nfold_(x_train, y_train, x_valid, y_valid, f_train, f_valid, incremental, callbacks) diff --git a/delft/textClassification/reader.py b/delft/textClassification/reader.py index 46b1621e..22938428 100644 --- a/delft/textClassification/reader.py +++ b/delft/textClassification/reader.py @@ -1,12 +1,53 @@ +import csv + import numpy as np -import xml import gzip import json -from xml.sax import make_parser, handler import pandas as pd from delft.utilities.numpy import shuffle_triple_with_view +def load_texts_and_classes_generic(filepath, text_index: int, classes_indexes: list): + """ + Load texts and classes from a file in the following simple tab-separated format: + + id_0 text_0 class_00 ... class_n0 + id_1 text_1 class_01 ... class_n1 + ... + id_m text_m class_0m ... class_nm + + text has no EOF and no tab + + Returns: + tuple(numpy array, numpy array): texts and classes + + """ + x = [] + y = [] + + delimiter = "\t" + if filepath.endswith(".csv"): + delimiter = "," + + with open(filepath) as f: + first = True + tsvreader = csv.reader(f, delimiter=delimiter) + for line in tsvreader: + if len(line) == 0: + continue + + classes = [line[i] for i in classes_indexes] if classes_indexes is not None else None + if first: + print("Sample input", "x: ", line[text_index], "y: ", classes) + first = False + x.append(line[text_index]) + + if classes_indexes is not None: + y.append(classes) + + return np.asarray(x, dtype=object), np.asarray(y, dtype=object) + + def load_texts_and_classes(filepath): """ Load texts and classes from a file in the following simple tab-separated format: @@ -66,7 +107,7 @@ def load_texts_and_classes_pandas(filepath): classes = df.iloc[:,2:] classes_list = classes.values.tolist() - return np.asarray(texts_list), np.asarray(classes_list) + return np.asarray(texts_list, dtype=object), np.asarray(classes_list, dtype=object) def load_texts_pandas(filepath): @@ -93,7 +134,7 @@ def load_texts_pandas(filepath): for j in range(0, df.shape[0]): texts_list.append(df.iloc[j,1]) - return np.asarray(texts_list) + return np.asarray(texts_list, dtype=object) def load_citation_sentiment_corpus(filepath): @@ -144,7 +185,7 @@ def load_citation_sentiment_corpus(filepath): polarity.append(0) polarities.append(polarity) - return np.asarray(texts), np.asarray(polarities) + return np.asarray(texts, dtype=object), np.asarray(polarities, dtype=object) def load_citation_intent_corpus(filepath): """ @@ -242,7 +283,7 @@ def map_boolean(x): # otherwise we have the list of datatypes, and optionally subtypes and leaf datatypes datatypes = df.iloc[:,2] datatypes_list = datatypes.values.tolist() - datatypes_list = np.asarray(datatypes_list) + datatypes_list = np.asarray(datatypes_list, dtype=object) datatypes_list_lower = np.char.lower(datatypes_list) list_classes_datatypes = np.unique(datatypes_list_lower) datatypes_final = normalize_classes(datatypes_list_lower, list_classes_datatypes) @@ -254,7 +295,7 @@ def map_boolean(x): df = df[~df.datatype.str.contains("no_dataset")] datasubtypes = df.iloc[:,3] datasubtypes_list = datasubtypes.values.tolist() - datasubtypes_list = np.asarray(datasubtypes_list) + datasubtypes_list = np.asarray(datasubtypes_list, dtype=object) datasubtypes_list_lower = np.char.lower(datasubtypes_list) list_classes_datasubtypes = np.unique(datasubtypes_list_lower) datasubtypes_final = normalize_classes(datasubtypes_list_lower, list_classes_datasubtypes) @@ -272,10 +313,10 @@ def map_boolean(x): ''' if df.shape[1] == 3: - return np.asarray(texts_list), datatypes_final, None, None, list_classes_datatypes.tolist(), None, None + return np.asarray(texts_list, dtype=object), datatypes_final, None, None, list_classes_datatypes.tolist(), None, None #elif df.shape[1] == 4: else: - return np.asarray(texts_list), datatypes_final, datasubtypes_final, None, list_classes_datatypes.tolist(), list_classes_datasubtypes.tolist(), None + return np.asarray(texts_list, dtype=object), datatypes_final, datasubtypes_final, None, list_classes_datatypes.tolist(), list_classes_datasubtypes.tolist(), None ''' else: return np.asarray(texts_list), datatypes_final, datasubtypes_final, leafdatatypes_final, list_classes_datatypes.tolist(), list_classes_datasubtypes.tolist(), list_classes_leafdatatypes.tolist() diff --git a/delft/textClassification/wrapper.py b/delft/textClassification/wrapper.py index abcce915..b44b633d 100644 --- a/delft/textClassification/wrapper.py +++ b/delft/textClassification/wrapper.py @@ -1,5 +1,7 @@ import os +from packaging import version + from delft.sequenceLabelling.trainer import LogLearningRateCallback # ask tensorflow to be quiet and not print hundred lines of logs from delft.utilities.misc import print_parameters @@ -35,7 +37,7 @@ from delft.textClassification.models import predict_folds from delft.textClassification.data_generator import DataGenerator -from delft.utilities.Transformer import Transformer, TRANSFORMER_CONFIG_FILE_NAME, DEFAULT_TRANSFORMER_TOKENIZER_DIR +from delft.utilities.Transformer import TRANSFORMER_CONFIG_FILE_NAME, DEFAULT_TRANSFORMER_TOKENIZER_DIR from delft.utilities.Embeddings import Embeddings, load_resource_registry @@ -43,9 +45,7 @@ from sklearn.model_selection import train_test_split import transformers -transformers.logging.set_verbosity(transformers.logging.ERROR) - -from tensorflow.keras.utils import plot_model +transformers.logging.set_verbosity(transformers.logging.ERROR) class Classifier(object): @@ -138,10 +138,26 @@ def __init__(self, class_weights=class_weights, multiprocessing=multiprocessing) - def train(self, x_train, y_train, vocab_init=None, incremental=False, callbacks=None): + def train(self, x_train, y_train, vocab_init=None, incremental=False, callbacks=None, multi_gpu=False): + if multi_gpu: + strategy = tf.distribute.MirroredStrategy() + print('Running with multi-gpu. Number of devices: {}'.format(strategy.num_replicas_in_sync)) + + # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. + # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + + with strategy.scope(): + self.train_(x_train, y_train, vocab_init, incremental, callbacks) + else: + self.train_(x_train, y_train, vocab_init, incremental, callbacks) + + def train_(self, x_train, y_train, vocab_init=None, incremental=False, callbacks=None): if incremental: - if self.model == None and self.models == None: + if self.model is None and self.models is None: print("error: you must load a model first for an incremental training") return print("Incremental training from loaded model", self.model_config.model_name) @@ -178,6 +194,7 @@ def train(self, x_train, y_train, vocab_init=None, incremental=False, callbacks= callbacks_.append(LogLearningRateCallback(self.model)) # uncomment to plot graph + # from tensorflow.keras.utils import plot_model #plot_model(self.model, # to_file='data/models/textClassification/'+self.model_config.model_name+'_'+self.model_config.architecture+'.png') self.model.train_model( @@ -194,7 +211,23 @@ def train(self, x_train, y_train, vocab_init=None, incremental=False, callbacks= callbacks=callbacks) - def train_nfold(self, x_train, y_train, vocab_init=None, incremental=False, callbacks=None): + def train_nfold(self, x_train, y_train, vocab_init=None, incremental=False, callbacks=None, multi_gpu=False): + if multi_gpu: + strategy = tf.distribute.MirroredStrategy() + print('Running with multi-gpu. Number of devices: {}'.format(strategy.num_replicas_in_sync)) + + # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. + # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + + with strategy.scope(): + self.train_nfold_(x_train, y_train,vocab_init, incremental, callbacks) + else: + self.train_nfold_(x_train, y_train, vocab_init, incremental, callbacks) + + def train_nfold_(self, x_train, y_train, vocab_init=None, incremental=False, callbacks=None): if incremental: if self.models == None: print("error: you must load a model first for an incremental training") @@ -204,8 +237,23 @@ def train_nfold(self, x_train, y_train, vocab_init=None, incremental=False, call else: self.models = train_folds(x_train, y_train, self.model_config, self.training_config, self.embeddings, None, callbacks=callbacks) + def predict(self, texts, output_format='json', use_main_thread_only=False, batch_size=None, multi_gpu=False): + if multi_gpu: + strategy = tf.distribute.MirroredStrategy() + print('Running with multi-gpu. Number of devices: {}'.format(strategy.num_replicas_in_sync)) - def predict(self, texts, output_format='json', use_main_thread_only=False, batch_size=None): + # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. + # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + + with strategy.scope(): + return self.predict_(texts, output_format, use_main_thread_only, batch_size) + else: + return self.predict_(texts, output_format, use_main_thread_only, batch_size) + + def predict_(self, texts, output_format='json', use_main_thread_only=False, batch_size=None): bert_data = False if self.transformer_name != None: bert_data = True @@ -217,7 +265,7 @@ def predict(self, texts, output_format='json', use_main_thread_only=False, batch print("---") if self.model_config.fold_number == 1: - if self.model != None: + if self.model is not None: predict_generator = DataGenerator(texts, None, batch_size=self.model_config.batch_size, maxlen=self.model_config.maxlen, list_classes=self.model_config.list_classes, @@ -227,7 +275,7 @@ def predict(self, texts, output_format='json', use_main_thread_only=False, batch else: raise (OSError('Could not find a model.')) else: - if self.models != None: + if self.models is not None: # just a warning: n classifiers using BERT layer for prediction might be heavy in term of model sizes predict_generator = DataGenerator(texts, None, batch_size=self.model_config.batch_size, @@ -272,7 +320,7 @@ def eval(self, x_test, y_test, use_main_thread_only=False): bert_data = True if self.model_config.fold_number == 1: - if self.model != None: + if self.model is not None: self.model.print_summary() test_generator = DataGenerator(x_test, None, batch_size=self.model_config.batch_size, maxlen=self.model_config.maxlen, list_classes=self.model_config.list_classes, @@ -423,7 +471,7 @@ def vectorize(index, size): ''' def save(self, dir_path='data/models/textClassification/'): - # create subfolder for the model if not already exists + # create sub-folder for the model if not already exists directory = os.path.join(dir_path, self.model_config.model_name) if not os.path.exists(directory): os.makedirs(directory) @@ -432,20 +480,20 @@ def save(self, dir_path='data/models/textClassification/'): print('model config file saved') if self.model_config.fold_number == 1: - if self.model != None: + if self.model is not None: self.model.save(os.path.join(directory, self.weight_file)) print('model saved') else: print('Error: model has not been built') else: - if self.models == None: - print('Error: nfolds models have not been built') + if self.models is None: + print('Error: n-folds models have not been built') else: # fold models having a transformer layers are already saved if self.model_config.transformer_name is None: for i in range(0, self.model_config.fold_number): self.models[i].save(os.path.join(directory, "model{0}_weights.hdf5".format(i))) - print('nfolds model saved') + print('n-folds model saved') # save pretrained transformer config and tokenizer if used in the model and if single fold (otherwise it is saved in the nfold process) if self.transformer_name is not None and self.model_config.fold_number == 1: diff --git a/delft/utilities/Utilities.py b/delft/utilities/Utilities.py index 18214134..95949b00 100644 --- a/delft/utilities/Utilities.py +++ b/delft/utilities/Utilities.py @@ -78,7 +78,10 @@ def split_data_and_labels(x, y, ratio): else: x2.append(x[i]) y2.append(y[i]) - return np.asarray(x1),np.asarray(y1),np.asarray(x2),np.asarray(y2) + return np.asarray(x1, dtype='object'),\ + np.asarray(y1, dtype='object'),\ + np.asarray(x2, dtype='object'),\ + np.asarray(y2, dtype='object') url_regex = re.compile(r"https?:\/\/[a-zA-Z0-9_\-\.]+(?:com|org|fr|de|uk|se|net|edu|gov|int|mil|biz|info|br|ca|cn|in|jp|ru|au|us|ch|it|nl|no|es|pl|ir|cz|kr|co|gr|za|tw|hu|vn|be|mx|at|tr|dk|me|ar|fi|nz)\/?\b") diff --git a/delft/utilities/misc.py b/delft/utilities/misc.py index 22a30298..8af19883 100644 --- a/delft/utilities/misc.py +++ b/delft/utilities/misc.py @@ -67,7 +67,7 @@ def print_parameters(model_config, training_config): if hasattr(model_config, 'use_ELMo'): print("use_ELMo: ", model_config.use_ELMo) - if hasattr(training_config, 'class_weights') and training_config.class_weights != None and hasattr(model_config, 'list_classes'): + if hasattr(training_config, 'class_weights') and training_config.class_weights is not None and hasattr(model_config, 'list_classes'): list_classes = model_config.list_classes weight_summary = "" for indx, class_name in enumerate(model_config.list_classes): diff --git a/tests/sequence_labelling/preprocess_test.py b/tests/sequence_labelling/preprocess_test.py index bb4d5f04..85eae2dd 100644 --- a/tests/sequence_labelling/preprocess_test.py +++ b/tests/sequence_labelling/preprocess_test.py @@ -78,7 +78,7 @@ def _to_dense(a: np.array): def all_close(a: np.array, b: np.array): - return np.allclose(_to_dense(a), _to_dense(b)) + return np.allclose(_to_dense(a).astype("float"), _to_dense(b).astype("float")) class TestFeaturesPreprocessor: @@ -95,7 +95,7 @@ def test_should_fit_single_value_feature(self): features_transformed = preprocessor.fit_transform(features_batch) features_length = len(preprocessor.features_indices) assert features_length == 1 - assert all_close(features_transformed, [[[1]]]) + assert all_close(features_transformed, np.array([[[1]]], dtype=object)) def test_should_fit_single_multiple_value_features(self): preprocessor = FeaturesPreprocessor() @@ -112,14 +112,14 @@ def test_should_fit_multiple_single_value_features(self): features_transformed = preprocessor.fit_transform(features_batch) features_length = len(preprocessor.features_indices) assert features_length == 2 - assert all_close(features_transformed, [[[1, 13]]]) + assert all_close(features_transformed, np.asarray([[[1, 13]]], dtype=object)) def test_should_transform_unseen_to_zero(self): preprocessor = FeaturesPreprocessor() features_batch = [[[FEATURE_VALUE_1]]] preprocessor.fit(features_batch) features_transformed = preprocessor.transform([[[FEATURE_VALUE_2]]]) - assert all_close(features_transformed, [[[0]]]) + assert all_close(features_transformed, np.asarray([[[0]]], dtype=object)) def test_should_select_features(self): preprocessor = FeaturesPreprocessor(features_indices=[1]) @@ -131,7 +131,7 @@ def test_should_select_features(self): features_transformed = preprocessor.fit_transform(features_batch) features_length = len(preprocessor.features_indices) assert features_length == 1 - assert all_close(features_transformed, [[[1], [2], [3]]]) + assert all_close(features_transformed, np.asarray([[[1], [2], [3]]], dtype=object)) def test_serialize_to_json(self, tmp_path): preprocessor = FeaturesPreprocessor(features_indices=[1])