Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quality of life improvements related to importing mmda types, predictors, etc. #141

Merged
merged 12 commits into from
Sep 23, 2022
2 changes: 1 addition & 1 deletion .github/workflows/mmda-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

- name: Test with Python ${{ matrix.python-version }}
run: |
pip install -e .[dev,hf_predictors]
pip install -e .[dev,pysbd_predictors,hf_predictors]
pytest tests --ignore=tests/test_predictors/test_vila_predictors.py

test_vila_predictors:
Expand Down
File renamed without changes.
5 changes: 2 additions & 3 deletions ai2_internal/bibentry_detection_predictor/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ def test_prediction(self, container):
import sys
import unittest

from .interface import Instance

from ai2_internal import api
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.rasterizers.rasterizer import PDF2ImageRasterizer
from mmda.types import api
from mmda.types.image import tobase64
from .interface import Instance

try:
from timo_interface import with_timo_container
Expand Down
3 changes: 2 additions & 1 deletion ai2_internal/bibentry_detection_predictor/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from pydantic import BaseModel, BaseSettings, Field

from ai2_internal import api
from mmda.predictors.d2_predictors.bibentry_detection_predictor import BibEntryDetectionPredictor
from mmda.types import api, image
from mmda.types import image
from mmda.types.document import Document


Expand Down
6 changes: 2 additions & 4 deletions ai2_internal/bibentry_predictor_mmda/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@ def test_prediction(self, container):
import sys
import unittest

from .interface import Instance, Prediction

from mmda.types import api
from ai2_internal import api
from mmda.types.document import Document

from .interface import Instance

try:
from timo_interface import with_timo_container
Expand Down
7 changes: 3 additions & 4 deletions ai2_internal/bibentry_predictor_mmda/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

from typing import List

from pydantic import BaseModel, BaseSettings, Field

from mmda.types import api
from mmda.types.document import Document
from pydantic import BaseModel, BaseSettings

from ai2_internal import api
from mmda.predictors.hf_predictors.bibentry_predictor.predictor import BibEntryPredictor
from mmda.predictors.hf_predictors.bibentry_predictor.types import BibEntryStructureSpanGroups
from mmda.types.document import Document


class Instance(BaseModel):
Expand Down
5 changes: 2 additions & 3 deletions ai2_internal/citation_links/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def test_prediction(self, container):
import sys
import unittest

from ai2_internal.citation_links.interface import Instance, Prediction
from mmda.types import api

from ai2_internal import api
from ai2_internal.citation_links.interface import Instance

try:
from timo_interface import with_timo_container
Expand Down
3 changes: 2 additions & 1 deletion ai2_internal/citation_links/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

from pydantic import BaseModel, BaseSettings

from ai2_internal import api
from mmda.predictors.xgb_predictors.citation_link_predictor import CitationLinkPredictor
from mmda.types import api
from mmda.types.document import Document


# these should represent the extracted citation mentions and bibliography entries for a paper
class Instance(BaseModel):
"""
Expand Down
4 changes: 1 addition & 3 deletions ai2_internal/citation_mentions/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@

predict_batch(instances: List[Instance]) -> List[Prediction]
"""
import json
import logging
import pathlib
import sys
import unittest

from ai2_internal import api
from ai2_internal.citation_mentions.interface import Instance
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.types import api


try:
from timo_interface import with_timo_container
Expand Down
2 changes: 1 addition & 1 deletion ai2_internal/citation_mentions/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from pydantic import BaseModel, BaseSettings

from ai2_internal import api
from mmda.predictors.hf_predictors.mention_predictor import MentionPredictor
from mmda.types import api
from mmda.types.document import Document


Expand Down
9 changes: 4 additions & 5 deletions ai2_internal/layout_parser/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import logging
from typing import List

import torch
from pydantic import BaseModel, BaseSettings, Field

from ai2_internal.api import BoxGroup
from mmda.predictors.lp_predictors import LayoutParserPredictor
from mmda.types import image
from mmda.types.api import BoxGroup
from mmda.types.document import Document

from pydantic import BaseModel, BaseSettings, Field
import torch


logger = logging.getLogger(__name__)


Expand Down
6 changes: 2 additions & 4 deletions ai2_internal/vila/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,15 @@ def test_prediction(self, container):
import os
import sys
import unittest
from pathlib import Path

from PIL import Image
from pathlib import Path

from mmda.types import api
from ai2_internal import api
from mmda.types.document import Document
from mmda.types.image import tobase64

from .interface import Instance


FIXTURE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_fixtures")


Expand Down
5 changes: 2 additions & 3 deletions ai2_internal/vila/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@
import logging
from typing import List

from pydantic import BaseModel, BaseSettings, Field
import torch
from pydantic import BaseModel, BaseSettings, Field

