Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add textclassification application and few other things #169

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6fce63f
adding a new generic text classifier and fixing some compatibility is…
lfoppiano Mar 31, 2020
db99fdf
use the right shuffle method
lfoppiano Apr 10, 2020
b077909
avoid clean operations on cache when it's disable
lfoppiano Apr 10, 2020
cbd5f14
commented too much code
lfoppiano Apr 13, 2020
80fa82a
Merge branch 'master' into classification-fixes
lfoppiano Jun 23, 2021
86d0583
read tsv with a parser
lfoppiano Jul 9, 2021
1aa8026
use an advanced splitter
lfoppiano Jul 9, 2021
56cdf08
minor updates
lfoppiano Jul 12, 2021
e02bc6a
improve memory usage
lfoppiano Jul 16, 2021
6943f1d
enforce the use of quotes (temporary)
lfoppiano Aug 24, 2021
0992a29
Merge branch 'master' into classification-fixes
lfoppiano Sep 20, 2022
47c9b83
Merge branch 'master' into classification-fixes
lfoppiano Sep 20, 2022
b0cf986
cleanup after merging
lfoppiano Sep 20, 2022
19561a6
more cleanup
lfoppiano Sep 20, 2022
7ba819e
add column x and column y
lfoppiano Sep 21, 2022
e6bf93b
improve generalisation with a binary / multiclass classification
lfoppiano Sep 26, 2022
39332c0
remove early stop for bert
lfoppiano Sep 28, 2022
7014a08
limit to single class and some cosmetics
lfoppiano Sep 28, 2022
5e223ea
add output in csv when in not in json, y-indexes mandatory
lfoppiano Oct 24, 2022
f34eae3
Merge branch 'master' into classification-fixes
lfoppiano Dec 19, 2022
9da4fde
make reader selecting tsv or csv
lfoppiano Dec 19, 2022
371bff8
add patience in parameters
lfoppiano Dec 19, 2022
77a6be0
cleanup useless removed spaces
lfoppiano Dec 19, 2022
830e092
cleanup useless removed spaces 2
lfoppiano Dec 19, 2022
03bd7dd
fix tests
lfoppiano Dec 19, 2022
198ba69
missing parameters
lfoppiano Dec 20, 2022
175fe89
add classification from input file, cleanup
lfoppiano Dec 22, 2022
ee6eca0
add consistency in the application script
lfoppiano Jan 5, 2023
f6f37f2
Merge branch 'master' into features/text-classification
lfoppiano Jan 16, 2024
a4d13b4
enable multi-gpu and other parameters
lfoppiano Jan 17, 2024
4e6ed0d
put some order with the parameters
lfoppiano Jan 18, 2024
92d9573
fix multigpu tricks for tf < 2.10
lfoppiano Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion delft/applications/citationClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def classify(texts, output_format, architecture="gru", embeddings_name=None, tra
args = parser.parse_args()

if args.action not in ('train', 'train_eval', 'classify'):
print('action not specifed, must be one of [train,train_eval,classify]')
print('action not specified, must be one of [train,train_eval,classify]')

embeddings_name = args.embedding
transformer = args.transformer
Expand Down
294 changes: 294 additions & 0 deletions delft/applications/textClassifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import argparse
import sys
import time

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

from delft.textClassification import Classifier
from delft.textClassification.models import architectures
from delft.textClassification.reader import load_texts_and_classes_generic
from delft.utilities.Utilities import t_or_f

pretrained_transformers_examples = ['bert-base-cased', 'bert-large-cased', 'allenai/scibert_scivocab_cased']

actions = ['train', 'train_eval', 'eval', 'classify']


def get_one_hot(y):
label_encoder = LabelEncoder()
integer_encoded = label_encoder.fit_transform(y)
onehot_encoder = OneHotEncoder(sparse=False)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
y2 = onehot_encoder.fit_transform(integer_encoded)
return y2



def configure(architecture, max_sequence_length_=-1, batch_size_=-1, max_epoch_=-1, patience_=-1, early_stop=True):
batch_size = 256
maxlen = 150 if max_sequence_length_ == -1 else max_sequence_length_
patience = 5 if patience_ == -1 else patience_
max_epoch = 60 if max_epoch_ == -1 else max_epoch_

# default bert model parameters
if architecture == "bert":
batch_size = 32
# early_stop = False
# max_epoch = 3

batch_size = batch_size_ if batch_size_ != -1 else batch_size

return batch_size, maxlen, patience, early_stop, max_epoch


def train(model_name,
architecture,
input_file,
embeddings_name,
fold_count,
transformer=None,
x_index=0,
y_indexes=[1],
batch_size=-1,
max_sequence_length=-1,
patience=-1,
incremental=False,
learning_rate=None,
multi_gpu=False,
max_epoch=50,
early_stop=True
):

batch_size, maxlen, patience, early_stop, max_epoch = configure(architecture,
max_sequence_length,
batch_size,
max_epoch,
patience,
early_stop=early_stop)

print('loading ' + model_name + ' training corpus...')
xtr, y = load_texts_and_classes_generic(input_file, x_index, y_indexes)

list_classes = list(set([y_[0] for y_ in y]))

model = Classifier(model_name,
architecture=architecture,
list_classes=list_classes,
max_epoch=max_epoch,
fold_number=fold_count,
patience=patience,
transformer_name=transformer,
use_roc_auc=True,
embeddings_name=embeddings_name,
early_stop=early_stop,
batch_size=batch_size,
maxlen=maxlen,
class_weights=None,
learning_rate=learning_rate)

y_ = get_one_hot(y)

if fold_count == 1:
model.train(xtr, y_, incremental=incremental, multi_gpu=multi_gpu)
else:
model.train_nfold(xtr, y_, multi_gpu=multi_gpu)
# saving the model
model.save()


def eval(model_name, architecture, input_file, x_index=0, y_indexes=[1]):
# model_name += model_name + '-' + architecture

print('loading ' + model_name + ' evaluation corpus...')

xtr, y = load_texts_and_classes_generic(input_file, x_index, y_indexes)
print(len(xtr), 'evaluation sequences')

model = Classifier(model_name, architecture=architecture)
model.load()

y_ = get_one_hot(y)

model.eval(xtr, y_)


def train_and_eval(model_name, architecture, input_file, embeddings_name, fold_count, transformer=None,
x_index=0, y_indexes=[1], batch_size=-1,
max_sequence_length=-1, patience=-1, multi_gpu=False):
batch_size, maxlen, patience, early_stop, max_epoch = configure(architecture, batch_size, max_sequence_length,
patience)

print('loading ' + model_name + ' corpus...')
xtr, y = load_texts_and_classes_generic(input_file, x_index, y_indexes)

list_classes = list(set([y_[0] for y_ in y]))

y_one_hot = get_one_hot(y)

model = Classifier(model_name, architecture=architecture, list_classes=list_classes, max_epoch=max_epoch,
fold_number=fold_count, patience=patience, transformer_name=transformer,
use_roc_auc=True, embeddings_name=embeddings_name, early_stop=early_stop,
batch_size=batch_size, maxlen=maxlen, class_weights=None)

# segment train and eval sets
x_train, x_test, y_train, y_test = train_test_split(xtr, y_one_hot, test_size=0.1)

if fold_count == 1:
model.train(x_train, y_train, multi_gpu=multi_gpu)
else:
model.train_nfold(x_train, y_train, multi_gpu=multi_gpu)

model.eval(x_test, y_test)

# saving the model
model.save()


# classify a list of texts
def classify(model_name, architecture, texts, output_format='json'):
model = Classifier(model_name, architecture=architecture)
model.load()

start_time = time.time()
result = model.predict(texts, output_format)
runtime = round(time.time() - start_time, 3)

if output_format == 'json':
result["runtime"] = runtime
else:
print("runtime: %s seconds " % (runtime))

return result


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="General classification of text ")

