From 7efd244e99c617bee43139575761cc13244fe866 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Fri, 28 Feb 2020 14:33:41 -0500 Subject: [PATCH] Add sequence_tagging example --- examples/README.md | 4 + examples/sequence_tagging/.gitignore | 2 + examples/sequence_tagging/README.md | 29 ++ examples/sequence_tagging/config.py | 68 +++++ examples/sequence_tagging/conll_reader.py | 266 ++++++++++++++++++ examples/sequence_tagging/conll_writer.py | 41 +++ examples/sequence_tagging/conlleval | 315 ++++++++++++++++++++++ examples/sequence_tagging/ner.py | 256 ++++++++++++++++++ examples/sequence_tagging/scores.py | 28 ++ 9 files changed, 1009 insertions(+) create mode 100644 examples/sequence_tagging/.gitignore create mode 100644 examples/sequence_tagging/README.md create mode 100644 examples/sequence_tagging/config.py create mode 100644 examples/sequence_tagging/conll_reader.py create mode 100644 examples/sequence_tagging/conll_writer.py create mode 100644 examples/sequence_tagging/conlleval create mode 100644 examples/sequence_tagging/ner.py create mode 100644 examples/sequence_tagging/scores.py diff --git a/examples/README.md b/examples/README.md index 72fe20c39..9e81744ed 100644 --- a/examples/README.md +++ b/examples/README.md @@ -26,6 +26,7 @@ More examples are continuously added... * [bert](./bert): Pre-trained BERT model for text representation * [sentence_classifier](./sentence_classifier): Basic CNN-based sentence classifier +* [sequence_tagging](./sequence_tagging): BiLSTM-CNN model for Named Entity Recognition (NER) * [xlnet](./xlnet): Pre-trained XLNet model for text classification/regression --- @@ -49,3 +50,6 @@ More examples are continuously added... * [sentence_classifier](./sentence_classifier): Basic CNN-based sentence classifier * [xlnet](./xlnet): Pre-trained XLNet model for text classification/regression +### Sequence Tagging ### + +* [sequence_tagging](./sequence_tagging): BiLSTM-CNN model for Named Entity Recognition (NER) diff --git a/examples/sequence_tagging/.gitignore b/examples/sequence_tagging/.gitignore new file mode 100644 index 000000000..fe8904f0f --- /dev/null +++ b/examples/sequence_tagging/.gitignore @@ -0,0 +1,2 @@ +/data/ +/tmp/ diff --git a/examples/sequence_tagging/README.md b/examples/sequence_tagging/README.md new file mode 100644 index 000000000..146468d7e --- /dev/null +++ b/examples/sequence_tagging/README.md @@ -0,0 +1,29 @@ +# Sequence tagging on CoNLL-2003 # + +This example builds a bi-directional LSTM-CNN model for Named Entity Recognition (NER) task and trains on CoNLL-2003 data. Model and training are described in +>[(Ma et al.) End-to-end Sequence Labeling via Bi-directional LSTM-CNNs-CRF](http://www.cs.cmu.edu/~xuezhem/publications/P16-1101.pdf) + +The top CRF layer is not used here. + +## Dataset ## + +The code uses [CoNLL-2003 NER dataset](https://www.clips.uantwerpen.be/conll2003/ner/) (English). Please put data files (e.g., `eng.train.bio.conll`) under `./data` folder. Pretrained Glove word embeddings can also be used (set `load_glove=True` in [config.py](./config.py)). The Glove file should also be under `./data`. + +## Run ## + +To train a NER model, + +```bash +python ner.py +``` + +The model will begin training, and will evaluate on the validation data periodically, and evaluate on the test data after the training is done. + +## Results ## + +The results on validation and test data is: + +| | precision | recall | F1 | +|-------|----------|----------|----------| +| valid | 91.98 | 93.30 | 92.63 | +| test | 87.39 | 89.78 | 88.57 | diff --git a/examples/sequence_tagging/config.py b/examples/sequence_tagging/config.py new file mode 100644 index 000000000..41912830c --- /dev/null +++ b/examples/sequence_tagging/config.py @@ -0,0 +1,68 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +NER config. +""" + +num_epochs = 200 +char_dim = 30 +embed_dim = 100 +hidden_size = 256 +tag_space = 128 +keep_prob = 0.5 +batch_size = 16 +encoder = None +load_glove = True + +emb = { + "name": "embedding", + "dim": embed_dim, + "dropout_rate": 0.33, + "dropout_strategy": 'item' +} + +char_emb = { + "name": "char_embedding", + "dim": char_dim +} + +conv = { + "out_channels": 30, + "kernel_size": [3], + "conv_activation": "Tanh", + "num_dense_layers": 0, + "dropout_rate": 0. +} + +cell = { + "type": "LSTMCell", + "kwargs": { + "hidden_size": hidden_size, + }, + "dropout": { + "output_keep_prob": keep_prob + }, + "num_layers": 1 +} + +opt = { + "optimizer": { + "type": "SGD", + "kwargs": { + "lr": 0.1, + "momentum": 0.9, + "nesterov": True + } + } +} diff --git a/examples/sequence_tagging/conll_reader.py b/examples/sequence_tagging/conll_reader.py new file mode 100644 index 000000000..f94058b84 --- /dev/null +++ b/examples/sequence_tagging/conll_reader.py @@ -0,0 +1,266 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities for pre-processing and iterating over the CoNLL 2003 data. +""" + +import re + +from collections import defaultdict + +import numpy as np + +import texar.torch as tx + +# pylint: disable=redefined-outer-name, unused-variable + +MAX_CHAR_LENGTH = 45 +NUM_CHAR_PAD = 2 + +UNK_WORD, UNK_CHAR, UNK_NER = 0, 0, 0 +PAD_WORD, PAD_CHAR, PAD_NER = 1, 1, 1 + +# Regular expressions used to normalize digits. +DIGIT_RE = re.compile(r"\d") + + +def create_vocabs(train_path, dev_path, test_path, normalize_digits=True, + min_occur=1, glove_dict=None): + word_vocab = defaultdict(lambda: len(word_vocab)) + word_count = defaultdict(lambda: 0) + char_vocab = defaultdict(lambda: len(char_vocab)) + ner_vocab = defaultdict(lambda: len(ner_vocab)) + + UNK_WORD = word_vocab[""] + PAD_WORD = word_vocab[""] + UNK_CHAR = char_vocab[""] + PAD_CHAR = char_vocab[""] + UNK_NER = ner_vocab[""] + PAD_NER = ner_vocab[""] + + print("Creating Vocabularies:") + + for file_path in [train_path, dev_path, test_path]: + with open(file_path, 'r') as file: + for line in file: + line = line.strip() + if len(line) == 0: + continue + + tokens = line.split(' ') + for char in tokens[1]: + cid = char_vocab[char] + + word = DIGIT_RE.sub("0", tokens[1]) if normalize_digits \ + else tokens[1] + ner = tokens[4] + + if glove_dict is not None and (word in glove_dict or + word.lower() in glove_dict): + word_count[word] += min_occur + 1 + elif file_path == train_path: + word_count[word] += 1 + + nid = ner_vocab[ner] + + print("Total Vocabulary Size: %d" % len(word_count)) + for word in word_count: + if word_count[word] > min_occur: + wid = word_vocab[word] + + print("Word Vocabulary Size: %d" % len(word_vocab)) + print("Character Alphabet Size: %d" % len(char_vocab)) + print("NER Alphabet Size: %d" % len(ner_vocab)) + + word_vocab = defaultdict(lambda: UNK_WORD, word_vocab) + char_vocab = defaultdict(lambda: UNK_CHAR, char_vocab) + ner_vocab = defaultdict(lambda: UNK_NER, ner_vocab) + + i2w = {v: k for k, v in word_vocab.items()} + i2n = {v: k for k, v in ner_vocab.items()} + return (word_vocab, char_vocab, ner_vocab), (i2w, i2n) + + +def read_data(source_path, word_vocab, char_vocab, ner_vocab, + normalize_digits=True): + data = [] + print('Reading data from %s' % source_path) + counter = 0 + reader = CoNLLReader(source_path, word_vocab, char_vocab, ner_vocab) + inst = reader.getNext(normalize_digits) + while inst is not None: + counter += 1 + sent = inst.sentence + data.append([sent.word_ids, sent.char_id_seqs, inst.ner_ids]) + inst = reader.getNext(normalize_digits) + + reader.close() + print("Total number of data: %d" % counter) + return data + + +def iterate_batch(data, batch_size, shuffle=False): + if shuffle: + np.random.shuffle(data) + + for start_idx in range(0, len(data), batch_size): + excerpt = slice(start_idx, start_idx + batch_size) + batch = data[excerpt] + + batch_length = max([len(batch[i][0]) for i in range(len(batch))]) + + wid_inputs = np.empty([len(batch), batch_length], dtype=np.int64) + cid_inputs = np.empty([len(batch), batch_length, MAX_CHAR_LENGTH], + dtype=np.int64) + nid_inputs = np.empty([len(batch), batch_length], dtype=np.int64) + masks = np.zeros([len(batch), batch_length], dtype=np.float32) + lengths = np.empty(len(batch), dtype=np.int64) + + for i, inst in enumerate(batch): + wids, cid_seqs, nids = inst + + inst_size = len(wids) + lengths[i] = inst_size + # word ids + wid_inputs[i, :inst_size] = wids + wid_inputs[i, inst_size:] = PAD_WORD + for c, cids in enumerate(cid_seqs): + cid_inputs[i, c, :len(cids)] = cids + cid_inputs[i, c, len(cids):] = PAD_CHAR + cid_inputs[i, inst_size:, :] = PAD_CHAR + nid_inputs[i, :inst_size] = nids + nid_inputs[i, inst_size:] = PAD_NER + masks[i, :inst_size] = 1.0 + + yield wid_inputs, cid_inputs, nid_inputs, masks, lengths + + +def load_glove(filename, emb_dim, normalize_digits=True): + r"""Loads embeddings in the glove text format in which each line is + ' '. Dimensions of the embedding vector + are separated with whitespace characters. + """ + glove_dict = dict() + with open(filename) as fin: + for line in fin: + vec = line.strip().split() + if len(vec) == 0: + continue + word, vec = vec[0], vec[1:] + word = tx.utils.compat_as_text(word) + word = DIGIT_RE.sub("0", word) if normalize_digits else word + glove_dict[word] = np.array([float(v) for v in vec]) + if len(vec) != emb_dim: + raise ValueError("Inconsistent word vector sizes: %d vs %d" % + (len(vec), emb_dim)) + return glove_dict + + +def construct_init_word_vecs(vocab, word_vecs, glove_dict): + for word, index in vocab.items(): + if word in glove_dict: + embedding = glove_dict[word] + elif word.lower() in glove_dict: + embedding = glove_dict[word.lower()] + else: + embedding = None + + if embedding is not None: + word_vecs[index] = embedding + return word_vecs + + +class CoNLLReader: + + def __init__(self, file_path, word_vocab, char_vocab, ner_vocab): + self.__source_file = open(file_path, 'r', encoding='utf-8') + self.__word_vocab = word_vocab + self.__char_vocab = char_vocab + self.__ner_vocab = ner_vocab + + def close(self): + self.__source_file.close() + + def getNext(self, normalize_digits=True): + line = self.__source_file.readline() + # skip multiple blank lines. + while len(line) > 0 and len(line.strip()) == 0: + line = self.__source_file.readline() + if len(line) == 0: + return None + + lines = [] + while len(line.strip()) > 0: + line = line.strip() + lines.append(line.split(' ')) + line = self.__source_file.readline() + + length = len(lines) + if length == 0: + return None + + words = [] + word_ids = [] + char_seqs = [] + char_id_seqs = [] + ner_tags = [] + ner_ids = [] + + for tokens in lines: + chars = [] + char_ids = [] + for char in tokens[1]: + chars.append(char) + char_ids.append(self.__char_vocab[char]) + if len(chars) > MAX_CHAR_LENGTH: + chars = chars[:MAX_CHAR_LENGTH] + char_ids = char_ids[:MAX_CHAR_LENGTH] + char_seqs.append(chars) + char_id_seqs.append(char_ids) + + word = DIGIT_RE.sub("0", tokens[1]) \ + if normalize_digits else tokens[1] + ner = tokens[4] + + words.append(word) + word_ids.append(self.__word_vocab[word]) + + ner_tags.append(ner) + ner_ids.append(self.__ner_vocab[ner]) + + return NERInstance(Sentence(words, word_ids, char_seqs, char_id_seqs), + ner_tags, ner_ids) + + +class NERInstance: + + def __init__(self, sentence, ner_tags, ner_ids): + self.sentence = sentence + self.ner_tags = ner_tags + self.ner_ids = ner_ids + + def length(self): + return self.sentence.length() + + +class Sentence: + + def __init__(self, words, word_ids, char_seqs, char_id_seqs): + self.words = words + self.word_ids = word_ids + self.char_seqs = char_seqs + self.char_id_seqs = char_id_seqs + + def length(self): + return len(self.words) diff --git a/examples/sequence_tagging/conll_writer.py b/examples/sequence_tagging/conll_writer.py new file mode 100644 index 000000000..0d5f1d2b8 --- /dev/null +++ b/examples/sequence_tagging/conll_writer.py @@ -0,0 +1,41 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Writer module. +""" + + +class CoNLLWriter: + + def __init__(self, i2w, i2n): + self.__source_file = None + self.__i2w = i2w + self.__i2n = i2n + + def start(self, file_path): + self.__source_file = open(file_path, 'w', encoding='utf-8') + + def close(self): + self.__source_file.close() + + def write(self, word, predictions, targets, lengths): + batch_size, _ = word.shape + for i in range(batch_size): + for j in range(lengths[i]): + w = self.__i2w[word[i, j]] + tgt = self.__i2n[targets[i, j]] + pred = self.__i2n[predictions[i, j]] + self.__source_file.write( + '%d %s %s %s %s %s\n' % (j + 1, w, "_", "_", tgt, pred)) + self.__source_file.write('\n') diff --git a/examples/sequence_tagging/conlleval b/examples/sequence_tagging/conlleval new file mode 100644 index 000000000..12341bae5 --- /dev/null +++ b/examples/sequence_tagging/conlleval @@ -0,0 +1,315 @@ +#!/usr/bin/perl -w +# conlleval: evaluate result of processing CoNLL-2000 shared task +# usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file +# README: http://cnts.uia.ac.be/conll2000/chunking/output.html +# options: l: generate LaTeX output for tables like in +# http://cnts.uia.ac.be/conll2003/ner/example.tex +# r: accept raw result tags (without B- and I- prefix; +# assumes one word per chunk) +# d: alternative delimiter tag (default is single space) +# o: alternative outside tag (default is O) +# note: the file should contain lines with items separated +# by $delimiter characters (default space). The final +# two items should contain the correct tag and the +# guessed tag in that order. Sentences should be +# separated from each other by empty lines or lines +# with $boundary fields (default -X-). +# url: http://lcg-www.uia.ac.be/conll2000/chunking/ +# started: 1998-09-25 +# version: 2004-01-26 +# author: Erik Tjong Kim Sang + +use strict; + +my $false = 0; +my $true = 42; + +my $boundary = "-X-"; # sentence boundary +my $correct; # current corpus chunk tag (I,O,B) +my $correctChunk = 0; # number of correctly identified chunks +my $correctTags = 0; # number of correct chunk tags +my $correctType; # type of current corpus chunk tag (NP,VP,etc.) +my $delimiter = " "; # field delimiter +my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) +my $firstItem; # first feature (for sentence boundary checks) +my $foundCorrect = 0; # number of chunks in corpus +my $foundGuessed = 0; # number of identified chunks +my $guessed; # current guessed chunk tag +my $guessedType; # type of current guessed chunk tag +my $i; # miscellaneous counter +my $inCorrect = $false; # currently processed chunk is correct until now +my $lastCorrect = "O"; # previous chunk tag in corpus +my $latex = 0; # generate LaTeX formatted output +my $lastCorrectType = ""; # type of previously identified chunk tag +my $lastGuessed = "O"; # previously identified chunk tag +my $lastGuessedType = ""; # type of previous chunk tag in corpus +my $lastType; # temporary storage for detecting duplicates +my $line; # line +my $nbrOfFeatures = -1; # number of features per line +my $precision = 0.0; # precision score +my $oTag = "O"; # outside tag, default O +my $raw = 0; # raw input: add B to every token +my $recall = 0.0; # recall score +my $tokenCounter = 0; # token counter (ignores sentence breaks) + +my %correctChunk = (); # number of correctly identified chunks per type +my %foundCorrect = (); # number of chunks in corpus per type +my %foundGuessed = (); # number of identified chunks per type + +my @features; # features on line +my @sortedTypes; # sorted list of chunk type names + +# sanity check +while (@ARGV and $ARGV[0] =~ /^-/) { + if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } + elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } + elsif ($ARGV[0] eq "-d") { + shift(@ARGV); + if (not defined $ARGV[0]) { + die "conlleval: -d requires delimiter character"; + } + $delimiter = shift(@ARGV); + } elsif ($ARGV[0] eq "-o") { + shift(@ARGV); + if (not defined $ARGV[0]) { + die "conlleval: -o requires delimiter character"; + } + $oTag = shift(@ARGV); + } else { die "conlleval: unknown argument $ARGV[0]\n"; } +} +if (@ARGV) { die "conlleval: unexpected command line argument\n"; } +# process input +while () { + chomp($line = $_); + @features = split(/$delimiter/,$line); + if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } + elsif ($nbrOfFeatures != $#features and @features != 0) { + printf STDERR "unexpected number of features: %d (%d)\n", + $#features+1,$nbrOfFeatures+1; + exit(1); + } + if (@features == 0 or + $features[0] eq $boundary) { @features = ($boundary,"O","O"); } + if (@features < 2) { + die "conlleval: unexpected number of features in line $line\n"; + } + if ($raw) { + if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } + if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } + if ($features[$#features] ne "O") { + $features[$#features] = "B-$features[$#features]"; + } + if ($features[$#features-1] ne "O") { + $features[$#features-1] = "B-$features[$#features-1]"; + } + } + # 20040126 ET code which allows hyphens in the types + if ($features[$#features] =~ /^([^-]*)-(.*)$/) { + $guessed = $1; + $guessedType = $2; + } else { + $guessed = $features[$#features]; + $guessedType = ""; + } + pop(@features); + if ($features[$#features] =~ /^([^-]*)-(.*)$/) { + $correct = $1; + $correctType = $2; + } else { + $correct = $features[$#features]; + $correctType = ""; + } + pop(@features); +# ($guessed,$guessedType) = split(/-/,pop(@features)); +# ($correct,$correctType) = split(/-/,pop(@features)); + $guessedType = $guessedType ? $guessedType : ""; + $correctType = $correctType ? $correctType : ""; + $firstItem = shift(@features); + + # 1999-06-26 sentence breaks should always be counted as out of chunk + if ( $firstItem eq $boundary ) { $guessed = "O"; } + + if ($inCorrect) { + if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and + &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and + $lastGuessedType eq $lastCorrectType) { + $inCorrect=$false; + $correctChunk++; + $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? + $correctChunk{$lastCorrectType}+1 : 1; + } elsif ( + &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != + &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or + $guessedType ne $correctType ) { + $inCorrect=$false; + } + } + + if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and + &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and + $guessedType eq $correctType) { $inCorrect = $true; } + + if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { + $foundCorrect++; + $foundCorrect{$correctType} = $foundCorrect{$correctType} ? + $foundCorrect{$correctType}+1 : 1; + } + if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { + $foundGuessed++; + $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? + $foundGuessed{$guessedType}+1 : 1; + } + if ( $firstItem ne $boundary ) { + if ( $correct eq $guessed and $guessedType eq $correctType ) { + $correctTags++; + } + $tokenCounter++; + } + + $lastGuessed = $guessed; + $lastCorrect = $correct; + $lastGuessedType = $guessedType; + $lastCorrectType = $correctType; +} +if ($inCorrect) { + $correctChunk++; + $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? + $correctChunk{$lastCorrectType}+1 : 1; +} + +if (not $latex) { + # compute overall precision, recall and FB1 (default values are 0.0) + $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); + $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); + $FB1 = 2*$precision*$recall/($precision+$recall) + if ($precision+$recall > 0); + + # print overall performance + printf "processed $tokenCounter tokens with $foundCorrect phrases; "; + printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; + if ($tokenCounter>0) { + printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; + printf "precision: %6.2f%%; ",$precision; + printf "recall: %6.2f%%; ",$recall; + printf "FB1: %6.2f\n",$FB1; + } +} + +# sort chunk type names +undef($lastType); +@sortedTypes = (); +foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { + if (not($lastType) or $lastType ne $i) { + push(@sortedTypes,($i)); + } + $lastType = $i; +} +# print performance per chunk type +if (not $latex) { + for $i (@sortedTypes) { + $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; + if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } + else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } + if (not($foundCorrect{$i})) { $recall = 0.0; } + else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } + if ($precision+$recall == 0.0) { $FB1 = 0.0; } + else { $FB1 = 2*$precision*$recall/($precision+$recall); } + printf "%17s: ",$i; + printf "precision: %6.2f%%; ",$precision; + printf "recall: %6.2f%%; ",$recall; + printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; + } +} else { + print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; + for $i (@sortedTypes) { + $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; + if (not($foundGuessed{$i})) { $precision = 0.0; } + else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } + if (not($foundCorrect{$i})) { $recall = 0.0; } + else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } + if ($precision+$recall == 0.0) { $FB1 = 0.0; } + else { $FB1 = 2*$precision*$recall/($precision+$recall); } + printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", + $i,$precision,$recall,$FB1; + } + print "\\hline\n"; + $precision = 0.0; + $recall = 0; + $FB1 = 0.0; + $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); + $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); + $FB1 = 2*$precision*$recall/($precision+$recall) + if ($precision+$recall > 0); + printf STDOUT "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", + $precision,$recall,$FB1; +} + +exit 0; + +# endOfChunk: checks if a chunk ended between the previous and current word +# arguments: previous and current chunk tags, previous and current types +# note: this code is capable of handling other chunk representations +# than the default CoNLL-2000 ones, see EACL'99 paper of Tjong +# Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 + +sub endOfChunk { + my $prevTag = shift(@_); + my $tag = shift(@_); + my $prevType = shift(@_); + my $type = shift(@_); + my $chunkEnd = $false; + + if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } + if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } + if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } + if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } + + if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } + if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } + if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } + if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } + + if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { + $chunkEnd = $true; + } + + # corrected 1998-12-22: these chunks are assumed to have length 1 + if ( $prevTag eq "]" ) { $chunkEnd = $true; } + if ( $prevTag eq "[" ) { $chunkEnd = $true; } + + return($chunkEnd); +} + +# startOfChunk: checks if a chunk started between the previous and current word +# arguments: previous and current chunk tags, previous and current types +# note: this code is capable of handling other chunk representations +# than the default CoNLL-2000 ones, see EACL'99 paper of Tjong +# Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 + +sub startOfChunk { + my $prevTag = shift(@_); + my $tag = shift(@_); + my $prevType = shift(@_); + my $type = shift(@_); + my $chunkStart = $false; + + if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } + if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } + if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } + if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } + + if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } + if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } + if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } + if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } + + if ($tag ne "O" and $tag ne "." and $prevType ne $type) { + $chunkStart = $true; + } + + # corrected 1998-12-22: these chunks are assumed to have length 1 + if ( $tag eq "[" ) { $chunkStart = $true; } + if ( $tag eq "]" ) { $chunkStart = $true; } + + return($chunkStart); +} diff --git a/examples/sequence_tagging/ner.py b/examples/sequence_tagging/ner.py new file mode 100644 index 000000000..8c1379b21 --- /dev/null +++ b/examples/sequence_tagging/ner.py @@ -0,0 +1,256 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Sequence tagging. +""" + +from typing import Any + +import argparse +import importlib +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import texar.torch as tx + +from conll_reader import (create_vocabs, construct_init_word_vecs, + iterate_batch, load_glove, read_data, MAX_CHAR_LENGTH) +from conll_writer import CoNLLWriter +from scores import scores + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--data-path", default="./data", + help="Directory containing NER data (e.g., eng.train.bio.conll).") +parser.add_argument( + "--train", default="eng.train.bio.conll", + help="The file name of the training data.") +parser.add_argument( + "--dev", default="eng.dev.bio.conll", + help="The file name of the dev data.") +parser.add_argument( + "--test", default="eng.test.bio.conll", + help="The file name of the testing data.") +parser.add_argument( + "--embedding", default="glove.6B.100d.txt", + help="The file name of the GloVe embedding.") +parser.add_argument( + "--config", default="config", help="The configurations to use.") +args = parser.parse_args() + +config: Any = importlib.import_module(args.config) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +train_path = os.path.join(args.data_path, args.train) +dev_path = os.path.join(args.data_path, args.dev) +test_path = os.path.join(args.data_path, args.test) +embedding_path = os.path.join(args.data_path, args.embedding) + +EMBEDD_DIM = config.embed_dim +CHAR_DIM = config.char_dim + +# Prepares/loads data +if config.load_glove: + print('loading GloVe embedding...') + glove_dict = load_glove(embedding_path, EMBEDD_DIM) +else: + glove_dict = None + +(word_vocab, char_vocab, ner_vocab), (i2w, i2n) = create_vocabs( + train_path, dev_path, test_path, glove_dict=glove_dict) + +data_train = read_data(train_path, word_vocab, char_vocab, ner_vocab) +data_dev = read_data(dev_path, word_vocab, char_vocab, ner_vocab) +data_test = read_data(test_path, word_vocab, char_vocab, ner_vocab) + +scale = np.sqrt(3.0 / EMBEDD_DIM) +word_vecs = np.random.uniform( + -scale, scale, [len(word_vocab), EMBEDD_DIM]).astype(np.float32) +if config.load_glove: + word_vecs = construct_init_word_vecs(word_vocab, word_vecs, glove_dict) + +scale = np.sqrt(3.0 / CHAR_DIM) +char_vecs = np.random.uniform( + -scale, scale, [len(char_vocab), CHAR_DIM]).astype(np.float32) + + +class NER(nn.Module): + + def __init__(self): + super().__init__() + self.embedder = tx.modules.WordEmbedder( + vocab_size=len(word_vecs), init_value=torch.tensor(word_vecs), + hparams=config.emb) + self.char_embedder = tx.modules.WordEmbedder( + vocab_size=len(char_vecs), init_value=torch.tensor(char_vecs), + hparams=config.char_emb) + self.char_encoder = tx.modules.Conv1DEncoder( + in_channels=MAX_CHAR_LENGTH, in_features=CHAR_DIM, + hparams=config.conv) + self.encoder = tx.modules.BidirectionalRNNEncoder( + input_size=(EMBEDD_DIM + CHAR_DIM), + hparams={"rnn_cell_fw": config.cell, "rnn_cell_bw": config.cell}) + + self.dropout_1 = nn.Dropout(p=0.33) + self.dense_1 = nn.Linear(in_features=2 * config.hidden_size, + out_features=config.tag_space) + self.dropout_2 = nn.Dropout(p=(1 - config.keep_prob)) + self.dense_2 = nn.Linear(in_features=config.tag_space, + out_features=len(ner_vocab)) + + def forward(self, inputs, chars, targets, masks, seq_lengths, mode): + emb_inputs = self.embedder(inputs) + emb_chars = self.char_embedder(chars) + char_shape = emb_chars.shape + emb_chars = torch.reshape(emb_chars, (-1, char_shape[2], CHAR_DIM)) + + char_outputs = self.char_encoder(emb_chars) + char_outputs = torch.reshape(char_outputs, ( + char_shape[0], char_shape[1], CHAR_DIM)) + emb_inputs = torch.cat((emb_inputs, char_outputs), dim=2) + + emb_inputs = self.dropout_1(emb_inputs) + outputs, _ = self.encoder(emb_inputs, sequence_length=seq_lengths) + outputs = torch.cat(outputs, dim=2) + rnn_shape = outputs.shape + outputs = torch.reshape(outputs, (-1, 2 * config.hidden_size)) + outputs = F.elu(self.dense_1(outputs)) + outputs = self.dropout_2(outputs) + logits = self.dense_2(outputs) + logits = torch.reshape( + logits, (rnn_shape[0], rnn_shape[1], len(ner_vocab))) + predicts = torch.argmax(logits, dim=2) + corrects = torch.sum(torch.eq(predicts, targets) * masks) + + if mode == 'train': + mle_loss = tx.losses.sequence_sparse_softmax_cross_entropy( + labels=targets, + logits=logits, + sequence_length=seq_lengths, + average_across_batch=True, + average_across_timesteps=True, + sum_over_timesteps=False) + return mle_loss, corrects + else: + return predicts + + +def main() -> None: + model = NER() + model.to(device) + train_op = tx.core.get_train_op(params=model.parameters(), + hparams=config.opt) + + def _train_epoch(epoch_): + model.train() + + start_time = time.time() + loss = 0. + corr = 0. + num_tokens = 0. + + num_inst = 0 + for batch in iterate_batch(data_train, config.batch_size, shuffle=True): + word, char, ner, mask, length = batch + mle_loss, correct = model(torch.tensor(word, device=device), + torch.tensor(char, device=device), + torch.tensor(ner, device=device), + torch.tensor(mask, device=device), + torch.tensor(length, device=device), + 'train') + mle_loss.backward() + train_op() + + nums = np.sum(length) + num_inst += len(word) + loss += mle_loss * nums + corr += correct + num_tokens += nums + + print("train: %d (%d/%d) loss: %.4f, acc: %.2f%%" % ( + epoch_, num_inst, len(data_train), loss / num_tokens, + corr / num_tokens * 100)) + + print("train: %d loss: %.4f, acc: %.2f%%, time: %.2fs" % ( + epoch_, loss / num_tokens, corr / num_tokens * 100, + time.time() - start_time)) + + @torch.no_grad() + def _eval_epoch(epoch_, mode): + model.eval() + + file_name = 'tmp/%s%d' % (mode, epoch_) + writer = CoNLLWriter(i2w, i2n) + writer.start(file_name) + data = data_dev if mode == 'dev' else data_test + + for batch in iterate_batch(data, config.batch_size, shuffle=False): + word, char, ner, mask, length = batch + predictions = model(torch.tensor(word, device=device), + torch.tensor(char, device=device), + torch.tensor(ner, device=device), + torch.tensor(mask, device=device), + torch.tensor(length, device=device), + mode) + + writer.write(word, predictions.numpy(), ner, length) + writer.close() + acc_, precision_, recall_, f1_ = scores(file_name) + print( + '%s acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' % ( + mode, acc_, precision_, recall_, f1_)) + return acc_, precision_, recall_, f1_ + + dev_f1 = 0.0 + dev_acc = 0.0 + dev_precision = 0.0 + dev_recall = 0.0 + best_epoch = 0 + + test_f1 = 0.0 + test_acc = 0.0 + test_prec = 0.0 + test_recall = 0.0 + + tx.utils.maybe_create_dir('./tmp') + + for epoch in range(config.num_epochs): + _train_epoch(epoch) + acc, precision, recall, f1 = _eval_epoch(epoch, 'dev') + if dev_f1 < f1: + dev_f1 = f1 + dev_acc = acc + dev_precision = precision + dev_recall = recall + best_epoch = epoch + test_acc, test_prec, test_recall, test_f1 = \ + _eval_epoch(epoch, 'test') + print('best acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, ' + 'F1: %.2f%%, epoch: %d' % (dev_acc, dev_precision, dev_recall, + dev_f1, best_epoch)) + print('test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, ' + 'F1: %.2f%%, epoch: %d' % (test_acc, test_prec, test_recall, + test_f1, best_epoch)) + print('---------------------------------------------------') + + +if __name__ == '__main__': + main() diff --git a/examples/sequence_tagging/scores.py b/examples/sequence_tagging/scores.py new file mode 100644 index 000000000..708c1fcfe --- /dev/null +++ b/examples/sequence_tagging/scores.py @@ -0,0 +1,28 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Score function. +""" +import subprocess + + +def scores(path): + bashCommand = 'perl conlleval' + process = subprocess.Popen( + bashCommand.split(), stdout=subprocess.PIPE, stdin=open(path)) + output, _ = process.communicate() + output = output.decode().split('\n')[1].split('%; ') + output = [out.split(' ')[-1] for out in output] + acc, prec, recall, fb1 = tuple(output) + return float(acc), float(prec), float(recall), float(fb1)