-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Biological NER predictor pack() missing context parameter (#85)
- Loading branch information
Showing
10 changed files
with
253 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
row_id,icd9_code,short_title,long_title | ||
1,01716,Erythem nod tb-oth test,"Erythema nodosum with hypersensitivity reaction in tuberculosis, tubercle bacilli not found by bacteriological or histological examination, but tuberculosis confirmed by other methods [inoculation of animals]" | ||
378,0879,Relapsing fever NOS,"Relapsing fever, unspecified" | ||
379,0880,Bartonellosis,Bartonellosis | ||
380,08881,Lyme disease,Lyme Disease | ||
392,0905,Late congen syph symptom,"Other late congenital syphilis, symptomatic" | ||
420,09324,Syphil pulmonary valve,Syphilitic endocarditis of pulmonary valve | ||
434,09486,Syphil acoustic neuritis,Syphilitic acoustic neuritis | ||
463,09830,Chr gc upper gu NOS,"Chronic gonococcal infection of upper genitourinary tract, site unspecified" | ||
523,04521,Nonparalyt polio-type 1,"Acute nonparalytic poliomyelitis, poliovirus type I" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
BERTTokenizer: | ||
model_path: "resources/NCBI-disease" | ||
|
||
BioBERTNERPredictor: | ||
model_path: "resources/NCBI-disease" | ||
ner_type: "DISEASE" | ||
ignore_labels: ["O"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# ***automatically_generated*** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# ***automatically_generated*** | ||
# ***source json:examples/clinical_pipeline/clinical_onto.json*** | ||
# flake8: noqa | ||
# mypy: ignore-errors | ||
# pylint: skip-file | ||
""" | ||
Automatically generated ontology clinical. Do not change manually. | ||
""" | ||
|
||
from dataclasses import dataclass | ||
from forte.data.data_pack import DataPack | ||
from forte.data.ontology.top import Annotation | ||
from ft.onto.base_ontology import EntityMention | ||
|
||
__all__ = [ | ||
"ClinicalEntityMention", | ||
"Description", | ||
"Body", | ||
] | ||
|
||
|
||
@dataclass | ||
class ClinicalEntityMention(EntityMention): | ||
""" | ||
A span based annotation `ClinicalEntityMention`, normally used to represent an Entity Mention in a piece of clinical text. | ||
""" | ||
|
||
def __init__(self, pack: DataPack, begin: int, end: int): | ||
super().__init__(pack, begin, end) | ||
|
||
|
||
@dataclass | ||
class Description(Annotation): | ||
""" | ||
A span based annotation `Description`, used to represent the description in a piece of clinical note. | ||
""" | ||
|
||
def __init__(self, pack: DataPack, begin: int, end: int): | ||
super().__init__(pack, begin, end) | ||
|
||
|
||
@dataclass | ||
class Body(Annotation): | ||
""" | ||
A span based annotation `Body`, used to represent the actual content in a piece of clinical note. | ||
""" | ||
|
||
def __init__(self, pack: DataPack, begin: int, end: int): | ||
super().__init__(pack, begin, end) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright 2021 The Forte Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import csv | ||
import logging | ||
from pathlib import Path | ||
from typing import Any, Iterator, Union, List | ||
|
||
from smart_open import open | ||
|
||
from bio_ner_predictor.demo.clinical import Description, Body | ||
from forte.data.data_pack import DataPack | ||
from forte.data.base_reader import PackReader | ||
|
||
|
||
class Mimic3DischargeNoteReader(PackReader): | ||
"""This class is designed to read the discharge notes from MIMIC3 dataset | ||
as plain text packs. | ||
For more information for the dataset, visit: | ||
https://mimic.physionet.org/ | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.headers: List[str] = [] | ||
self.text_col = -1 # Default to be last column. | ||
self.description_col = 0 # Default to be first column. | ||
self.__note_count = 0 # Count number of notes processed. | ||
|
||
def _collect( # type: ignore | ||
self, mimic3_path: Union[Path, str] | ||
) -> Iterator[Any]: | ||
with open(mimic3_path) as f: | ||
for r in csv.reader(f): | ||
if 0 < self.configs.max_num_notes <= self.__note_count: | ||
break | ||
yield r | ||
|
||
def _parse_pack(self, row: List[str]) -> Iterator[DataPack]: | ||
if len(self.headers) == 0: | ||
self.headers.extend(row) | ||
for i, h in enumerate(self.headers): | ||
if h == "TEXT": | ||
self.text_col = i | ||
logging.info("Text Column is %d", i) | ||
if h == "DESCRIPTION": | ||
self.description_col = i | ||
logging.info("Description Column is %d", i) | ||
else: | ||
pack: DataPack = DataPack() | ||
description: str = row[self.description_col] | ||
text: str = row[self.text_col] | ||
delimiter = "\n-----------------\n" | ||
full_text = description + delimiter + text | ||
pack.set_text(full_text) | ||
|
||
Description(pack, 0, len(description)) | ||
Body(pack, len(description) + len(delimiter), len(full_text)) | ||
self.__note_count += 1 | ||
yield pack | ||
|
||
@classmethod | ||
def default_configs(cls): | ||
config = super().default_configs() | ||
# If this is set (>0), the reader will only read up to | ||
# the number specified. | ||
config["max_num_notes"] = -1 | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import sys | ||
import time | ||
import os | ||
import yaml | ||
from bio_ner_predictor.mimic3_note_reader import Mimic3DischargeNoteReader | ||
|
||
from fortex.elastic import ElasticSearchPackIndexProcessor | ||
from fortex.huggingface.bio_ner_predictor import BioBERTNERPredictor | ||
from fortex.huggingface.transformers_processor import BERTTokenizer | ||
|
||
from forte.common.configuration import Config | ||
from forte.data.data_pack import DataPack | ||
from forte.pipeline import Pipeline | ||
from forte.processors.writers import PackIdJsonPackWriter | ||
from fortex.nltk import NLTKSentenceSegmenter | ||
import unittest | ||
from ddt import ddt, data, unpack | ||
from forte.data.data_utils import maybe_download | ||
from ft.onto.base_ontology import EntityMention | ||
|
||
@ddt | ||
class TestBioNerPredictor(unittest.TestCase): | ||
r"""Tests Elastic Indexer.""" | ||
|
||
def setUp(self): | ||
self.pl = Pipeline[DataPack]() | ||
|
||
script_dir_path = os.path.dirname(os.path.abspath(__file__)) | ||
data_folder = "bio_ner_predictor" | ||
self.output_path = os.path.join(script_dir_path,data_folder, "test_case_output/") | ||
config_path = os.path.join(script_dir_path,data_folder,"bio_ner_config.yml") | ||
self.input_path = os.path.join(script_dir_path,data_folder, "D_ICD_DIAGNOSES.csv") | ||
self.num_packs = 5 | ||
|
||
# download resources | ||
urls = [ | ||
"https://drive.google.com/file/d/15RSfFkW9syQKtx-_fQ9KshN3BJ27Jf8t/" | ||
"view?usp=sharing", | ||
"https://drive.google.com/file/d/1Nh7D6Xam5JefdoSXRoL7S0DZK1d4i2UK/" | ||
"view?usp=sharing", | ||
"https://drive.google.com/file/d/1YWcI60lGKtTFH01Ai1HnwOKBsrFf2r29/" | ||
"view?usp=sharing", | ||
"https://drive.google.com/file/d/1ElHUEMPQIuWmV0GimroqFphbCvFKskYj/" | ||
"view?usp=sharing", | ||
"https://drive.google.com/file/d/1EhMXlieoEg-bGUbbQ2vN-iyNJvC4Dajl/" | ||
"view?usp=sharing", | ||
] | ||
|
||
filenames = [ | ||
"config.json", | ||
"pytorch_model.bin", | ||
"special_tokens_map.json", | ||
"tokenizer_config.json", | ||
"vocab.txt", | ||
] | ||
model_path = os.path.abspath("resources/NCBI-disease") | ||
config = yaml.safe_load(open(config_path, "r")) | ||
config = Config(config, default_hparams=None) | ||
config.BERTTokenizer.model_path = model_path | ||
config.BioBERTNERPredictor.model_path = model_path | ||
maybe_download(urls=urls, path=model_path, filenames=filenames) | ||
self.assertTrue(os.path.exists(os.path.join(model_path, "pytorch_model.bin"))) | ||
self.pl.set_reader( | ||
Mimic3DischargeNoteReader(), config={"max_num_notes": self.num_packs} | ||
) | ||
self.pl.add(NLTKSentenceSegmenter()) | ||
|
||
|
||
|
||
self.pl.add(BERTTokenizer(), config=config.BERTTokenizer) | ||
self.pl.add(BioBERTNERPredictor(), config=config.BioBERTNERPredictor) | ||
self.pl.add(ElasticSearchPackIndexProcessor()) | ||
self.pl.add( | ||
PackIdJsonPackWriter(), | ||
{ | ||
"output_dir": self.output_path, | ||
"indent": 2, | ||
"overwrite": True, | ||
"drop_record": True, | ||
"zip_pack": True, | ||
}, | ||
) | ||
self.pl.initialize() | ||
|
||
def test_predict(self): | ||
for idx, data_pack in enumerate(self.pl.process_dataset(self.input_path)): | ||
ems = list(data_pack.get_data(EntityMention)) | ||
self.assertTrue(len(ems) > 0) | ||
|
||
self.assertEqual(len(os.listdir(self.output_path)), self.num_packs) | ||
for f_name in os.listdir(self.output_path): | ||
os.remove(os.path.join(self.output_path, f_name)) | ||
os.removedirs(self.output_path) |