Skip to content

Commit

Permalink
Add text analyzer to skip text extraction from image (#199)
Browse files Browse the repository at this point in the history
* read in text from csv

* add tests for csv reading

* run textanalyzer in demo notebook

* add text analyser in doc and demo

* improve init TextDetector testing

* more init tests

* add csv encoding keyword

* add utf16-csv file

* skip csv reading on windows
  • Loading branch information
iulusoy authored Jun 5, 2024
1 parent 9202f51 commit 4ac760e
Show file tree
Hide file tree
Showing 9 changed files with 328 additions and 14 deletions.
3 changes: 2 additions & 1 deletion ammico/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ammico.faces import EmotionDetector
from ammico.multimodal_search import MultimodalSearch
from ammico.summary import SummaryDetector
from ammico.text import TextDetector, PostprocessText
from ammico.text import TextDetector, TextAnalyzer, PostprocessText
from ammico.utils import find_files, get_dataframe

# Export the version defined in project metadata
Expand All @@ -23,6 +23,7 @@
"MultimodalSearch",
"SummaryDetector",
"TextDetector",
"TextAnalyzer",
"PostprocessText",
"find_files",
"get_dataframe",
Expand Down
8 changes: 8 additions & 0 deletions ammico/data/ref/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
text, date
this is a test, 05/31/24
bu bir denemedir, 05/31/24
dies ist ein Test, 05/31/24
c'est un test, 05/31/24
esto es una prueba, 05/31/24
detta är ett test, 05/31/24

88 changes: 88 additions & 0 deletions ammico/notebooks/DemoNotebook_ammico.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,94 @@
"image_df.to_csv(\"/content/drive/MyDrive/misinformation-data/data_out.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read in a csv file containing text and translating/analysing the text\n",
"\n",
"Instead of extracting text from an image, or to re-process text that was already extracted, it is also possible to provide a `csv` file containing text in its rows.\n",
"Provide the path and name of the csv file with the keyword `csv_path`. The keyword `column_key` tells the Analyzer which column key in the csv file holds the text that should be analyzed. This defaults to \"text\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ta = ammico.TextAnalyzer(csv_path=\"../data/ref/test.csv\", column_key=\"text\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the csv file\n",
"ta.read_csv()\n",
"# set up the dict containing all text entries\n",
"text_dict = ta.mydict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# set the dump file\n",
"# dump file name\n",
"dump_file = \"dump_file.csv\"\n",
"# dump every N images \n",
"dump_every = 10"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# analyze the csv file\n",
"for num, key in tqdm(enumerate(text_dict.keys()), total=len(text_dict)): # loop through all text entries\n",
" ammico.TextDetector(text_dict[key], analyse_text=True, skip_extraction=True).analyse_image() # analyse text with TextDetector and update dict\n",
" if num % dump_every == 0 | num == len(text_dict) - 1: # save results every dump_every to dump_file\n",
" image_df = ammico.get_dataframe(text_dict)\n",
" image_df.to_csv(dump_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save the results to a csv file\n",
"text_df = ammico.get_dataframe(text_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# inspect\n",
"text_df.head(3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# write to csv\n",
"text_df.to_csv(\"data_out.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Binary file added ammico/test/data/test-utf16.csv
Binary file not shown.
8 changes: 8 additions & 0 deletions ammico/test/data/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
text, date
this is a test, 05/31/24
bu bir denemedir, 05/31/24
dies ist ein Test, 05/31/24
c'est un test, 05/31/24
esto es una prueba, 05/31/24
detta är ett test, 05/31/24

32 changes: 32 additions & 0 deletions ammico/test/data/test_read_csv_ref.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"test.csvrow-1":
{
"filename": "test.csv",
"text": "this is a test"
},
"test.csvrow-2":
{
"filename": "test.csv",
"text": "bu bir denemedir"
},
"test.csvrow-3":
{
"filename": "test.csv",
"text": "dies ist ein Test"
},
"test.csvrow-4":
{
"filename": "test.csv",
"text": "c'est un test"
},
"test.csvrow-5":
{
"filename": "test.csv",
"text": "esto es una prueba"
},
"test.csvrow-6":
{
"filename": "test.csv",
"text": "detta är ett test"
}
}
71 changes: 65 additions & 6 deletions ammico/test/test_text.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import ammico.text as tt
import spacy
import json
import sys


