From 287f34eddfd010836ff268f2c955f6c2665d8c10 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Thu, 2 Nov 2023 17:13:22 +0100 Subject: [PATCH] release 0.2.0, now depending on text-utils>=0.2.1 --- .gitmodules | 6 +- README.md | 2 +- configs/eo_lstm_char.yaml | 4 +- configs/eo_transformer_byt5.yaml | 4 +- configs/eo_transformer_byt5_scratch.yaml | 4 +- configs/eo_transformer_byte_v1_like.yaml | 4 +- configs/eo_transformer_char_v1_like.yaml | 4 +- configs/server.yaml | 5 +- pyproject.toml | 4 +- scripts/generate_visualizations.py | 2 +- src/whitespace_correction/api/cli.py | 22 +-- src/whitespace_correction/api/corrector.py | 190 ++++++++++----------- src/whitespace_correction/api/server.py | 19 +-- src/whitespace_correction/api/train.py | 4 +- src/whitespace_correction/model.py | 16 +- src/whitespace_correction/version.py | 2 +- 16 files changed, 144 insertions(+), 148 deletions(-) diff --git a/.gitmodules b/.gitmodules index f463a0d..86d988b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "third_party/text-correction-utils"] - path = third_party/text-correction-utils - url = https://github.com/bastiscode/text-correction-utils +[submodule "third_party/text-utils"] + path = third_party/text-utils + url = https://github.com/ad-freiburg/text-utils diff --git a/README.md b/README.md index a42cca6..d7c9993 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ echo "splitthissentenceforme" | wsc cat "path/to/input/file.txt" | wsc > output.txt # correct a string using -wsc -c "splitthissentenceforme" +wsc -p "splitthissentenceforme" # correct a text file line by line and print the corrected lines wsc -f path/to/input/file.txt diff --git a/configs/eo_lstm_char.yaml b/configs/eo_lstm_char.yaml index fc0144a..cdca382 100644 --- a/configs/eo_lstm_char.yaml +++ b/configs/eo_lstm_char.yaml @@ -26,8 +26,7 @@ model: num_classes: 3 train: - mixed_precision: true - mixed_precision_dtype: env(MIXED_PRECISION_DTYPE:fp16) + precision: fp16 clip_grad_norm: env(CLIP_GRAD_NORM:1.0) num_epochs: env(NUM_EPOCHS:3) eval_interval: eval(1 / env(EVAL_PER_EPOCH:10)) @@ -47,6 +46,7 @@ train: strategy: weighted shuffle: true sort: true + max_length: env(MAX_LENGTH:512) buffer_size: env(BATCH_LIMIT:32) prefetch_factor: env(PREFETCH_FACTOR:2048) num_threads: eval(env(THREADS:None) or len(os.sched_getaffinity(0)) // 2) diff --git a/configs/eo_transformer_byt5.yaml b/configs/eo_transformer_byt5.yaml index 15e5ef3..211c610 100644 --- a/configs/eo_transformer_byt5.yaml +++ b/configs/eo_transformer_byt5.yaml @@ -24,8 +24,7 @@ model: num_classes: 3 train: - mixed_precision: env(MIXED_PRECISION:true) - mixed_precision_dtype: env(MIXED_PRECISION_DTYPE:bfp16) + precision: env(PRECISION:fp16) clip_grad_norm: env(CLIP_GRAD_NORM:1.0) num_epochs: env(NUM_EPOCHS:1) eval_interval: eval(1 / env(EVAL_PER_EPOCH:10)) @@ -50,6 +49,7 @@ train: strategy: weighted shuffle: true sort: true + max_length: env(MAX_LENGTH:512) buffer_size: env(BATCH_LIMIT:32) prefetch_factor: env(PREFETCH_FACTOR:2048) num_threads: eval(env(THREADS:None) or len(os.sched_getaffinity(0)) // 2) diff --git a/configs/eo_transformer_byt5_scratch.yaml b/configs/eo_transformer_byt5_scratch.yaml index 15e3ed3..61d8d74 100644 --- a/configs/eo_transformer_byt5_scratch.yaml +++ b/configs/eo_transformer_byt5_scratch.yaml @@ -25,8 +25,7 @@ model: num_classes: 3 train: - mixed_precision: env(MIXED_PRECISION:true) - mixed_precision_dtype: env(MIXED_PRECISION_DTYPE:bfp16) + precision: env(PRECISION:fp16) clip_grad_norm: env(CLIP_GRAD_NORM:1.0) num_epochs: env(NUM_EPOCHS:1) eval_interval: eval(1 / env(EVAL_PER_EPOCH:10)) @@ -47,6 +46,7 @@ train: strategy: weighted shuffle: true sort: true + max_length: env(MAX_LENGTH:512) buffer_size: env(BATCH_LIMIT:32) prefetch_factor: env(PREFETCH_FACTOR:2048) num_threads: eval(env(THREADS:None) or len(os.sched_getaffinity(0)) // 2) diff --git a/configs/eo_transformer_byte_v1_like.yaml b/configs/eo_transformer_byte_v1_like.yaml index a7f52d5..34220c1 100644 --- a/configs/eo_transformer_byte_v1_like.yaml +++ b/configs/eo_transformer_byte_v1_like.yaml @@ -35,8 +35,7 @@ model: num_classes: 3 train: - mixed_precision: env(MIXED_PRECISION:true) - mixed_precision_dtype: env(MIXED_PRECISION_DTYPE:fp16) + precision: env(PRECISION:fp16) clip_grad_norm: env(CLIP_GRAD_NORM:1.0) num_epochs: env(NUM_EPOCHS:3) eval_interval: eval(1 / env(EVAL_PER_EPOCH:10)) @@ -56,6 +55,7 @@ train: strategy: weighted shuffle: true sort: true + max_length: env(MAX_LENGTH:512) buffer_size: env(BATCH_LIMIT:32) prefetch_factor: env(PREFETCH_FACTOR:2048) num_threads: eval(env(THREADS:None) or len(os.sched_getaffinity(0)) // 2) diff --git a/configs/eo_transformer_char_v1_like.yaml b/configs/eo_transformer_char_v1_like.yaml index 4395c9d..646a267 100644 --- a/configs/eo_transformer_char_v1_like.yaml +++ b/configs/eo_transformer_char_v1_like.yaml @@ -29,8 +29,7 @@ model: num_classes: 3 train: - mixed_precision: true - mixed_precision_dtype: env(MIXED_PRECISION_DTYPE:fp16) + precision: env(PRECISION:fp16) clip_grad_norm: env(CLIP_GRAD_NORM:1.0) num_epochs: env(NUM_EPOCHS:3) eval_interval: eval(1 / env(EVAL_PER_EPOCH:10)) @@ -50,6 +49,7 @@ train: strategy: weighted shuffle: true sort: true + max_length: env(MAX_LENGTH:512) buffer_size: env(BATCH_LIMIT:32) prefetch_factor: env(PREFETCH_FACTOR:2048) num_threads: eval(env(THREADS:None) or len(os.sched_getaffinity(0)) // 2) diff --git a/configs/server.yaml b/configs/server.yaml index 7ceff18..95d50bc 100644 --- a/configs/server.yaml +++ b/configs/server.yaml @@ -1,14 +1,13 @@ port: 40000 timeout: 10 -# precision: fp16 # allow_origin: test.mydomain.com base_url: env(BASE_URL:/api) models: # load a pretrained model by specifying the name - # - eo_large_arxiv + # - name: eo_large_arxiv # load a model from a local experiment by specifying the # directory path (you can use special configuration operators, # e.g. env(ENV_VAR) to load env variables) - - env(EXPERIMENT) + - path: env(EXPERIMENT) batch_size: env(BATCH_SIZE:16) # batch_max_tokens: env(BATCH_MAX_TOKENS:8192) diff --git a/pyproject.toml b/pyproject.toml index a423da2..9b5080b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "whitespace_correction" -version = "0.1.5" +version = "0.2.0" description = "Correct missing or spurious whitespaces in text." authors = [ { name = "Sebastian Walter", email = "swalter@cs.uni-freiburg.de" } @@ -20,7 +20,7 @@ classifiers = [ ] dependencies = [ - "text-correction-utils==0.1.4", + "dtpu>=0.2.1", "transformers>=4.26.0" ] diff --git a/scripts/generate_visualizations.py b/scripts/generate_visualizations.py index 074902b..214199e 100644 --- a/scripts/generate_visualizations.py +++ b/scripts/generate_visualizations.py @@ -6,7 +6,7 @@ from whitespace_correction.api import WhitespaceCorrector -from text_correction_utils import hook, logging +from text_utils import hook, logging from torch import nn import numpy as np diff --git a/src/whitespace_correction/api/cli.py b/src/whitespace_correction/api/cli.py index 87df605..1616211 100644 --- a/src/whitespace_correction/api/cli.py +++ b/src/whitespace_correction/api/cli.py @@ -1,27 +1,27 @@ from io import TextIOWrapper from typing import Iterator, Optional, Union -from text_correction_utils.api.cli import TextCorrectionCli -from text_correction_utils import data +from text_utils.api.cli import TextProcessingCli +from text_utils import data from whitespace_correction import version from whitespace_correction.api.corrector import WhitespaceCorrector from whitespace_correction.api.server import WhitespaceCorrectionServer -class WhitespaceCorrectionCli(TextCorrectionCli): - text_corrector_cls = WhitespaceCorrector - text_correction_server_cls = WhitespaceCorrectionServer +class WhitespaceCorrectionCli(TextProcessingCli): + text_processor_cls = WhitespaceCorrector + text_processing_server_cls = WhitespaceCorrectionServer def version(self) -> str: return version.__version__ - def correct_iter( + def process_iter( self, - corrector: WhitespaceCorrector, + processor: WhitespaceCorrector, iter: Iterator[data.InferenceData] ) -> Iterator[data.InferenceData]: - yield from corrector.correct_iter( + yield from processor.correct_iter( ((data.text, data.language) for data in iter), self.args.batch_size, self.args.batch_max_tokens, @@ -31,14 +31,14 @@ def correct_iter( show_progress=self.args.progress ) - def correct_file( + def process_file( self, - corrector: WhitespaceCorrector, + processor: WhitespaceCorrector, path: str, lang: Optional[str], out_file: Union[str, TextIOWrapper] ): - corrector.correct_file( + processor.correct_file( path, self.args.input_format, out_file, diff --git a/src/whitespace_correction/api/corrector.py b/src/whitespace_correction/api/corrector.py index a97a059..ce4621f 100644 --- a/src/whitespace_correction/api/corrector.py +++ b/src/whitespace_correction/api/corrector.py @@ -9,20 +9,19 @@ from whitespace_correction.model import model_from_config, EncoderDecoderWithHead -from text_correction_utils import data, whitespace, tokenization -from text_correction_utils.api.corrector import ModelInfo -from text_correction_utils.api import corrector -from text_correction_utils.api.utils import device_info, to -from text_correction_utils.inference import IdxSelectFn, eos_stop_fn, search +from text_utils import data, whitespace, tokenization +from text_utils.api.processor import ModelInfo, TextProcessor +from text_utils.api.utils import device_info, to, Device +from text_utils.inference import IdxSelectFn, eos_stop_fn, search _BASE_URL = "https://ad-publications.informatik.uni-freiburg.de/" \ "ACL_whitespace_correction_transformer_BHW_2023.materials" _NAME_TO_ZIP = { - "eo_large_char_v1": "eo_large_char_v1.zip", + "eo_large_char_v1": "eo_large_v1.zip", "eo_large_char": "eo_large_char_v2.zip", "eo_large_byte": "eo_large_byte_v2.zip", "eo_larger_byte": "eo_huge_byte_v2.zip", - "eo_medium_char_v1": "eo_medium_char_v1.zip", + "eo_medium_char_v1": "eo_medium_v1.zip", "eo_medium_char": "eo_medium_char_v2.zip", "eo_medium_byte": "eo_medium_byte_v2.zip", "ed_large_char": "ed_large_v1.zip", @@ -30,7 +29,7 @@ } -class WhitespaceCorrector(corrector.TextCorrector): +class WhitespaceCorrector(TextProcessor): task = "whitespace correction" @classmethod @@ -94,7 +93,11 @@ def name(self) -> str: return self.cfg["experiment"]["name"] @classmethod - def _model_from_config(cls, cfg: Dict[str, Any]) -> nn.Module: + def _model_from_config( + cls, + cfg: Dict[str, Any], + device: Device + ) -> nn.Module: input_tokenizer = tokenization.Tokenizer.from_config(cfg["input_tokenizer"]) if "output_tokenizer" in cfg: output_tokenizer = tokenization.Tokenizer.from_config(cfg["output_tokenizer"]) @@ -108,14 +111,7 @@ def _model_from_config(cls, cfg: Dict[str, Any]) -> nn.Module: @property def max_length(self) -> int: - if self.cfg["model"]["type"] == "pretrained_encoder_with_head": - return 512 - elif self.cfg["model"]["type"] == "encoder_with_head": - return self.cfg["model"]["embedding"].get("max_length", 512) - elif self.cfg["model"]["type"] == "encoder_decoder_with_head": - return self.cfg["model"]["encoder_embedding"].get("max_length", 512) - else: - raise ValueError(f"unknown model type: {self.cfg['model']['type']}") + return self.cfg["train"]["data"].get("max_length", 512) @property def context_length(self) -> int: @@ -131,10 +127,14 @@ def supported_languages(self) -> Optional[List[str]]: def __init__( self, - model_dir: str, - device: Union[str, int] + model: nn.Module, + cfg: Dict[str, Any], + device: Device = "cuda" ) -> None: - super().__init__(model_dir, device) + super().__init__(model, cfg, device) + assert len(self.devices) == 1, \ + "whitespace correction is only supported on single devices for now" + self.device = self.devices[0] self.logger.debug(f"loaded model config:\n{self.cfg['model']}") self.logger.info(f"running {self.name} whitespace corrector on device {device_info(self.device)}") self.input_tokenizer = tokenization.Tokenizer.from_config(self.cfg["input_tokenizer"]) @@ -147,9 +147,6 @@ def __init__( self._pfx = self.input_tokenizer.num_prefix_tokens() self._sfx = self.input_tokenizer.num_suffix_tokens() - precision = self.cfg["train"].get("mixed_precision_dtype", "fp32") - self.set_precision(precision) - def _build_inference_loader_config(self) -> Dict[str, Any]: input_tokenizer = tokenization.Tokenizer.from_config(self.cfg["input_tokenizer"]) pfx = input_tokenizer.num_prefix_tokens() @@ -173,7 +170,7 @@ def _build_inference_loader_config(self) -> Dict[str, Any]: } def _prepare_batch(self, batch: data.InferenceBatch) -> Dict[str, Any]: - token_ids_np, pad_mask_np, lengths, info = batch.tensors + token_ids_np, pad_mask_np, lengths, info = batch.tensors() inputs = { "token_ids": torch.from_numpy(token_ids_np).to(non_blocking=True, device=self.device), "padding_mask": torch.from_numpy(pad_mask_np).to(non_blocking=True, device=self.device), @@ -195,17 +192,17 @@ def _inference(self, inputs: Dict[str, Any]) -> Any: def _decode_fn( token_ids: torch.Tensor, **kwargs: Any - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: decoded = self.model.decode( token_ids, **kwargs ) - return decoded + return decoded, {} - def _kwargs_sub_select_fn(kwargs: Dict[str, Any], mask: torch.Tensor) -> Dict[str, Any]: + def _kwargs_select_fn(kwargs: Dict[str, Any], mask: torch.Tensor) -> Dict[str, Any]: return { "memories": {"encoder": kwargs["memory"][mask]}, - "memory_padding_masks": {"encoder": kwargs["padding_mask"][mask]} + "memory_padding_masks": {"encoder": kwargs["memory_padding_mask"][mask]} } max_output_length = self.cfg["model"]["decoder_embedding"].get("max_length", 2 * self.max_length) @@ -214,7 +211,10 @@ def _kwargs_sub_select_fn(kwargs: Dict[str, Any], mask: torch.Tensor) -> Dict[st # the whitespace token or the copy the corrpesonding token from # the input assert self.output_tokenizer is not None - eos_token_id = self.output_tokenizer.eos_token_id() + eos_token = "" + eos_token_id = self.output_tokenizer.special_token_to_id(eos_token) + bos_token = "" + bos_token_id = self.output_tokenizer.special_token_to_id(bos_token) ws_token_id = self.output_tokenizer.tokenize(" ").token_ids[self._pfx] lengths = inputs["lengths"] non_ws_token_ids = [ @@ -228,36 +228,44 @@ def _kwargs_sub_select_fn(kwargs: Dict[str, Any], mask: torch.Tensor) -> Dict[st # either copy the correct input char or add a whitespace, # also makes sure that no doubled whitespaces are returned def _custom_select_fn() -> IdxSelectFn: - def _select(scores: torch.Tensor, idx: int) -> Tuple[int, float]: - token_id_idx = token_id_indices[idx] - if token_id_idx >= len(non_ws_token_ids[idx]): - # we are at the end of the input, select eos - return eos_token_id, 0 - - input_token_id = non_ws_token_ids[idx][token_id_idx] - ws_score = scores[ws_token_id] - input_token_score = scores[input_token_id] - if ws_score > input_token_score and not last_was_ws[idx]: - last_was_ws[idx] = True - return ws_token_id, float(ws_score) - else: - token_id_indices[idx] += 1 - last_was_ws[idx] = False - return input_token_id, float(input_token_score) + def _select(scores: torch.Tensor, indices: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + token_ids = [] + log_probs = [] + for dist, idx in zip(scores, indices): + token_id_idx = token_id_indices[idx] + if token_id_idx >= len(non_ws_token_ids[idx]): + # we are at the end of the input, select eos + token_ids.append(eos_token_id) + log_probs.append(0.0) + continue + + input_token_id = non_ws_token_ids[idx][token_id_idx] + ws_score = dist[ws_token_id] + input_token_score = dist[input_token_id] + if ws_score > input_token_score and not last_was_ws[idx]: + last_was_ws[idx] = True + token_ids.append(ws_token_id) + log_probs.append(float(ws_score)) + else: + token_id_indices[idx] += 1 + last_was_ws[idx] = False + token_ids.append(input_token_id) + log_probs.append(float(input_token_score)) + return torch.tensor(token_ids, device=scores.device), torch.tensor(log_probs, device=scores.device) return _select output = search( decode_fn=_decode_fn, - initial_token_ids=[[self.output_tokenizer.bos_token_id()]] * encoded.shape[0], + initial_token_ids=[[bos_token_id]] * encoded.shape[0], pad_token_id=self.output_tokenizer.pad_token_id(), max_length=max_output_length, select_fn=_custom_select_fn(), - stop_fn=eos_stop_fn(self.output_tokenizer.eos_token_id()), + stop_fn=eos_stop_fn(eos_token_id), device=self.device, - kwargs_sub_select_fn=_kwargs_sub_select_fn, + kwargs_select_fn=_kwargs_select_fn, memory=encoded, - **kwargs + memory_padding_mask=kwargs["padding_mask"], ) return output @@ -276,13 +284,16 @@ def _process_results( prediction = torch.argmax(output[self._pfx + window_start:self._pfx + window_end], dim=-1) merged_predictions.extend(prediction.tolist()) repaired = whitespace.repair(items[0].data.text, merged_predictions) - return data.InferenceData(repaired, language=items[0].data.language) + return data.InferenceData( + repaired.strip(), + language=items[0].data.language + ) # only thing left to do here is swap back the unknown tokens # with the original ones assert self.output_tokenizer is not None - unk_token_id = self.output_tokenizer.unk_token_id() - out_pfx = self.output_tokenizer.num_prefix_tokens() + unk_token = "" + unk_token_id = self.output_tokenizer.special_token_to_id(unk_token) out_sfx = self.output_tokenizer.num_suffix_tokens() merged = "" for item, output in zip(items, outputs): @@ -298,7 +309,10 @@ def _process_results( if tok_id == unk_token_id ] input_unk_chars = [input_chars[i] for i in input_unk_indices] - output_token_ids = output[out_pfx:-out_sfx] + # we only need to exclude the suffix here because by default the search + # function returns only the newly decoded tokens, which means the prefix is + # already excluded + output_token_ids = output[:-out_sfx] output_unk_indices = [ i for i, tok_id in enumerate(output_token_ids) @@ -308,7 +322,9 @@ def _process_results( output_str = "" start_idx = 0 for output_unk_idx, input_unk_char in zip(output_unk_indices, input_unk_chars): - output_str += self.output_tokenizer.de_tokenize(output_token_ids[start_idx:output_unk_idx]) + output_str += self.output_tokenizer.de_tokenize( + output_token_ids[start_idx:output_unk_idx] + ) start_idx = output_unk_idx + 1 output_str += input_unk_char output_str += self.output_tokenizer.de_tokenize(output_token_ids[start_idx:]) @@ -320,14 +336,14 @@ def _process_results( return data.InferenceData(merged.rstrip(), language=items[0].data.language) def correct_text( - self, - inputs: Union[str, List[str]], - languages: Optional[List[str]] = None, - batch_size: int = 16, - batch_max_tokens: Optional[int] = None, - sort: bool = True, - num_threads: Optional[int] = None, - show_progress: bool = False + self, + inputs: Union[str, List[str]], + languages: Optional[List[str]] = None, + batch_size: int = 16, + batch_max_tokens: Optional[int] = None, + sort: bool = True, + num_threads: Optional[int] = None, + show_progress: bool = False ) -> Union[str, List[str]]: input_is_string = isinstance(inputs, str) assert ( @@ -365,7 +381,7 @@ def correct_text( progress_unit = "seq" if sort: - outputs = self._correct_sorted( + outputs = self._process_sorted( loader, progress_desc, progress_total, @@ -373,7 +389,7 @@ def correct_text( show_progress ) else: - outputs = self._correct_unsorted( + outputs = self._process_unsorted( loader, progress_desc, progress_total, @@ -406,7 +422,7 @@ def correct_iter( progress_unit = "byte" if sort: - output = self._correct_sorted( + output = self._process_sorted( loader, progress_desc, progress_total, @@ -414,7 +430,7 @@ def correct_iter( show_progress ) else: - output = self._correct_unsorted( + output = self._process_unsorted( loader, progress_desc, progress_total, @@ -428,17 +444,17 @@ def correct_iter( yield from (data.text for data in output) def correct_file( - self, - input_file: str, - input_file_format: str = "text", - output_file: Optional[Union[TextIOWrapper, str]] = None, - output_file_format: str = "text", - language: Optional[str] = None, - batch_size: int = 16, - batch_max_tokens: Optional[int] = None, - sort: bool = True, - num_threads: Optional[int] = None, - show_progress: bool = False + self, + input_file: str, + input_file_format: str = "text", + output_file: Optional[Union[TextIOWrapper, str]] = None, + output_file_format: str = "text", + language: Optional[str] = None, + batch_size: int = 16, + batch_max_tokens: Optional[int] = None, + sort: bool = True, + num_threads: Optional[int] = None, + show_progress: bool = False ) -> Optional[Iterator[str]]: assert input_file_format in self.supported_input_formats(), f"unsupported input file format {input_file_format}, \ must be one of {self.supported_input_formats()}" @@ -459,7 +475,7 @@ def correct_file( progress_unit = "byte" if sort: - outputs = iter(self._correct_sorted( + outputs = iter(self._process_sorted( loader, progress_desc, progress_total, @@ -467,7 +483,7 @@ def correct_file( show_progress )) else: - outputs = self._correct_unsorted( + outputs = self._process_unsorted( loader, progress_desc, progress_total, @@ -491,19 +507,3 @@ def correct_file( else: return (output.text for output in outputs) - - def set_precision(self, precision: str) -> None: - if self.device.type == "cpu" and self._encoder_only and precision != "fp32": - self.logger.warning( - f"got {precision} precision, but " - "encoder-only models only support fp32 precision on CPU" - ) - precision = "fp32" - else: - training_precision = self.cfg["train"].get("mixed_precision_dtype", "fp32") - if precision != "fp32" and precision != training_precision: - self.logger.warning( - f"this model was trained with {training_precision} precision, " - f"inference with {precision} might give unexpected results" - ) - return super().set_precision(precision) diff --git a/src/whitespace_correction/api/server.py b/src/whitespace_correction/api/server.py index ecd4d8e..e224093 100644 --- a/src/whitespace_correction/api/server.py +++ b/src/whitespace_correction/api/server.py @@ -3,27 +3,24 @@ from flask import Response, jsonify, request, abort -from text_correction_utils.api.server import TextCorrectionServer, Error -from text_correction_utils.api.utils import ProgressIterator -from text_correction_utils import metrics +from text_utils.api.server import TextProcessingServer, Error +from text_utils.api.utils import ProgressIterator +from text_utils import metrics from whitespace_correction.api.corrector import WhitespaceCorrector -class WhitespaceCorrectionServer(TextCorrectionServer): - text_corrector_cls = WhitespaceCorrector +class WhitespaceCorrectionServer(TextProcessingServer): + text_processor_cls = WhitespaceCorrector def __init__(self, config: Dict[str, Any]): super().__init__(config) - self.batch_size = int(self.config.get("batch_size", 16)) + self.batch_size = self.config.get("batch_size", 1) if "batch_max_tokens" in self.config: - self.batch_max_tokens = int(self.config["batch_max_tokens"]) + self.batch_max_tokens = self.config["batch_max_tokens"] else: self.batch_max_tokens = None - for cor, _ in self.text_correctors.values(): - cor.set_precision(self.precision) - @self.server.route(f"{self.base_url}/correct", methods=["POST"]) def _correct() -> Response: json = request.get_json() @@ -35,7 +32,7 @@ def _correct() -> Response: return abort(Response("missing text in json", status=400)) try: - with self.text_corrector(json["model"]) as cor: + with self.text_processor(json["model"]) as cor: if isinstance(cor, Error): return abort(cor.to_response()) assert isinstance(cor, WhitespaceCorrector) diff --git a/src/whitespace_correction/api/train.py b/src/whitespace_correction/api/train.py index c90857f..b2bb459 100644 --- a/src/whitespace_correction/api/train.py +++ b/src/whitespace_correction/api/train.py @@ -3,8 +3,8 @@ from torch import nn -from text_correction_utils.api.trainer import Trainer -from text_correction_utils import tokenization +from text_utils.api.trainer import Trainer +from text_utils import tokenization from whitespace_correction.model import model_from_config diff --git a/src/whitespace_correction/model.py b/src/whitespace_correction/model.py index 939fb69..92fe589 100644 --- a/src/whitespace_correction/model.py +++ b/src/whitespace_correction/model.py @@ -4,11 +4,11 @@ import torch from torch import nn -from text_correction_utils import tokenization -from text_correction_utils.modules.embedding import Embedding, embedding_from_config -from text_correction_utils.modules.encoder import Encoder, encoder_from_config -from text_correction_utils.modules.decoder import Decoder, decoder_from_config -from text_correction_utils.modules.head import Head, head_from_config +from text_utils import tokenization +from text_utils.modules.embedding import Embedding, embedding_from_config +from text_utils.modules.encoder import Encoder, encoder_from_config +from text_utils.modules.decoder import Decoder, decoder_from_config +from text_utils.modules.head import Head, head_from_config from transformers import T5EncoderModel, T5Config @@ -59,7 +59,7 @@ def forward( token_ids: torch.Tensor, **kwargs: Any ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - emb, pos_emb = self.embedding(token_ids, **kwargs) + emb, pos_emb, kwargs = self.embedding(token_ids, **kwargs) enc, kwargs = self.encoder(emb, pos=pos_emb, **kwargs) output = self.head(enc, **kwargs) return output, self.encoder.additional_losses() @@ -106,7 +106,7 @@ def encode( token_ids: torch.Tensor, **kwargs: Any ) -> Tuple[torch.Tensor, Dict[str, Any]]: - emb, pos_emb = self.encoder_embedding(token_ids, **kwargs) + emb, pos_emb, kwargs = self.encoder_embedding(token_ids, **kwargs) enc, kwargs = self.encoder(emb, pos_emb, **kwargs) return enc, kwargs @@ -115,7 +115,7 @@ def decode( token_ids: torch.Tensor, **kwargs: Any ) -> torch.Tensor: - emb, pos_emb = self.decoder_embedding(token_ids, **kwargs) + emb, pos_emb, kwargs = self.decoder_embedding(token_ids, **kwargs) dec, kwargs = self.decoder( emb, pos_emb, diff --git a/src/whitespace_correction/version.py b/src/whitespace_correction/version.py index 1276d02..d3ec452 100644 --- a/src/whitespace_correction/version.py +++ b/src/whitespace_correction/version.py @@ -1 +1 @@ -__version__ = "0.1.5" +__version__ = "0.2.0"