from ai2_internal import api
from mmda.predictors.hf_predictors.token_classification_predictor import (
IVILATokenClassificationPredictor,
)
from mmda.types import api
from mmda.types.document import Document, SpanGroup
from mmda.types.image import frombase64


logger = logging.getLogger(__name__)


Expand Down
5 changes: 5 additions & 0 deletions mmda/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from mmda.parsers.pdfplumber_parser import PDFPlumberParser

__all__ = [
'PDFPlumberParser'
]
18 changes: 18 additions & 0 deletions mmda/predictors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# flake8: noqa
from necessary import necessary

from mmda.predictors.heuristic_predictors.dictionary_word_predictor import DictionaryWordPredictor

__all__ = ['DictionaryWordPredictor']

with necessary('pysbd', soft=True) as PYSBD_AVAILABLE:
if PYSBD_AVAILABLE:
from mmda.predictors.heuristic_predictors.sentence_boundary_predictor \
import PysbdSentenceBoundaryPredictor
__all__.append('PysbdSentenceBoundaryPredictor')

with necessary(["layoutparser", "torch", "torchvision", "effdet"], soft=True) as PYTORCH_AVAILABLE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to not have to define this in three places:

  1. setup extras_require
  2. required_backends on the predictor itself
  3. here

But I don't have a good suggestion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah :/ ive created an issue here: #144

i need to think a bit more how this would be resolved. a reasonable stop-gap could be defining a top-level config that all 3 locations imports from; reserve for next PR

if PYTORCH_AVAILABLE:
from mmda.predictors.lp_predictors import LayoutParserPredictor
__all__.append('LayoutParserPredictor')

16 changes: 16 additions & 0 deletions mmda/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from mmda.types.document import Document
from mmda.types.annotation import SpanGroup, BoxGroup
from mmda.types.span import Span
from mmda.types.box import Box
from mmda.types.image import PILImage
from mmda.types.metadata import Metadata

__all__ = [
'Document',
'SpanGroup',
'BoxGroup',
'Span',
'Box',
'PILImage',
'Metadata'
]
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="mmda",
description="mmda",
version="0.0.35",
version="0.0.36",
url="https://www.github.com/allenai/mmda",
python_requires=">= 3.7",
packages=find_namespace_packages(include=["mmda*", "ai2_internal*"]),
Expand All @@ -15,10 +15,12 @@
"pandas",
"pydantic",
"ncls",
"necessary"
],
extras_require={
"dev": ["pytest"],
"spacy_predictors": ["spacy"],
"pysbd_predictors": ["pysbd"],
"lp_predictors": ["layoutparser", "torch", "torchvision", "effdet"],
"hf_predictors": ["torch", "transformers", "smashed==0.1.10"],
"vila_predictors": ["vila>=0.4.2,<0.5", "transformers"],
Expand Down
5 changes: 2 additions & 3 deletions tests/test_parsers/test_pdf_plumber_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import pathlib
import unittest

from mmda.types.document import Document
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.types import Document
from mmda.parsers import PDFPlumberParser

import string
import re

os.chdir(pathlib.Path(__file__).parent)
Expand Down
Empty file.
7 changes: 2 additions & 5 deletions tests/test_predictors/test_dictionary_word_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
import unittest
from typing import List, Optional, Set

from mmda.predictors.heuristic_predictors.dictionary_word_predictor import (
DictionaryWordPredictor,
)
from mmda.types.document import Document, SpanGroup
from mmda.types.span import Span
from mmda.predictors import DictionaryWordPredictor
from mmda.types import Document, SpanGroup, Span


def mock_document(symbols: str, spans: List[Span], rows: List[SpanGroup]) -> Document:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_types/test_indexers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest

from mmda.types.annotation import SpanGroup
from mmda.types import SpanGroup, Span
from mmda.types.indexers import SpanGroupIndexer
from mmda.types.span import Span


class TestSpanGroupIndexer(unittest.TestCase):
Expand Down
6 changes: 2 additions & 4 deletions tests/test_types/test_json_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
import json
from pathlib import Path

from mmda.types.annotation import BoxGroup, SpanGroup
from mmda.types.document import Document
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.types.metadata import Metadata
from mmda.types import BoxGroup, SpanGroup, Document, Metadata
from mmda.parsers import PDFPlumberParser


PDFFILEPATH = Path(__file__).parent / "../fixtures/1903.10676.pdf"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_types/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import unittest


from mmda.types.metadata import Metadata
from mmda.types import Metadata


class TestSpanGroup(unittest.TestCase):
Expand Down
4 changes: 1 addition & 3 deletions tests/test_types/test_span_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import json
import unittest

from mmda.types.annotation import SpanGroup
from mmda.types.document import Document
from mmda.types.span import Span
from mmda.types import SpanGroup, Document, Span


class TestSpanGroup(unittest.TestCase):
Expand Down