@pytest.fixture
Expand All @@ -25,10 +27,25 @@ def set_testdict(get_path):
def test_TextDetector(set_testdict):
for item in set_testdict:
test_obj = tt.TextDetector(set_testdict[item])
assert test_obj.subdict["text"] is None
assert test_obj.subdict["text_language"] is None
assert test_obj.subdict["text_english"] is None
assert not test_obj.analyse_text
assert not test_obj.skip_extraction
assert test_obj.subdict["filename"] == set_testdict[item]["filename"]
assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6"
assert (
test_obj.model_sentiment
== "distilbert-base-uncased-finetuned-sst-2-english"
)
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
assert test_obj.revision_summary == "a4f8f3e"
assert test_obj.revision_sentiment == "af0f99b"
assert test_obj.revision_ner == "f2482bf"
test_obj = tt.TextDetector({}, analyse_text=True, skip_extraction=True)
assert test_obj.analyse_text
assert test_obj.skip_extraction
with pytest.raises(ValueError):
tt.TextDetector({}, analyse_text=1.0)
with pytest.raises(ValueError):
tt.TextDetector({}, skip_extraction=1.0)


def test_run_spacy(set_testdict, get_path):
Expand Down Expand Up @@ -140,7 +157,6 @@ def test_remove_linebreaks():
assert test_obj.subdict["text_english"] == "This is another test."


@pytest.mark.win_skip
def test_text_summary(get_path):
mydict = {}
test_obj = tt.TextDetector(mydict, analyse_text=True)
Expand All @@ -162,7 +178,6 @@ def test_text_sentiment_transformers():
assert mydict["sentiment_score"] == pytest.approx(0.99, 0.02)


@pytest.mark.win_skip
def test_text_ner():
mydict = {}
test_obj = tt.TextDetector(mydict, analyse_text=True)
Expand All @@ -172,7 +187,51 @@ def test_text_ner():
assert mydict["entity_type"] == ["PER", "LOC"]


@pytest.mark.win_skip
def test_init_csv_option(get_path):
test_obj = tt.TextAnalyzer(csv_path=get_path + "test.csv")
assert test_obj.csv_path == get_path + "test.csv"
assert test_obj.column_key == "text"
assert test_obj.csv_encoding == "utf-8"
test_obj = tt.TextAnalyzer(
csv_path=get_path + "test.csv", column_key="mytext", csv_encoding="utf-16"
)
assert test_obj.column_key == "mytext"
assert test_obj.csv_encoding == "utf-16"
with pytest.raises(ValueError):
tt.TextAnalyzer(csv_path=1.0)
with pytest.raises(ValueError):
tt.TextAnalyzer(csv_path="something")
with pytest.raises(FileNotFoundError):
tt.TextAnalyzer(csv_path=get_path + "test_no.csv")
with pytest.raises(ValueError):
tt.TextAnalyzer(csv_path=get_path + "test.csv", column_key=1.0)
with pytest.raises(ValueError):
tt.TextAnalyzer(csv_path=get_path + "test.csv", csv_encoding=1.0)


@pytest.mark.skipif(sys.platform == "win32", reason="Encoding different on Window")
def test_read_csv(get_path):
test_obj = tt.TextAnalyzer(csv_path=get_path + "test.csv")
test_obj.read_csv()
with open(get_path + "test_read_csv_ref.json", "r") as file:
ref_dict = json.load(file)
# we are assuming the order did not get jungled up
for (_, value_test), (_, value_ref) in zip(
test_obj.mydict.items(), ref_dict.items()
):
assert value_test["text"] == value_ref["text"]
# test with different encoding
test_obj = tt.TextAnalyzer(
csv_path=get_path + "test-utf16.csv", csv_encoding="utf-16"
)
test_obj.read_csv()
# we are assuming the order did not get jungled up
for (_, value_test), (_, value_ref) in zip(
test_obj.mydict.items(), ref_dict.items()
):
assert value_test["text"] == value_ref["text"]