parser.add_argument("action", help="the action", choices=actions)
parser.add_argument("model", help="The name of the model")
parser.add_argument("--fold-count", type=int, default=1)
parser.add_argument("--input", type=str, required=True,
help="The file to be used for train, train_eval, eval and, classify")
parser.add_argument("--x-index", type=int, required=True, help="Index of the columns for the X value "
"(assuming a TSV file)")
parser.add_argument("--y-indexes", type=str, required=False, help="Index(es) of the columns for the Y (classes) "
"separated by comma, without spaces (assuming "
"a TSV file)")
parser.add_argument("--architecture", default='gru', choices=architectures,
help="type of model architecture to be used, one of " + str(architectures))
parser.add_argument(
"--embedding", default='word2vec',
help=(
"The desired pre-trained word embeddings using their descriptions in the file"
" embedding-registry.json."
" Be sure to use here the same name as in the registry ('glove-840B', 'fasttext-crawl', 'word2vec'),"
" and that the path in the registry to the embedding file is correct on your system."))

parser.add_argument(
"--transformer",
default=None,
help="The desired pre-trained transformer to be used in the selected architecture. " + \
"For local loading use, delft/resources-registry.json, and be sure to use here the "
"same name as in the registry, e.g. " + \
str(pretrained_transformers_examples) + \
" and that the path in the registry to the model path is correct on your system. " + \
"HuggingFace transformers hub will be used otherwise to fetch the model, "
"see https://huggingface.co/models " + \
"for model names"
)

