Skip to content

Commit

Permalink
#131 related update: sync with AREkit updates
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Nov 23, 2024
1 parent b6b61f4 commit 539880c
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 8 deletions.
2 changes: 1 addition & 1 deletion arelight/arekit/custom_sqlite_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from arekit.contrib.utils.data.readers.base import BaseReader
from arelight.readers.base import BaseReader


class CustomSQliteReader(BaseReader):
Expand Down
3 changes: 2 additions & 1 deletion arelight/arekit/samples_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.contrib.utils.data.readers.base import BaseReader
from arekit.contrib.utils.data.writers.base import BaseWriter

from arelight.readers.base import BaseReader


class CustomSamplesIO(BaseSamplesIO):
""" Samples default IO utils for samples.
Expand Down
Empty file added arelight/readers/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions arelight/readers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class BaseReader(object):

def extension(self):
raise NotImplementedError()

def read(self, target):
raise NotImplementedError()
39 changes: 39 additions & 0 deletions arelight/readers/csv_pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import importlib

from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage

from arelight.readers.base import BaseReader


class PandasCsvReader(BaseReader):
""" Represents a CSV-based reader, implmented via pandas API.
"""

def __init__(self, sep='\t', header='infer', compression='infer', encoding='utf-8', col_types=None,
custom_extension=None):
self.__sep = sep
self.__compression = compression
self.__encoding = encoding
self.__header = header
self.__custom_extension = custom_extension

# Special assignation of types for certain columns.
self.__col_types = col_types
if self.__col_types is None:
self.__col_types = dict()

def extension(self):
return ".tsv.gz" if self.__custom_extension is None else self.__custom_extension

def __from_csv(self, filepath):
pd = importlib.import_module("pandas")
return pd.read_csv(filepath,
sep=self.__sep,
encoding=self.__encoding,
compression=self.__compression,
dtype=self.__col_types,
header=self.__header)

def read(self, target):
df = self.__from_csv(filepath=target)
return PandasBasedRowsStorage(df)
16 changes: 16 additions & 0 deletions arelight/readers/jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from arekit.contrib.utils.data.storages.jsonl_based import JsonlBasedRowsStorage

from arelight.readers.base import BaseReader


class JsonlReader(BaseReader):

def extension(self):
return ".jsonl"

def read(self, target):
rows = []
with open(target, "r") as f:
for line in f.readlines():
rows.append(line)
return JsonlBasedRowsStorage(rows)
15 changes: 15 additions & 0 deletions arelight/readers/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from arekit.contrib.utils.data.storages.sqlite_based import SQliteBasedRowsStorage

from arelight.readers.base import BaseReader


class SQliteReader(BaseReader):

def __init__(self, table_name):
self.__table_name = table_name

def extension(self):
return ".sqlite"

def read(self, target):
return SQliteBasedRowsStorage(path=target, table_name=self.__table_name)
4 changes: 2 additions & 2 deletions arelight/run/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders
from arekit.common.utils import split_by_whitespaces
from arekit.contrib.bert.input.providers.text_pair import PairTextProvider
from arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader
from arekit.contrib.utils.data.readers.sqlite import SQliteReader
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter
from arekit.contrib.utils.entities.formatters.str_simple_sharp_prefixed_fmt import SharpPrefixedEntitiesSimpleFormatter
Expand All @@ -34,6 +32,8 @@
from arelight.pipelines.items.entities_ner_transformers import TransformersNERPipelineItem
from arelight.predict.writer_csv import TsvPredictWriter
from arelight.predict.writer_sqlite3 import SQLite3PredictWriter
from arelight.readers.csv_pd import PandasCsvReader
from arelight.readers.sqlite import SQliteReader
from arelight.run.utils import merge_dictionaries, iter_group_values, create_sentence_parser, \
create_translate_model, iter_content, OPENNRE_CHECKPOINTS, NER_TYPES
from arelight.run.utils_logger import setup_custom_logger, TqdmToLogger
Expand Down
2 changes: 1 addition & 1 deletion test/test_arekit_iter_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.data_type import DataType
from arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader
from arekit.common.data.rows_fmt import create_base_column_fmt
from arekit.common.data.rows_parser import ParsedSampleRow
from arekit.common.pipeline.base import BasePipelineLauncher
Expand All @@ -15,6 +14,7 @@
from arelight.pipelines.demo.labels.scalers import CustomLabelScaler
from arelight.pipelines.demo.result import PipelineResult
from arelight.pipelines.items.backend_d3js_graphs import D3jsGraphsBackendPipelineItem
from arelight.readers.csv_pd import PandasCsvReader


class TestAREkitIterData(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions test/test_backend_d3js_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import pandas as pd

from arekit.common.pipeline.base import BasePipelineLauncher
from arekit.contrib.utils.data.readers.jsonl import JsonlReader
from arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader

from arelight.arekit.samples_io import CustomSamplesIO
from arelight.backend.d3js.relations_graph_builder import make_graph_from_relations_array
Expand All @@ -16,6 +14,8 @@
from arelight.pipelines.demo.labels.formatter import CustomLabelsFormatter
from arelight.pipelines.demo.labels.scalers import CustomLabelScaler
from arelight.pipelines.demo.result import PipelineResult
from arelight.readers.csv_pd import PandasCsvReader
from arelight.readers.jsonl import JsonlReader


class TestBackendD3JS(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion test/test_pipeline_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from arekit.common.labels.base import NoLabel
from arekit.common.labels.scaler.single import SingleLabelScaler
from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter
from arekit.contrib.utils.data.readers.jsonl import JsonlReader
from arekit.contrib.utils.entities.formatters.str_simple_sharp_prefixed_fmt import SharpPrefixedEntitiesSimpleFormatter
from arekit.common.data import const
from arekit.common.pipeline.context import PipelineContext
Expand All @@ -26,6 +25,7 @@
from arelight.pipelines.demo.labels.scalers import CustomLabelScaler
from arelight.pipelines.items.entities_ner_dp import DeepPavlovNERPipelineItem
from arelight.predict.writer_csv import TsvPredictWriter
from arelight.readers.jsonl import JsonlReader
from arelight.samplers.bert import create_bert_sample_provider
from arelight.samplers.types import BertSampleProviderTypes
from arelight.synonyms import iter_synonym_groups
Expand Down

0 comments on commit 539880c

Please sign in to comment.