From aacb11214f4bcc029d307382832aab0199d81d4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20M=C3=A1rk=20Kis?= Date: Mon, 1 Nov 2021 17:01:01 +0100 Subject: [PATCH] Create bert.py --- bert.py | 215 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 bert.py diff --git a/bert.py b/bert.py new file mode 100644 index 0000000..a765c96 --- /dev/null +++ b/bert.py @@ -0,0 +1,215 @@ +""" +Obtain BERT contextual embeddings for sentences. +""" + +import argparse + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + +import torch +from transformers import AutoTokenizer, AutoModel + +import pandas as pd +import numpy as np +import os + +np.set_printoptions(threshold=np.inf) + + +def main(): + """Main.""" + args = get_args() + + INPUT = args.input # 'etl_nothree' + + DIVISOR = args.divisor # 3 or 200 :) + + # different setups might need different solutions here + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + os.environ["TOKENIZERS_PARALLELISM"] = "false" + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained('SZTAKI-HLT/hubert-base-cc') + model = AutoModel.from_pretrained('SZTAKI-HLT/hubert-base-cc') + + + # read corpus from tsv + corpus = pd.read_csv(f"corpora_and_labels/{INPUT}.tsv", sep='\t') + print(corpus) + + # tokenize corpus + # note: it does not check for length requirements at this time + tokenized = corpus["text"].apply((lambda x: tokenizer.encode(x, add_special_tokens=True))) + print(tokenized) + + # create padding and attention masks + max_len = 0 + for i in tokenized.values: + if len(i) > max_len: + max_len = len(i) + padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values]) + print(padded) + + attention_mask = np.where(padded != 0, 1, 0) + print(attention_mask) + + # for computationally weaker setups, batch execution is the only way to process the texts + # manipulate floor divisor if a different batch size is needed + batchsize = (len(corpus) // DIVISOR) + 1 + print('Number of batches:', batchsize) + splitpadded = np.array_split(padded, batchsize) + splitmask = np.array_split(attention_mask, batchsize) + + + last_hidden_states = [] + model = model.to(device) + + DIMS = 768 # 768 at most (because of using BERT Base) + if args.first_word_pooling: + featuresfinal = np.empty((0, DIMS), dtype='float32') + elif args.all_word_pooling: + featuresfinal = np.empty((0, max_len, DIMS), dtype='float32') + elif args.mean_pooling: + featuresfinal = np.empty((0, DIMS), dtype='float32') + + # take batches of tokenized texts + # to extract BERT's last hidden states, i.e. contextual word embeddings + # + # XXX handling attention_mask was erroneous here, + # because array_split() gives variable length! + # now: zip() ensures that text and attention data is taken strictly in parallel + for count, (batch, mask) in enumerate(zip(splitpadded, splitmask)): + batch_cnt = count + 1 + print(f'Batch #{batch_cnt}') + paddedsplit = np.array(batch, dtype='float64') + + input_batch = torch.tensor(batch).to(torch.long) + mask_batch = torch.tensor(mask) + print(input_batch) + print(mask_batch) + print('Batches established!') + + # put data onto GPU + input_batch = input_batch.to(device) + mask_batch = mask_batch.to(device) + print('Lengths', input_batch.size(0), mask_batch.size(0)) + + # no_grad ensures there is no gradient update in the model, + # as we are not looking for recursive training here + with torch.no_grad(): + print('Model is running on', model.device) + last_hidden_states = model(input_batch, attention_mask=mask_batch) + print('Hidden states created for batch', batch_cnt) + + # tensor dimensions: 0=sents, 1=words, 2=coords + if args.first_word_pooling: + # we take the vector of the _first_ word, seriously :) + features = last_hidden_states[0][:, 0, 0:DIMS].cpu().numpy() + elif args.all_word_pooling: + features = last_hidden_states[0][:, :, 0:DIMS].cpu().numpy() + elif args.mean_pooling: + lhs = last_hidden_states[0][:, :, 0:DIMS].cpu().numpy() + features = np.mean(lhs, axis=1) # average above words + + if args.verbose: + print(features.shape) + print(features) + + featuresfinal = np.append(featuresfinal, features, axis=0) + + print('Finished with batch', batch_cnt) + + # output + labels are saved as separate files + labels = corpus["topik"] + + np.save(f"featuresfinal_{INPUT}", featuresfinal) + np.save(f"labels_{INPUT}", labels) + + if not args.verbose: + + print(list(featuresfinal)) + print(labels) + + else: + + print() + print('Vectors') + + for padd, feat, label in zip(padded, featuresfinal, labels): + print() + for p, f in zip(padd, feat): + # XXX why do we have additional whitespaces? + token = ''.join(filter(lambda x: x != " ", tokenizer.decode(int(p)))) + print(f'{label}\t{p}\t"{token}"\t{f}') + + print() + print('Distances') + + import itertools as it + from scipy.spatial import distance + + SAMPLEWORD = 2765 # "vár" + + for i, j in it.combinations(range(len(padded)), 2): + ap, af, al = padded[i], featuresfinal[i], labels[i] + bp, bf, bl = padded[j], featuresfinal[j], labels[j] + a = [x[1] for x in zip(ap, af) if x[0] == SAMPLEWORD] + b = [x[1] for x in zip(bp, bf) if x[0] == SAMPLEWORD] + dist = distance.cosine(a, b) + issame = al == bl + issamemark = '==' if issame else '<>' + anom = '!ERR!' if dist >= 0.08 and issame else '' + print(f'#{i} L{al} {issamemark} #{j} L{bl} = {dist} {anom}') + + # XXX add clustering at the end -> polysemy will be solved :) + + +def get_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + '-i', '--input', + help='input corpus (without ".tsv")', + required=True, + type=str, + default=argparse.SUPPRESS + ) + parser.add_argument( + '-d', '--divisor', + help='split corpus to this many batches', + required=True, + type=int, + default=argparse.SUPPRESS + ) + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + '-a', '--all-word-pooling', + help='all words = sentence representation', + action='store_true' + ) + group.add_argument( + '-f', '--first-word-pooling', + help='first word = sentence representation', + action='store_true' + ) + group.add_argument( + '-m', '--mean-pooling', + help='average of all words = sentence representation', + action='store_true' + ) + + parser.add_argument( + '-v', '--verbose', + help='verbose output for investigation', + action='store_true' + ) + + return parser.parse_args() + + +if __name__ == '__main__': + main()