From b7ee06c785c927bda2860521f85a2fb5c40083ec Mon Sep 17 00:00:00 2001
From: Kyle Lo <kyleclo@uw.edu>
Date: Fri, 23 Sep 2022 15:03:23 -0700
Subject: [PATCH] quality of life improvements related to importing mmda types,
 predictors, etc. (#141)

* add easy import for types

* predictor import

* import parsers

* github; setup

* update tests

* update tests

* use necessary to manage deps

* remove bad test

* rm bad test

* move api to ai2_internal

* increment
---
 .github/workflows/mmda-ci.yml                  |  2 +-
 {mmda/types => ai2_internal}/api.py            |  0
 .../integration_test.py                        |  5 ++---
 .../bibentry_detection_predictor/interface.py  |  3 ++-
 .../integration_test.py                        |  6 ++----
 .../bibentry_predictor_mmda/interface.py       |  7 +++----
 .../citation_links/integration_test.py         |  5 ++---
 ai2_internal/citation_links/interface.py       |  3 ++-
 .../citation_mentions/integration_test.py      |  5 +----
 ai2_internal/citation_mentions/interface.py    |  2 +-
 ai2_internal/layout_parser/interface.py        |  9 ++++-----
 ai2_internal/vila/integration_test.py          |  6 ++----
 ai2_internal/vila/interface.py                 |  5 ++---
 mmda/parsers/__init__.py                       |  5 +++++
 mmda/predictors/__init__.py                    | 18 ++++++++++++++++++
 mmda/types/__init__.py                         | 16 ++++++++++++++++
 setup.py                                       |  4 +++-
 tests/test_parsers/test_pdf_plumber_parser.py  |  5 ++---
 tests/test_predictors/test.json.py             |  0
 .../test_dictionary_word_predictor.py          |  7 ++-----
 tests/test_types/test_indexers.py              |  3 +--
 tests/test_types/test_json_conversion.py       |  6 ++----
 tests/test_types/test_metadata.py              |  2 +-
 tests/test_types/test_span_group.py            |  4 +---
 24 files changed, 75 insertions(+), 53 deletions(-)
 rename {mmda/types => ai2_internal}/api.py (100%)
 create mode 100644 tests/test_predictors/test.json.py

diff --git a/.github/workflows/mmda-ci.yml b/.github/workflows/mmda-ci.yml
index ea3e2336..62e48ae7 100644
--- a/.github/workflows/mmda-ci.yml
+++ b/.github/workflows/mmda-ci.yml
@@ -23,7 +23,7 @@ jobs:
 
       - name: Test with Python ${{ matrix.python-version }}
         run: | 
-          pip install -e .[dev,hf_predictors]
+          pip install -e .[dev,pysbd_predictors,hf_predictors]
           pytest tests --ignore=tests/test_predictors/test_vila_predictors.py
 
   test_vila_predictors:
diff --git a/mmda/types/api.py b/ai2_internal/api.py
similarity index 100%
rename from mmda/types/api.py
rename to ai2_internal/api.py
diff --git a/ai2_internal/bibentry_detection_predictor/integration_test.py b/ai2_internal/bibentry_detection_predictor/integration_test.py
index 7245740d..b1a015fd 100644
--- a/ai2_internal/bibentry_detection_predictor/integration_test.py
+++ b/ai2_internal/bibentry_detection_predictor/integration_test.py
@@ -33,12 +33,11 @@ def test_prediction(self, container):
 import sys
 import unittest
 
-from .interface import Instance
-
+from ai2_internal import api
 from mmda.parsers.pdfplumber_parser import PDFPlumberParser
 from mmda.rasterizers.rasterizer import PDF2ImageRasterizer
-from mmda.types import api
 from mmda.types.image import tobase64
+from .interface import Instance
 
 try:
     from timo_interface import with_timo_container
diff --git a/ai2_internal/bibentry_detection_predictor/interface.py b/ai2_internal/bibentry_detection_predictor/interface.py
index 173ffa81..02d9c6d2 100644
--- a/ai2_internal/bibentry_detection_predictor/interface.py
+++ b/ai2_internal/bibentry_detection_predictor/interface.py
@@ -10,8 +10,9 @@
 
 from pydantic import BaseModel, BaseSettings, Field
 
+from ai2_internal import api
 from mmda.predictors.d2_predictors.bibentry_detection_predictor import BibEntryDetectionPredictor
-from mmda.types import api, image
+from mmda.types import image
 from mmda.types.document import Document
 
 
diff --git a/ai2_internal/bibentry_predictor_mmda/integration_test.py b/ai2_internal/bibentry_predictor_mmda/integration_test.py
index 8e4cdcd4..af90b386 100644
--- a/ai2_internal/bibentry_predictor_mmda/integration_test.py
+++ b/ai2_internal/bibentry_predictor_mmda/integration_test.py
@@ -35,11 +35,9 @@ def test_prediction(self, container):
 import sys
 import unittest
 
-from .interface import Instance, Prediction
-
-from mmda.types import api
+from ai2_internal import api
 from mmda.types.document import Document
-
+from .interface import Instance
 
 try:
     from timo_interface import with_timo_container
diff --git a/ai2_internal/bibentry_predictor_mmda/interface.py b/ai2_internal/bibentry_predictor_mmda/interface.py
index 52e710f8..d185e424 100644
--- a/ai2_internal/bibentry_predictor_mmda/interface.py
+++ b/ai2_internal/bibentry_predictor_mmda/interface.py
@@ -8,13 +8,12 @@
 
 from typing import List
 
-from pydantic import BaseModel, BaseSettings, Field
-
-from mmda.types import api
-from mmda.types.document import Document
+from pydantic import BaseModel, BaseSettings
 
+from ai2_internal import api
 from mmda.predictors.hf_predictors.bibentry_predictor.predictor import BibEntryPredictor
 from mmda.predictors.hf_predictors.bibentry_predictor.types import BibEntryStructureSpanGroups
+from mmda.types.document import Document
 
 
 class Instance(BaseModel):
diff --git a/ai2_internal/citation_links/integration_test.py b/ai2_internal/citation_links/integration_test.py
index 96029938..ff05d628 100644
--- a/ai2_internal/citation_links/integration_test.py
+++ b/ai2_internal/citation_links/integration_test.py
@@ -31,9 +31,8 @@ def test_prediction(self, container):
 import sys
 import unittest
 
-from ai2_internal.citation_links.interface import Instance, Prediction
-from mmda.types import api
-
+from ai2_internal import api
+from ai2_internal.citation_links.interface import Instance
 
 try:
     from timo_interface import with_timo_container
diff --git a/ai2_internal/citation_links/interface.py b/ai2_internal/citation_links/interface.py
index cc005b97..4bc948f5 100644
--- a/ai2_internal/citation_links/interface.py
+++ b/ai2_internal/citation_links/interface.py
@@ -10,10 +10,11 @@
 
 from pydantic import BaseModel, BaseSettings
 
+from ai2_internal import api
 from mmda.predictors.xgb_predictors.citation_link_predictor import CitationLinkPredictor
-from mmda.types import api
 from mmda.types.document import Document
 
+
 # these should represent the extracted citation mentions and bibliography entries for a paper
 class Instance(BaseModel):
     """
diff --git a/ai2_internal/citation_mentions/integration_test.py b/ai2_internal/citation_mentions/integration_test.py
index ad843aae..19fd66bd 100644
--- a/ai2_internal/citation_mentions/integration_test.py
+++ b/ai2_internal/citation_mentions/integration_test.py
@@ -7,17 +7,14 @@
 
 predict_batch(instances: List[Instance]) -> List[Prediction]
 """
-import json
 import logging
 import pathlib
 import sys
 import unittest
 
-
+from ai2_internal import api
 from ai2_internal.citation_mentions.interface import Instance
 from mmda.parsers.pdfplumber_parser import PDFPlumberParser
-from mmda.types import api
-
 
 try:
     from timo_interface import with_timo_container
diff --git a/ai2_internal/citation_mentions/interface.py b/ai2_internal/citation_mentions/interface.py
index 27ec1d31..58f25c3c 100644
--- a/ai2_internal/citation_mentions/interface.py
+++ b/ai2_internal/citation_mentions/interface.py
@@ -10,8 +10,8 @@
 
 from pydantic import BaseModel, BaseSettings
 
+from ai2_internal import api
 from mmda.predictors.hf_predictors.mention_predictor import MentionPredictor
-from mmda.types import api
 from mmda.types.document import Document
 
 
diff --git a/ai2_internal/layout_parser/interface.py b/ai2_internal/layout_parser/interface.py
index 9c6a13c9..14d188ac 100644
--- a/ai2_internal/layout_parser/interface.py
+++ b/ai2_internal/layout_parser/interface.py
@@ -9,15 +9,14 @@
 import logging
 from typing import List
 
+import torch
+from pydantic import BaseModel, BaseSettings, Field
+
+from ai2_internal.api import BoxGroup
 from mmda.predictors.lp_predictors import LayoutParserPredictor
 from mmda.types import image
-from mmda.types.api import BoxGroup
 from mmda.types.document import Document
 
-from pydantic import BaseModel, BaseSettings, Field
-import torch
-
-
 logger = logging.getLogger(__name__)
 
 
diff --git a/ai2_internal/vila/integration_test.py b/ai2_internal/vila/integration_test.py
index c417a012..59fa3ae2 100644
--- a/ai2_internal/vila/integration_test.py
+++ b/ai2_internal/vila/integration_test.py
@@ -31,17 +31,15 @@ def test_prediction(self, container):
 import os
 import sys
 import unittest
+from pathlib import Path
 
 from PIL import Image
-from pathlib import Path
 
-from mmda.types import api
+from ai2_internal import api
 from mmda.types.document import Document
 from mmda.types.image import tobase64
-
 from .interface import Instance
 
-
 FIXTURE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_fixtures")
 
 
diff --git a/ai2_internal/vila/interface.py b/ai2_internal/vila/interface.py
index b320a862..688d27ae 100644
--- a/ai2_internal/vila/interface.py
+++ b/ai2_internal/vila/interface.py
@@ -9,17 +9,16 @@
 import logging
 from typing import List
 
-from pydantic import BaseModel, BaseSettings, Field
 import torch
+from pydantic import BaseModel, BaseSettings, Field
 
+from ai2_internal import api
 from mmda.predictors.hf_predictors.token_classification_predictor import (
     IVILATokenClassificationPredictor,
 )
-from mmda.types import api
 from mmda.types.document import Document, SpanGroup
 from mmda.types.image import frombase64
 
-
 logger = logging.getLogger(__name__)
 
 
diff --git a/mmda/parsers/__init__.py b/mmda/parsers/__init__.py
index e69de29b..a880b9ce 100644
--- a/mmda/parsers/__init__.py
+++ b/mmda/parsers/__init__.py
@@ -0,0 +1,5 @@
+from mmda.parsers.pdfplumber_parser import PDFPlumberParser
+
+__all__ = [
+    'PDFPlumberParser'
+]
\ No newline at end of file
diff --git a/mmda/predictors/__init__.py b/mmda/predictors/__init__.py
index e69de29b..42292041 100644
--- a/mmda/predictors/__init__.py
+++ b/mmda/predictors/__init__.py
@@ -0,0 +1,18 @@
+# flake8: noqa
+from necessary import necessary
+
+from mmda.predictors.heuristic_predictors.dictionary_word_predictor import DictionaryWordPredictor
+
+__all__ = ['DictionaryWordPredictor']
+
+with necessary('pysbd', soft=True) as PYSBD_AVAILABLE:
+    if PYSBD_AVAILABLE:
+        from mmda.predictors.heuristic_predictors.sentence_boundary_predictor \
+            import PysbdSentenceBoundaryPredictor
+        __all__.append('PysbdSentenceBoundaryPredictor')
+
+with necessary(["layoutparser", "torch", "torchvision", "effdet"], soft=True) as PYTORCH_AVAILABLE:
+    if PYTORCH_AVAILABLE:
+        from mmda.predictors.lp_predictors import LayoutParserPredictor
+        __all__.append('LayoutParserPredictor')
+
diff --git a/mmda/types/__init__.py b/mmda/types/__init__.py
index e69de29b..d0f3929c 100644
--- a/mmda/types/__init__.py
+++ b/mmda/types/__init__.py
@@ -0,0 +1,16 @@
+from mmda.types.document import Document
+from mmda.types.annotation import SpanGroup, BoxGroup
+from mmda.types.span import Span
+from mmda.types.box import Box
+from mmda.types.image import PILImage
+from mmda.types.metadata import Metadata
+
+__all__ = [
+    'Document',
+    'SpanGroup',
+    'BoxGroup',
+    'Span',
+    'Box',
+    'PILImage',
+    'Metadata'
+]
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 1037e5d1..9b69e5e3 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
 setup(
     name="mmda",
     description="mmda",
-    version="0.0.36",
+    version="0.0.37",
     url="https://www.github.com/allenai/mmda",
     python_requires=">= 3.7",
     packages=find_namespace_packages(include=["mmda*", "ai2_internal*"]),
@@ -15,10 +15,12 @@
         "pandas",
         "pydantic",
         "ncls",
+        "necessary"
     ],
     extras_require={
         "dev": ["pytest"],
         "spacy_predictors": ["spacy"],
+        "pysbd_predictors": ["pysbd"],
         "lp_predictors": ["layoutparser", "torch", "torchvision", "effdet"],
         "hf_predictors": ["torch", "transformers", "smashed==0.1.10"],
         "vila_predictors": ["vila>=0.4.2,<0.5", "transformers"],
diff --git a/tests/test_parsers/test_pdf_plumber_parser.py b/tests/test_parsers/test_pdf_plumber_parser.py
index 20e59df7..41bde235 100644
--- a/tests/test_parsers/test_pdf_plumber_parser.py
+++ b/tests/test_parsers/test_pdf_plumber_parser.py
@@ -2,10 +2,9 @@
 import pathlib
 import unittest
 
-from mmda.types.document import Document
-from mmda.parsers.pdfplumber_parser import PDFPlumberParser
+from mmda.types import Document
+from mmda.parsers import PDFPlumberParser
 
-import string
 import re
 
 os.chdir(pathlib.Path(__file__).parent)
diff --git a/tests/test_predictors/test.json.py b/tests/test_predictors/test.json.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_predictors/test_dictionary_word_predictor.py b/tests/test_predictors/test_dictionary_word_predictor.py
index dc9521d8..d8085362 100644
--- a/tests/test_predictors/test_dictionary_word_predictor.py
+++ b/tests/test_predictors/test_dictionary_word_predictor.py
@@ -8,11 +8,8 @@
 import unittest
 from typing import List, Optional, Set
 
-from mmda.predictors.heuristic_predictors.dictionary_word_predictor import (
-    DictionaryWordPredictor,
-)
-from mmda.types.document import Document, SpanGroup
-from mmda.types.span import Span
+from mmda.predictors import DictionaryWordPredictor
+from mmda.types import Document, SpanGroup, Span
 
 
 def mock_document(symbols: str, spans: List[Span], rows: List[SpanGroup]) -> Document:
diff --git a/tests/test_types/test_indexers.py b/tests/test_types/test_indexers.py
index c208ef43..05ab6885 100644
--- a/tests/test_types/test_indexers.py
+++ b/tests/test_types/test_indexers.py
@@ -1,8 +1,7 @@
 import unittest
 
-from mmda.types.annotation import SpanGroup
+from mmda.types import SpanGroup, Span
 from mmda.types.indexers import SpanGroupIndexer
-from mmda.types.span import Span
 
 
 class TestSpanGroupIndexer(unittest.TestCase):
diff --git a/tests/test_types/test_json_conversion.py b/tests/test_types/test_json_conversion.py
index c8b78627..416358b4 100644
--- a/tests/test_types/test_json_conversion.py
+++ b/tests/test_types/test_json_conversion.py
@@ -8,10 +8,8 @@
 import json
 from pathlib import Path
 
-from mmda.types.annotation import BoxGroup, SpanGroup
-from mmda.types.document import Document
-from mmda.parsers.pdfplumber_parser import PDFPlumberParser
-from mmda.types.metadata import Metadata
+from mmda.types import BoxGroup, SpanGroup, Document, Metadata
+from mmda.parsers import PDFPlumberParser
 
 
 PDFFILEPATH = Path(__file__).parent / "../fixtures/1903.10676.pdf"
diff --git a/tests/test_types/test_metadata.py b/tests/test_types/test_metadata.py
index de1543c0..53602a47 100644
--- a/tests/test_types/test_metadata.py
+++ b/tests/test_types/test_metadata.py
@@ -8,7 +8,7 @@
 import unittest
 
 
-from mmda.types.metadata import Metadata
+from mmda.types import Metadata
 
 
 class TestSpanGroup(unittest.TestCase):
diff --git a/tests/test_types/test_span_group.py b/tests/test_types/test_span_group.py
index 3e24dfad..9c63db5d 100644
--- a/tests/test_types/test_span_group.py
+++ b/tests/test_types/test_span_group.py
@@ -7,9 +7,7 @@
 import json
 import unittest
 
-from mmda.types.annotation import SpanGroup
-from mmda.types.document import Document
-from mmda.types.span import Span
+from mmda.types import SpanGroup, Document, Span
 
 
 class TestSpanGroup(unittest.TestCase):