From b7ee06c785c927bda2860521f85a2fb5c40083ec Mon Sep 17 00:00:00 2001 From: Kyle Lo <kyleclo@uw.edu> Date: Fri, 23 Sep 2022 15:03:23 -0700 Subject: [PATCH] quality of life improvements related to importing mmda types, predictors, etc. (#141) * add easy import for types * predictor import * import parsers * github; setup * update tests * update tests * use necessary to manage deps * remove bad test * rm bad test * move api to ai2_internal * increment --- .github/workflows/mmda-ci.yml | 2 +- {mmda/types => ai2_internal}/api.py | 0 .../integration_test.py | 5 ++--- .../bibentry_detection_predictor/interface.py | 3 ++- .../integration_test.py | 6 ++---- .../bibentry_predictor_mmda/interface.py | 7 +++---- .../citation_links/integration_test.py | 5 ++--- ai2_internal/citation_links/interface.py | 3 ++- .../citation_mentions/integration_test.py | 5 +---- ai2_internal/citation_mentions/interface.py | 2 +- ai2_internal/layout_parser/interface.py | 9 ++++----- ai2_internal/vila/integration_test.py | 6 ++---- ai2_internal/vila/interface.py | 5 ++--- mmda/parsers/__init__.py | 5 +++++ mmda/predictors/__init__.py | 18 ++++++++++++++++++ mmda/types/__init__.py | 16 ++++++++++++++++ setup.py | 4 +++- tests/test_parsers/test_pdf_plumber_parser.py | 5 ++--- tests/test_predictors/test.json.py | 0 .../test_dictionary_word_predictor.py | 7 ++----- tests/test_types/test_indexers.py | 3 +-- tests/test_types/test_json_conversion.py | 6 ++---- tests/test_types/test_metadata.py | 2 +- tests/test_types/test_span_group.py | 4 +--- 24 files changed, 75 insertions(+), 53 deletions(-) rename {mmda/types => ai2_internal}/api.py (100%) create mode 100644 tests/test_predictors/test.json.py diff --git a/.github/workflows/mmda-ci.yml b/.github/workflows/mmda-ci.yml index ea3e2336..62e48ae7 100644 --- a/.github/workflows/mmda-ci.yml +++ b/.github/workflows/mmda-ci.yml @@ -23,7 +23,7 @@ jobs: - name: Test with Python ${{ matrix.python-version }} run: | - pip install -e .[dev,hf_predictors] + pip install -e .[dev,pysbd_predictors,hf_predictors] pytest tests --ignore=tests/test_predictors/test_vila_predictors.py test_vila_predictors: diff --git a/mmda/types/api.py b/ai2_internal/api.py similarity index 100% rename from mmda/types/api.py rename to ai2_internal/api.py diff --git a/ai2_internal/bibentry_detection_predictor/integration_test.py b/ai2_internal/bibentry_detection_predictor/integration_test.py index 7245740d..b1a015fd 100644 --- a/ai2_internal/bibentry_detection_predictor/integration_test.py +++ b/ai2_internal/bibentry_detection_predictor/integration_test.py @@ -33,12 +33,11 @@ def test_prediction(self, container): import sys import unittest -from .interface import Instance - +from ai2_internal import api from mmda.parsers.pdfplumber_parser import PDFPlumberParser from mmda.rasterizers.rasterizer import PDF2ImageRasterizer -from mmda.types import api from mmda.types.image import tobase64 +from .interface import Instance try: from timo_interface import with_timo_container diff --git a/ai2_internal/bibentry_detection_predictor/interface.py b/ai2_internal/bibentry_detection_predictor/interface.py index 173ffa81..02d9c6d2 100644 --- a/ai2_internal/bibentry_detection_predictor/interface.py +++ b/ai2_internal/bibentry_detection_predictor/interface.py @@ -10,8 +10,9 @@ from pydantic import BaseModel, BaseSettings, Field +from ai2_internal import api from mmda.predictors.d2_predictors.bibentry_detection_predictor import BibEntryDetectionPredictor -from mmda.types import api, image +from mmda.types import image from mmda.types.document import Document diff --git a/ai2_internal/bibentry_predictor_mmda/integration_test.py b/ai2_internal/bibentry_predictor_mmda/integration_test.py index 8e4cdcd4..af90b386 100644 --- a/ai2_internal/bibentry_predictor_mmda/integration_test.py +++ b/ai2_internal/bibentry_predictor_mmda/integration_test.py @@ -35,11 +35,9 @@ def test_prediction(self, container): import sys import unittest -from .interface import Instance, Prediction - -from mmda.types import api +from ai2_internal import api from mmda.types.document import Document - +from .interface import Instance try: from timo_interface import with_timo_container diff --git a/ai2_internal/bibentry_predictor_mmda/interface.py b/ai2_internal/bibentry_predictor_mmda/interface.py index 52e710f8..d185e424 100644 --- a/ai2_internal/bibentry_predictor_mmda/interface.py +++ b/ai2_internal/bibentry_predictor_mmda/interface.py @@ -8,13 +8,12 @@ from typing import List -from pydantic import BaseModel, BaseSettings, Field - -from mmda.types import api -from mmda.types.document import Document +from pydantic import BaseModel, BaseSettings +from ai2_internal import api from mmda.predictors.hf_predictors.bibentry_predictor.predictor import BibEntryPredictor from mmda.predictors.hf_predictors.bibentry_predictor.types import BibEntryStructureSpanGroups +from mmda.types.document import Document class Instance(BaseModel): diff --git a/ai2_internal/citation_links/integration_test.py b/ai2_internal/citation_links/integration_test.py index 96029938..ff05d628 100644 --- a/ai2_internal/citation_links/integration_test.py +++ b/ai2_internal/citation_links/integration_test.py @@ -31,9 +31,8 @@ def test_prediction(self, container): import sys import unittest -from ai2_internal.citation_links.interface import Instance, Prediction -from mmda.types import api - +from ai2_internal import api +from ai2_internal.citation_links.interface import Instance try: from timo_interface import with_timo_container diff --git a/ai2_internal/citation_links/interface.py b/ai2_internal/citation_links/interface.py index cc005b97..4bc948f5 100644 --- a/ai2_internal/citation_links/interface.py +++ b/ai2_internal/citation_links/interface.py @@ -10,10 +10,11 @@ from pydantic import BaseModel, BaseSettings +from ai2_internal import api from mmda.predictors.xgb_predictors.citation_link_predictor import CitationLinkPredictor -from mmda.types import api from mmda.types.document import Document + # these should represent the extracted citation mentions and bibliography entries for a paper class Instance(BaseModel): """ diff --git a/ai2_internal/citation_mentions/integration_test.py b/ai2_internal/citation_mentions/integration_test.py index ad843aae..19fd66bd 100644 --- a/ai2_internal/citation_mentions/integration_test.py +++ b/ai2_internal/citation_mentions/integration_test.py @@ -7,17 +7,14 @@ predict_batch(instances: List[Instance]) -> List[Prediction] """ -import json import logging import pathlib import sys import unittest - +from ai2_internal import api from ai2_internal.citation_mentions.interface import Instance from mmda.parsers.pdfplumber_parser import PDFPlumberParser -from mmda.types import api - try: from timo_interface import with_timo_container diff --git a/ai2_internal/citation_mentions/interface.py b/ai2_internal/citation_mentions/interface.py index 27ec1d31..58f25c3c 100644 --- a/ai2_internal/citation_mentions/interface.py +++ b/ai2_internal/citation_mentions/interface.py @@ -10,8 +10,8 @@ from pydantic import BaseModel, BaseSettings +from ai2_internal import api from mmda.predictors.hf_predictors.mention_predictor import MentionPredictor -from mmda.types import api from mmda.types.document import Document diff --git a/ai2_internal/layout_parser/interface.py b/ai2_internal/layout_parser/interface.py index 9c6a13c9..14d188ac 100644 --- a/ai2_internal/layout_parser/interface.py +++ b/ai2_internal/layout_parser/interface.py @@ -9,15 +9,14 @@ import logging from typing import List +import torch +from pydantic import BaseModel, BaseSettings, Field + +from ai2_internal.api import BoxGroup from mmda.predictors.lp_predictors import LayoutParserPredictor from mmda.types import image -from mmda.types.api import BoxGroup from mmda.types.document import Document -from pydantic import BaseModel, BaseSettings, Field -import torch - - logger = logging.getLogger(__name__) diff --git a/ai2_internal/vila/integration_test.py b/ai2_internal/vila/integration_test.py index c417a012..59fa3ae2 100644 --- a/ai2_internal/vila/integration_test.py +++ b/ai2_internal/vila/integration_test.py @@ -31,17 +31,15 @@ def test_prediction(self, container): import os import sys import unittest +from pathlib import Path from PIL import Image -from pathlib import Path -from mmda.types import api +from ai2_internal import api from mmda.types.document import Document from mmda.types.image import tobase64 - from .interface import Instance - FIXTURE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_fixtures") diff --git a/ai2_internal/vila/interface.py b/ai2_internal/vila/interface.py index b320a862..688d27ae 100644 --- a/ai2_internal/vila/interface.py +++ b/ai2_internal/vila/interface.py @@ -9,17 +9,16 @@ import logging from typing import List -from pydantic import BaseModel, BaseSettings, Field import torch +from pydantic import BaseModel, BaseSettings, Field +from ai2_internal import api from mmda.predictors.hf_predictors.token_classification_predictor import ( IVILATokenClassificationPredictor, ) -from mmda.types import api from mmda.types.document import Document, SpanGroup from mmda.types.image import frombase64 - logger = logging.getLogger(__name__) diff --git a/mmda/parsers/__init__.py b/mmda/parsers/__init__.py index e69de29b..a880b9ce 100644 --- a/mmda/parsers/__init__.py +++ b/mmda/parsers/__init__.py @@ -0,0 +1,5 @@ +from mmda.parsers.pdfplumber_parser import PDFPlumberParser + +__all__ = [ + 'PDFPlumberParser' +] \ No newline at end of file diff --git a/mmda/predictors/__init__.py b/mmda/predictors/__init__.py index e69de29b..42292041 100644 --- a/mmda/predictors/__init__.py +++ b/mmda/predictors/__init__.py @@ -0,0 +1,18 @@ +# flake8: noqa +from necessary import necessary + +from mmda.predictors.heuristic_predictors.dictionary_word_predictor import DictionaryWordPredictor + +__all__ = ['DictionaryWordPredictor'] + +with necessary('pysbd', soft=True) as PYSBD_AVAILABLE: + if PYSBD_AVAILABLE: + from mmda.predictors.heuristic_predictors.sentence_boundary_predictor \ + import PysbdSentenceBoundaryPredictor + __all__.append('PysbdSentenceBoundaryPredictor') + +with necessary(["layoutparser", "torch", "torchvision", "effdet"], soft=True) as PYTORCH_AVAILABLE: + if PYTORCH_AVAILABLE: + from mmda.predictors.lp_predictors import LayoutParserPredictor + __all__.append('LayoutParserPredictor') + diff --git a/mmda/types/__init__.py b/mmda/types/__init__.py index e69de29b..d0f3929c 100644 --- a/mmda/types/__init__.py +++ b/mmda/types/__init__.py @@ -0,0 +1,16 @@ +from mmda.types.document import Document +from mmda.types.annotation import SpanGroup, BoxGroup +from mmda.types.span import Span +from mmda.types.box import Box +from mmda.types.image import PILImage +from mmda.types.metadata import Metadata + +__all__ = [ + 'Document', + 'SpanGroup', + 'BoxGroup', + 'Span', + 'Box', + 'PILImage', + 'Metadata' +] \ No newline at end of file diff --git a/setup.py b/setup.py index 1037e5d1..9b69e5e3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="mmda", description="mmda", - version="0.0.36", + version="0.0.37", url="https://www.github.com/allenai/mmda", python_requires=">= 3.7", packages=find_namespace_packages(include=["mmda*", "ai2_internal*"]), @@ -15,10 +15,12 @@ "pandas", "pydantic", "ncls", + "necessary" ], extras_require={ "dev": ["pytest"], "spacy_predictors": ["spacy"], + "pysbd_predictors": ["pysbd"], "lp_predictors": ["layoutparser", "torch", "torchvision", "effdet"], "hf_predictors": ["torch", "transformers", "smashed==0.1.10"], "vila_predictors": ["vila>=0.4.2,<0.5", "transformers"], diff --git a/tests/test_parsers/test_pdf_plumber_parser.py b/tests/test_parsers/test_pdf_plumber_parser.py index 20e59df7..41bde235 100644 --- a/tests/test_parsers/test_pdf_plumber_parser.py +++ b/tests/test_parsers/test_pdf_plumber_parser.py @@ -2,10 +2,9 @@ import pathlib import unittest -from mmda.types.document import Document -from mmda.parsers.pdfplumber_parser import PDFPlumberParser +from mmda.types import Document +from mmda.parsers import PDFPlumberParser -import string import re os.chdir(pathlib.Path(__file__).parent) diff --git a/tests/test_predictors/test.json.py b/tests/test_predictors/test.json.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_predictors/test_dictionary_word_predictor.py b/tests/test_predictors/test_dictionary_word_predictor.py index dc9521d8..d8085362 100644 --- a/tests/test_predictors/test_dictionary_word_predictor.py +++ b/tests/test_predictors/test_dictionary_word_predictor.py @@ -8,11 +8,8 @@ import unittest from typing import List, Optional, Set -from mmda.predictors.heuristic_predictors.dictionary_word_predictor import ( - DictionaryWordPredictor, -) -from mmda.types.document import Document, SpanGroup -from mmda.types.span import Span +from mmda.predictors import DictionaryWordPredictor +from mmda.types import Document, SpanGroup, Span def mock_document(symbols: str, spans: List[Span], rows: List[SpanGroup]) -> Document: diff --git a/tests/test_types/test_indexers.py b/tests/test_types/test_indexers.py index c208ef43..05ab6885 100644 --- a/tests/test_types/test_indexers.py +++ b/tests/test_types/test_indexers.py @@ -1,8 +1,7 @@ import unittest -from mmda.types.annotation import SpanGroup +from mmda.types import SpanGroup, Span from mmda.types.indexers import SpanGroupIndexer -from mmda.types.span import Span class TestSpanGroupIndexer(unittest.TestCase): diff --git a/tests/test_types/test_json_conversion.py b/tests/test_types/test_json_conversion.py index c8b78627..416358b4 100644 --- a/tests/test_types/test_json_conversion.py +++ b/tests/test_types/test_json_conversion.py @@ -8,10 +8,8 @@ import json from pathlib import Path -from mmda.types.annotation import BoxGroup, SpanGroup -from mmda.types.document import Document -from mmda.parsers.pdfplumber_parser import PDFPlumberParser -from mmda.types.metadata import Metadata +from mmda.types import BoxGroup, SpanGroup, Document, Metadata +from mmda.parsers import PDFPlumberParser PDFFILEPATH = Path(__file__).parent / "../fixtures/1903.10676.pdf" diff --git a/tests/test_types/test_metadata.py b/tests/test_types/test_metadata.py index de1543c0..53602a47 100644 --- a/tests/test_types/test_metadata.py +++ b/tests/test_types/test_metadata.py @@ -8,7 +8,7 @@ import unittest -from mmda.types.metadata import Metadata +from mmda.types import Metadata class TestSpanGroup(unittest.TestCase): diff --git a/tests/test_types/test_span_group.py b/tests/test_types/test_span_group.py index 3e24dfad..9c63db5d 100644 --- a/tests/test_types/test_span_group.py +++ b/tests/test_types/test_span_group.py @@ -7,9 +7,7 @@ import json import unittest -from mmda.types.annotation import SpanGroup -from mmda.types.document import Document -from mmda.types.span import Span +from mmda.types import SpanGroup, Document, Span class TestSpanGroup(unittest.TestCase):