Skip to content

Commit

Permalink
new: Add multitask split_parse command
Browse files Browse the repository at this point in the history
  • Loading branch information
ivyleavedtoadflax committed Mar 31, 2020
1 parent 3e48684 commit f392f9f
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 1 deletion.
2 changes: 2 additions & 0 deletions deep_reference_parser/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from .train import train
from .split import split
from .parse import parse
from .split_parse import split_parse

commands = {
"split": split,
"parse": parse,
"train": train,
"split_parse": split_parse,
}

if len(sys.argv) == 1:
Expand Down
1 change: 1 addition & 0 deletions deep_reference_parser/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
__license__ = "MIT"
__splitter_model_version__ = "2020.3.6_splitting"
__parser_model_version__ = "2020.3.8_parsing"
__multitask_model_version__ = "2020.3.18_multitask"
3 changes: 2 additions & 1 deletion deep_reference_parser/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from urllib import parse, request

from .logger import logger
from .__version__ import __splitter_model_version__, __parser_model_version__
from .__version__ import __splitter_model_version__, __parser_model_version__, __multitask_model_version__


def get_path(path):
Expand All @@ -15,6 +15,7 @@ def get_path(path):

SPLITTER_CFG = get_path(f"configs/{__splitter_model_version__}.ini")
PARSER_CFG = get_path(f"configs/{__parser_model_version__}.ini")
MULTITASK_CFG = get_path(f"configs/{__multitask_model_version__}.ini")


def download_model_artefact(artefact, s3_slug):
Expand Down
198 changes: 198 additions & 0 deletions deep_reference_parser/split_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#!/usr/bin/env python3
# coding: utf-8
"""
Run predictions from a pre-trained model
"""

import itertools
import json
import os

import en_core_web_sm
import plac
import spacy
import wasabi

import warnings

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)

from deep_reference_parser import __file__
from deep_reference_parser.__version__ import __splitter_model_version__
from deep_reference_parser.common import MULTITASK_CFG, download_model_artefact
from deep_reference_parser.deep_reference_parser import DeepReferenceParser
from deep_reference_parser.logger import logger
from deep_reference_parser.model_utils import get_config
from deep_reference_parser.reference_utils import break_into_chunks
from deep_reference_parser.tokens_to_references import tokens_to_references

msg = wasabi.Printer(icons={"check": "\u2023"})


class SplitParser:
def __init__(self, config_file):

msg.info(f"Using config file: {config_file}")

cfg = get_config(config_file)

msg.info(
f"Attempting to download model artefacts if they are not found locally in {cfg['build']['output_path']}. This may take some time..."
)

# Build config

OUTPUT_PATH = cfg["build"]["output_path"]
S3_SLUG = cfg["data"]["s3_slug"]

# Check whether the necessary artefacts exists and download them if
# not.

artefacts = [
"indices.pickle",
"weights.h5",
]

for artefact in artefacts:
with msg.loading(f"Could not find {artefact} locally, downloading..."):
try:
artefact = os.path.join(OUTPUT_PATH, artefact)
download_model_artefact(artefact, S3_SLUG)
msg.good(f"Found {artefact}")
except:
msg.fail(f"Could not download {S3_SLUG}{artefact}")
logger.exception("Could not download %s%s", S3_SLUG, artefact)

# Check on word embedding and download if not exists

WORD_EMBEDDINGS = cfg["build"]["word_embeddings"]

with msg.loading(f"Could not find {WORD_EMBEDDINGS} locally, downloading..."):
try:
download_model_artefact(WORD_EMBEDDINGS, S3_SLUG)
msg.good(f"Found {WORD_EMBEDDINGS}")
except:
msg.fail(f"Could not download {S3_SLUG}{WORD_EMBEDDINGS}")
logger.exception("Could not download %s", WORD_EMBEDDINGS)

OUTPUT = cfg["build"]["output"]
PRETRAINED_EMBEDDING = cfg["build"]["pretrained_embedding"]
DROPOUT = float(cfg["build"]["dropout"])
LSTM_HIDDEN = int(cfg["build"]["lstm_hidden"])
WORD_EMBEDDING_SIZE = int(cfg["build"]["word_embedding_size"])
CHAR_EMBEDDING_SIZE = int(cfg["build"]["char_embedding_size"])

self.MAX_WORDS = int(cfg["data"]["line_limit"])

# Evaluate config

self.drp = DeepReferenceParser(output_path=OUTPUT_PATH)

# Encode data and load required mapping dicts. Note that the max word and
# max char lengths will be loaded in this step.

self.drp.load_data(OUTPUT_PATH)

# Build the model architecture

self.drp.build_model(
output=OUTPUT,
word_embeddings=WORD_EMBEDDINGS,
pretrained_embedding=PRETRAINED_EMBEDDING,
dropout=DROPOUT,
lstm_hidden=LSTM_HIDDEN,
word_embedding_size=WORD_EMBEDDING_SIZE,
char_embedding_size=CHAR_EMBEDDING_SIZE,
)

def split_parse(self, text, return_tokens=False, verbose=False):

nlp = en_core_web_sm.load()
doc = nlp(text)
chunks = break_into_chunks(doc, max_words=self.MAX_WORDS)
tokens = [[token.text for token in chunk] for chunk in chunks]

preds = self.drp.predict(tokens, load_weights=True)

return preds

# If tokens argument passed, return the labelled tokens

#if return_tokens:

# flat_predictions = list(itertools.chain.from_iterable(preds))
# flat_X = list(itertools.chain.from_iterable(tokens))
# rows = [i for i in zip(flat_X, flat_predictions)]

# if verbose:

# msg.divider("Token Results")

# header = ("token", "label")
# aligns = ("r", "l")
# formatted = wasabi.table(
# rows, header=header, divider=True, aligns=aligns
# )
# print(formatted)

# out = rows

#else:

# # Otherwise convert the tokens into references and return

# refs = tokens_to_references(tokens, preds)

# if verbose:

# msg.divider("Results")

# if refs:

# msg.good(f"Found {len(refs)} references.")
# msg.info("Printing found references:")

# for ref in refs:
# msg.text(ref, icon="check", spaced=True)

# else:

# msg.fail("Failed to find any references.")

# out = refs

#return out


@plac.annotations(
text=("Plaintext from which to extract references", "positional", None, str),
config_file=("Path to config file", "option", "c", str),
tokens=("Output tokens instead of complete references", "flag", "t", str),
outfile=("Path to json file to which results will be written", "option", "o", str),
)
def split_parse(text, config_file=MULTITASK_CFG, tokens=False, outfile=None):
"""
Runs the default splitting model and pretty prints results to console unless
--outfile is parsed with a path. Files output to the path specified in
--outfile will be a valid json. Can output either tokens (with -t|--tokens)
or split naively into references based on the b-r tag (default).
NOTE: that this function is provided for examples only and should not be used
in production as the model is instantiated each time the command is run. To
use in a production setting, a more sensible approach would be to replicate
the split or parse functions within your own logic.
"""
mt = SplitParser(config_file)
if outfile:
out = mt.split_parse(text, return_tokens=tokens, verbose=False)

try:
with open(outfile, "w") as fb:
json.dump(out, fb)
msg.good(f"Wrote model output to {outfile}")
except:
msg.fail(f"Failed to write output to {outfile}")

else:
out = mt.split_parse(text, return_tokens=tokens, verbose=True)

0 comments on commit f392f9f

Please sign in to comment.