def test_PostprocessText(set_testdict, get_path):
reference_dict = "THE\nALGEBRAIC\nEIGENVALUE\nPROBLEM\nDOM\nNVS TIO\nMINA\nMonographs\non Numerical Analysis\nJ.. H. WILKINSON"
reference_df = "Mathematische Formelsammlung\nfür Ingenieure und Naturwissenschaftler\nMit zahlreichen Abbildungen und Rechenbeispielen\nund einer ausführlichen Integraltafel\n3., verbesserte Auflage"
Expand Down
38 changes: 32 additions & 6 deletions ammico/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
self,
subdict: dict,
analyse_text: bool = False,
skip_extraction: bool = False,
model_names: list = None,
revision_numbers: list = None,
) -> None:
Expand All @@ -25,6 +26,8 @@ def __init__(
analysis results from other modules.
analyse_text (bool, optional): Decide if extracted text will be further subject
to analysis. Defaults to False.
skip_extraction (bool, optional): Decide if text will be extracted from images or
is already provided via a csv. Defaults to False.
model_names (list, optional): Provide model names for summary, sentiment and ner
analysis. Defaults to None, in which case the default model from transformers
are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary),
Expand All @@ -40,11 +43,21 @@ def __init__(
"f2482bf" (NER, bert).
"""
super().__init__(subdict)
self.subdict.update(self.set_keys())
# disable this for now
# maybe it would be better to initialize the keys differently
# the reason is that they are inconsistent depending on the selected
# options, and also this may not be really necessary and rather restrictive
# self.subdict.update(self.set_keys())
self.translator = Translator()
if not isinstance(analyse_text, bool):
raise ValueError("analyse_text needs to be set to true or false")
self.analyse_text = analyse_text
self.skip_extraction = skip_extraction
if not isinstance(skip_extraction, bool):
raise ValueError("skip_extraction needs to be set to true or false")
if self.skip_extraction:
print("Skipping text extraction from image.")
print("Reading text directly from provided dictionary.")
if self.analyse_text:
self._initialize_spacy()
if model_names:
Expand Down Expand Up @@ -155,7 +168,8 @@ def analyse_image(self) -> dict:
Returns:
dict: The updated dictionary with text analysis results.
"""
self.get_text_from_image()
if not self.skip_extraction:
self.get_text_from_image()
self.translate_text()
self.remove_linebreaks()
if self.analyse_text:
Expand Down Expand Up @@ -287,18 +301,32 @@ def text_ner(self):
class TextAnalyzer:
"""Used to get text from a csv and then run the TextDetector on it."""

def __init__(self, csv_path: str, column_key: str = None) -> None:
def __init__(
self, csv_path: str, column_key: str = None, csv_encoding: str = "utf-8"
) -> None:
"""Init the TextTranslator class.
Args:
csv_path (str): Path to the CSV file containing the text entries.
column_key (str): Key for the column containing the text entries.
Defaults to None.
csv_encoding (str): Encoding of the CSV file. Defaults to "utf-8".
"""
self.csv_path = csv_path
self.column_key = column_key
self.csv_encoding = csv_encoding
self._check_valid_csv_path()
self._check_file_exists()
if not self.column_key:
print("No column key provided - using 'text' as default.")
self.column_key = "text"
if not self.csv_encoding:
print("No encoding provided - using 'utf-8' as default.")
self.csv_encoding = "utf-8"
if not isinstance(self.column_key, str):
raise ValueError("The provided column key is not a string.")
if not isinstance(self.csv_encoding, str):
raise ValueError("The provided encoding is not a string.")

def _check_valid_csv_path(self):
if not isinstance(self.csv_path, str):
Expand All @@ -319,9 +347,7 @@ def read_csv(self) -> dict:
Returns:
dict: The dictionary with the text entries.
"""
df = pd.read_csv(self.csv_path, encoding="utf8")
if not self.column_key:
self.column_key = "text"
df = pd.read_csv(self.csv_path, encoding=self.csv_encoding)

if self.column_key not in df:
raise ValueError(
Expand Down
Loading

0 comments on commit 4ac760e

Please sign in to comment.