Skip to content

Commit

Permalink
Create bert.py
Browse files Browse the repository at this point in the history
  • Loading branch information
escrogar authored Nov 1, 2021
1 parent f925bec commit aacb112
Showing 1 changed file with 215 additions and 0 deletions.
215 changes: 215 additions & 0 deletions bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
"""
Obtain BERT contextual embeddings for sentences.
"""

import argparse

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
from transformers import AutoTokenizer, AutoModel

import pandas as pd
import numpy as np
import os

np.set_printoptions(threshold=np.inf)


def main():
"""Main."""
args = get_args()

INPUT = args.input # 'etl_nothree'

DIVISOR = args.divisor # 3 or 200 :)

# different setups might need different solutions here
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained('SZTAKI-HLT/hubert-base-cc')
model = AutoModel.from_pretrained('SZTAKI-HLT/hubert-base-cc')


# read corpus from tsv
corpus = pd.read_csv(f"corpora_and_labels/{INPUT}.tsv", sep='\t')
print(corpus)

# tokenize corpus
# note: it does not check for length requirements at this time
tokenized = corpus["text"].apply((lambda x: tokenizer.encode(x, add_special_tokens=True)))
print(tokenized)

# create padding and attention masks
max_len = 0
for i in tokenized.values:
if len(i) > max_len:
max_len = len(i)
padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values])
print(padded)

attention_mask = np.where(padded != 0, 1, 0)
print(attention_mask)

# for computationally weaker setups, batch execution is the only way to process the texts
# manipulate floor divisor if a different batch size is needed
batchsize = (len(corpus) // DIVISOR) + 1
print('Number of batches:', batchsize)
splitpadded = np.array_split(padded, batchsize)
splitmask = np.array_split(attention_mask, batchsize)


last_hidden_states = []
model = model.to(device)

DIMS = 768 # 768 at most (because of using BERT Base)
if args.first_word_pooling:
featuresfinal = np.empty((0, DIMS), dtype='float32')
elif args.all_word_pooling:
featuresfinal = np.empty((0, max_len, DIMS), dtype='float32')
elif args.mean_pooling:
featuresfinal = np.empty((0, DIMS), dtype='float32')

# take batches of tokenized texts
# to extract BERT's last hidden states, i.e. contextual word embeddings
#
# XXX handling attention_mask was erroneous here,
# because array_split() gives variable length!
# now: zip() ensures that text and attention data is taken strictly in parallel
for count, (batch, mask) in enumerate(zip(splitpadded, splitmask)):
batch_cnt = count + 1
print(f'Batch #{batch_cnt}')
paddedsplit = np.array(batch, dtype='float64')

input_batch = torch.tensor(batch).to(torch.long)
mask_batch = torch.tensor(mask)
print(input_batch)
print(mask_batch)
print('Batches established!')

# put data onto GPU
input_batch = input_batch.to(device)
mask_batch = mask_batch.to(device)
print('Lengths', input_batch.size(0), mask_batch.size(0))

# no_grad ensures there is no gradient update in the model,
# as we are not looking for recursive training here
with torch.no_grad():
print('Model is running on', model.device)
last_hidden_states = model(input_batch, attention_mask=mask_batch)
print('Hidden states created for batch', batch_cnt)

# tensor dimensions: 0=sents, 1=words, 2=coords
if args.first_word_pooling:
# we take the vector of the _first_ word, seriously :)
features = last_hidden_states[0][:, 0, 0:DIMS].cpu().numpy()
elif args.all_word_pooling:
features = last_hidden_states[0][:, :, 0:DIMS].cpu().numpy()
elif args.mean_pooling:
lhs = last_hidden_states[0][:, :, 0:DIMS].cpu().numpy()
features = np.mean(lhs, axis=1) # average above words

if args.verbose:
print(features.shape)
print(features)

featuresfinal = np.append(featuresfinal, features, axis=0)

print('Finished with batch', batch_cnt)

# output + labels are saved as separate files
labels = corpus["topik"]

np.save(f"featuresfinal_{INPUT}", featuresfinal)
np.save(f"labels_{INPUT}", labels)

if not args.verbose:

print(list(featuresfinal))
print(labels)

else:

print()
print('Vectors')

for padd, feat, label in zip(padded, featuresfinal, labels):
print()
for p, f in zip(padd, feat):
# XXX why do we have additional whitespaces?
token = ''.join(filter(lambda x: x != " ", tokenizer.decode(int(p))))
print(f'{label}\t{p}\t"{token}"\t{f}')

print()
print('Distances')

import itertools as it
from scipy.spatial import distance

SAMPLEWORD = 2765 # "vár"

for i, j in it.combinations(range(len(padded)), 2):
ap, af, al = padded[i], featuresfinal[i], labels[i]
bp, bf, bl = padded[j], featuresfinal[j], labels[j]
a = [x[1] for x in zip(ap, af) if x[0] == SAMPLEWORD]
b = [x[1] for x in zip(bp, bf) if x[0] == SAMPLEWORD]
dist = distance.cosine(a, b)
issame = al == bl
issamemark = '==' if issame else '<>'
anom = '!ERR!' if dist >= 0.08 and issame else ''
print(f'#{i} L{al} {issamemark} #{j} L{bl} = {dist} {anom}')

# XXX add clustering at the end -> polysemy will be solved :)


def get_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-i', '--input',
help='input corpus (without ".tsv")',
required=True,
type=str,
default=argparse.SUPPRESS
)
parser.add_argument(
'-d', '--divisor',
help='split corpus to this many batches',
required=True,
type=int,
default=argparse.SUPPRESS
)

group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'-a', '--all-word-pooling',
help='all words = sentence representation',
action='store_true'
)
group.add_argument(
'-f', '--first-word-pooling',
help='first word = sentence representation',
action='store_true'
)
group.add_argument(
'-m', '--mean-pooling',
help='average of all words = sentence representation',
action='store_true'
)

parser.add_argument(
'-v', '--verbose',
help='verbose output for investigation',
action='store_true'
)

return parser.parse_args()


if __name__ == '__main__':
main()

0 comments on commit aacb112

Please sign in to comment.