parser.add_argument("--max-sequence-length", type=int, default=-1, help="max-sequence-length parameter to be used.")
parser.add_argument("--batch-size", type=int, default=-1, help="batch-size parameter to be used.")
parser.add_argument("--patience", type=int, default=-1, help="patience, number of extra epochs to perform after "
"the best epoch before stopping a training.")
parser.add_argument("--learning-rate", type=float, default=None, help="Initial learning rate")
parser.add_argument("--incremental", action="store_true", help="training is incremental, starting from existing model if present")
parser.add_argument("--max-epoch", type=int, default=-1,
help="Maximum number of epochs for training.")
parser.add_argument("--early-stop", type=t_or_f, default=None,
help="Force early training termination when metrics scores are not improving " +
"after a number of epochs equals to the patience parameter.")

parser.add_argument("--multi-gpu", default=False,
help="Enable the support for distributed computing (the batch size needs to be set accordingly using --batch-size)",
action="store_true")

args = parser.parse_args()

embeddings_name = args.embedding
input_file = args.input
model_name = args.model
transformer = args.transformer
architecture = args.architecture
x_index = args.x_index
patience = args.patience
batch_size = args.batch_size
incremental = args.incremental
max_sequence_length = args.max_sequence_length
learning_rate = args.learning_rate
max_epoch = args.max_epoch
early_stop = args.early_stop
multi_gpu = args.multi_gpu

if args.action != "classify":
if args.y_indexes is None:
print("--y-indexes is mandatory")
sys.exit(-1)
y_indexes = [int(index) for index in args.y_indexes.split(",")]

if len(y_indexes) > 1:
print("At the moment we support just one value per class. Taking the first value only. ")
y_indexes = y_indexes[0]

if transformer is None and embeddings_name is None:
# default word embeddings
embeddings_name = "glove-840B"

if args.action == 'train':
train(model_name, architecture, input_file, embeddings_name, args.fold_count,
transformer=transformer,
x_index=x_index,
y_indexes=y_indexes,
batch_size=batch_size,
incremental=incremental,
max_sequence_length=max_sequence_length,
patience=patience,
learning_rate=learning_rate,
max_epoch=max_epoch,
early_stop=early_stop,
multi_gpu=multi_gpu)

elif args.action == 'eval':
eval(model_name, architecture, input_file, x_index=x_index, y_indexes=y_indexes)

