Skip to content

Commit

Permalink
quality of life improvements related to importing mmda types, predict…
Browse files Browse the repository at this point in the history
…ors, etc. (#141)

* add easy import for types

* predictor import

* import parsers

* github; setup

* update tests

* update tests

* use necessary to manage deps

* remove bad test

* rm bad test

* move api to ai2_internal

* increment
  • Loading branch information
kyleclo authored Sep 23, 2022
1 parent 4f003b2 commit b7ee06c
Show file tree
Hide file tree
Showing 24 changed files with 75 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mmda-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

- name: Test with Python ${{ matrix.python-version }}
run: |
pip install -e .[dev,hf_predictors]
pip install -e .[dev,pysbd_predictors,hf_predictors]
pytest tests --ignore=tests/test_predictors/test_vila_predictors.py
test_vila_predictors:
Expand Down
File renamed without changes.
5 changes: 2 additions & 3 deletions ai2_internal/bibentry_detection_predictor/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ def test_prediction(self, container):
import sys
import unittest

from .interface import Instance

from ai2_internal import api
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.rasterizers.rasterizer import PDF2ImageRasterizer
from mmda.types import api
from mmda.types.image import tobase64
from .interface import Instance

try:
from timo_interface import with_timo_container
Expand Down
3 changes: 2 additions & 1 deletion ai2_internal/bibentry_detection_predictor/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from pydantic import BaseModel, BaseSettings, Field

from ai2_internal import api
from mmda.predictors.d2_predictors.bibentry_detection_predictor import BibEntryDetectionPredictor
from mmda.types import api, image
from mmda.types import image
from mmda.types.document import Document


Expand Down
6 changes: 2 additions & 4 deletions ai2_internal/bibentry_predictor_mmda/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@ def test_prediction(self, container):
import sys
import unittest

from .interface import Instance, Prediction

from mmda.types import api
from ai2_internal import api
from mmda.types.document import Document

from .interface import Instance

try:
from timo_interface import with_timo_container
Expand Down
7 changes: 3 additions & 4 deletions ai2_internal/bibentry_predictor_mmda/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

from typing import List

from pydantic import BaseModel, BaseSettings, Field

from mmda.types import api
from mmda.types.document import Document
from pydantic import BaseModel, BaseSettings

from ai2_internal import api
from mmda.predictors.hf_predictors.bibentry_predictor.predictor import BibEntryPredictor
from mmda.predictors.hf_predictors.bibentry_predictor.types import BibEntryStructureSpanGroups
from mmda.types.document import Document


class Instance(BaseModel):
Expand Down
5 changes: 2 additions & 3 deletions ai2_internal/citation_links/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def test_prediction(self, container):
import sys
import unittest

from ai2_internal.citation_links.interface import Instance, Prediction
from mmda.types import api

from ai2_internal import api
from ai2_internal.citation_links.interface import Instance

try:
from timo_interface import with_timo_container
Expand Down
3 changes: 2 additions & 1 deletion ai2_internal/citation_links/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

from pydantic import BaseModel, BaseSettings

from ai2_internal import api
from mmda.predictors.xgb_predictors.citation_link_predictor import CitationLinkPredictor
from mmda.types import api
from mmda.types.document import Document


# these should represent the extracted citation mentions and bibliography entries for a paper
class Instance(BaseModel):
"""
Expand Down
5 changes: 1 addition & 4 deletions ai2_internal/citation_mentions/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
predict_batch(instances: List[Instance]) -> List[Prediction]
"""
import json
import logging
import pathlib
import sys
import unittest


from ai2_internal import api
from ai2_internal.citation_mentions.interface import Instance
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.types import api


try:
from timo_interface import with_timo_container
Expand Down
2 changes: 1 addition & 1 deletion ai2_internal/citation_mentions/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from pydantic import BaseModel, BaseSettings

from ai2_internal import api
from mmda.predictors.hf_predictors.mention_predictor import MentionPredictor
from mmda.types import api
from mmda.types.document import Document


Expand Down
9 changes: 4 additions & 5 deletions ai2_internal/layout_parser/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import logging
from typing import List

import torch
from pydantic import BaseModel, BaseSettings, Field

from ai2_internal.api import BoxGroup
from mmda.predictors.lp_predictors import LayoutParserPredictor
from mmda.types import image
from mmda.types.api import BoxGroup
from mmda.types.document import Document

from pydantic import BaseModel, BaseSettings, Field
import torch


logger = logging.getLogger(__name__)


Expand Down
6 changes: 2 additions & 4 deletions ai2_internal/vila/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,15 @@ def test_prediction(self, container):
import os
import sys
import unittest
from pathlib import Path

from PIL import Image
from pathlib import Path

from mmda.types import api
from ai2_internal import api
from mmda.types.document import Document
from mmda.types.image import tobase64

from .interface import Instance


FIXTURE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_fixtures")


Expand Down
5 changes: 2 additions & 3 deletions ai2_internal/vila/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@
import logging
from typing import List

from pydantic import BaseModel, BaseSettings, Field
import torch
from pydantic import BaseModel, BaseSettings, Field

from ai2_internal import api
from mmda.predictors.hf_predictors.token_classification_predictor import (
IVILATokenClassificationPredictor,
)
from mmda.types import api
from mmda.types.document import Document, SpanGroup
from mmda.types.image import frombase64


logger = logging.getLogger(__name__)


Expand Down
5 changes: 5 additions & 0 deletions mmda/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from mmda.parsers.pdfplumber_parser import PDFPlumberParser

__all__ = [
'PDFPlumberParser'
]
18 changes: 18 additions & 0 deletions mmda/predictors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# flake8: noqa
from necessary import necessary

from mmda.predictors.heuristic_predictors.dictionary_word_predictor import DictionaryWordPredictor

__all__ = ['DictionaryWordPredictor']

with necessary('pysbd', soft=True) as PYSBD_AVAILABLE:
if PYSBD_AVAILABLE:
from mmda.predictors.heuristic_predictors.sentence_boundary_predictor \
import PysbdSentenceBoundaryPredictor
__all__.append('PysbdSentenceBoundaryPredictor')

with necessary(["layoutparser", "torch", "torchvision", "effdet"], soft=True) as PYTORCH_AVAILABLE:
if PYTORCH_AVAILABLE:
from mmda.predictors.lp_predictors import LayoutParserPredictor
__all__.append('LayoutParserPredictor')

16 changes: 16 additions & 0 deletions mmda/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from mmda.types.document import Document
from mmda.types.annotation import SpanGroup, BoxGroup
from mmda.types.span import Span
from mmda.types.box import Box
from mmda.types.image import PILImage
from mmda.types.metadata import Metadata

__all__ = [
'Document',
'SpanGroup',
'BoxGroup',
'Span',
'Box',
'PILImage',
'Metadata'
]
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="mmda",
description="mmda",
version="0.0.36",
version="0.0.37",
url="https://www.github.com/allenai/mmda",
python_requires=">= 3.7",
packages=find_namespace_packages(include=["mmda*", "ai2_internal*"]),
Expand All @@ -15,10 +15,12 @@
"pandas",
"pydantic",
"ncls",
"necessary"
],
extras_require={
"dev": ["pytest"],
"spacy_predictors": ["spacy"],
"pysbd_predictors": ["pysbd"],
"lp_predictors": ["layoutparser", "torch", "torchvision", "effdet"],
"hf_predictors": ["torch", "transformers", "smashed==0.1.10"],
"vila_predictors": ["vila>=0.4.2,<0.5", "transformers"],
Expand Down
5 changes: 2 additions & 3 deletions tests/test_parsers/test_pdf_plumber_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import pathlib
import unittest

from mmda.types.document import Document
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.types import Document
from mmda.parsers import PDFPlumberParser

import string
import re

os.chdir(pathlib.Path(__file__).parent)
Expand Down
Empty file.
7 changes: 2 additions & 5 deletions tests/test_predictors/test_dictionary_word_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
import unittest
from typing import List, Optional, Set

from mmda.predictors.heuristic_predictors.dictionary_word_predictor import (
DictionaryWordPredictor,
)
from mmda.types.document import Document, SpanGroup
from mmda.types.span import Span
from mmda.predictors import DictionaryWordPredictor
from mmda.types import Document, SpanGroup, Span


def mock_document(symbols: str, spans: List[Span], rows: List[SpanGroup]) -> Document:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_types/test_indexers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest

from mmda.types.annotation import SpanGroup
from mmda.types import SpanGroup, Span
from mmda.types.indexers import SpanGroupIndexer
from mmda.types.span import Span


class TestSpanGroupIndexer(unittest.TestCase):
Expand Down
6 changes: 2 additions & 4 deletions tests/test_types/test_json_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
import json
from pathlib import Path

from mmda.types.annotation import BoxGroup, SpanGroup
from mmda.types.document import Document
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.types.metadata import Metadata
from mmda.types import BoxGroup, SpanGroup, Document, Metadata
from mmda.parsers import PDFPlumberParser


PDFFILEPATH = Path(__file__).parent / "../fixtures/1903.10676.pdf"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_types/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import unittest


from mmda.types.metadata import Metadata
from mmda.types import Metadata


class TestSpanGroup(unittest.TestCase):
Expand Down
4 changes: 1 addition & 3 deletions tests/test_types/test_span_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import json
import unittest

from mmda.types.annotation import SpanGroup
from mmda.types.document import Document
from mmda.types.span import Span
from mmda.types import SpanGroup, Document, Span


class TestSpanGroup(unittest.TestCase):
Expand Down

0 comments on commit b7ee06c

Please sign in to comment.