elif args.action == 'train_eval':
if args.fold_count < 1:
raise ValueError("fold-count should be equal or more than 1")

train_and_eval(model_name,
architecture,
input_file,
embeddings_name,
args.fold_count,
transformer=transformer,
x_index=x_index,
y_indexes=y_indexes,
batch_size=batch_size,
max_sequence_length=max_sequence_length,
patience=patience,
multi_gpu=multi_gpu)

elif args.action == 'classify':
lines, _ = load_texts_and_classes_generic(input_file, x_index, None)

result = classify(model_name, lines, "csv")

result_binary = [np.argmax(line) for line in result]

for x in result_binary:
print(x)
# See https://github.com/tensorflow/tensorflow/issues/3388
# K.clear_session()
2 changes: 1 addition & 1 deletion delft/sequenceLabelling/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def __data_generation(self, index):

# to have input as sentence piece token index for transformer layer
input_ids, token_type_ids, attention_mask, input_chars, input_features, input_labels, input_offsets = self.bert_preprocessor.tokenize_and_align_features_and_labels(
x_tokenized,
x_tokenized,
batch_c,
sub_f,
batch_y,
Expand Down
6 changes: 3 additions & 3 deletions delft/sequenceLabelling/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def transform(self, X, extend=False):
out.append([0] * features_count)

features_vector_padded, _ = pad_sequences(features_vector, [0] * features_count)
output = np.asarray(features_vector_padded)
output = np.asarray(features_vector_padded, dtype='object')

return output

Expand Down Expand Up @@ -757,13 +757,13 @@ def pad_sequence(self, char_ids, labels=None, label_indices=False):
labels_final = None
if labels:
labels_padded, _ = pad_sequences(labels, 0)
labels_final = np.asarray(labels_padded)
labels_final = np.asarray(labels_padded, dtype='object')
if not label_indices:
labels_final = dense_to_one_hot(labels_final, len(self.vocab_tag), nlevels=2)

#if self.return_chars:
char_ids, word_lengths = pad_sequences(char_ids, pad_tok=0, nlevels=2, max_char_length=self.max_char_length)
char_ids = np.asarray(char_ids)
char_ids = np.asarray(char_ids, dtype='object')
return [char_ids], labels_final
#else:
# return labels_final
Expand Down
9 changes: 6 additions & 3 deletions delft/sequenceLabelling/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,10 @@ def load_data_crf_string(crfString):

#print('sents:', len(sents))
#print('featureSets:', len(featureSets))
return sents, featureSets
return (
np.asarray(sents, dtype='object'),
np.asarray(featureSets, dtype='object')
)


def _translate_tags_grobid_to_IOB(tag):
Expand Down Expand Up @@ -735,8 +738,8 @@ def load_data_and_labels_ontonotes(ontonotesRoot, lang='en'):
total_tokens += len(sentence)
print('nb total tokens:', total_tokens)

final_tokens = np.asarray(tokens)
final_label = np.asarray(labels)
final_tokens = np.asarray(tokens, dtype=object)
final_label = np.asarray(labels, dtype=object)

return final_tokens, final_label

Expand Down
10 changes: 6 additions & 4 deletions delft/sequenceLabelling/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ def train(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_va

# This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system.
# It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore
if version.parse(tf.__version__) < version.parse('2.10.0'):
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore

with strategy.scope():
self.train_(x_train, y_train, f_train, x_valid, y_valid, f_valid, incremental, callbacks)
Expand Down Expand Up @@ -219,8 +220,9 @@ def train_nfold(self, x_train, y_train, x_valid=None, y_valid=None, f_train=None

# This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system.
# It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore
if version.parse(tf.__version__) < version.parse('2.10.0'):
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore

with strategy.scope():
self.train_nfold_(x_train, y_train, x_valid, y_valid, f_train, f_valid, incremental, callbacks)
Expand Down
Loading
Loading