Skip to content

Add support for multiple postprocessing requests #759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 76 additions & 32 deletions silnlp/common/postprocess_draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_paths_from_exp(config: Config) -> Tuple[Path, Path]:
translate_config = yaml.safe_load(file)["translate"][0]
src_project = translate_config.get("src_project", next(iter(config.src_projects)))
books = translate_config["books"]
book = books[0] if isinstance(books, list) else books.split(";")[0] # TODO: handle partial book translation
book = books[0][:3] if isinstance(books, list) else books.split(";")[0][:3]
book_num = book_id_to_number(book)

ckpt = translate_config.get("checkpoint", "last")
Expand Down Expand Up @@ -89,9 +89,7 @@ def get_sentences(


def main() -> None:
parser = argparse.ArgumentParser(
description="Applies draft postprocessing steps to a draft. Can be used with no postprocessing options to create a base draft."
)
parser = argparse.ArgumentParser(description="Applies draft postprocessing steps to a draft.")
parser.add_argument(
"--experiment",
default=None,
Expand Down Expand Up @@ -173,6 +171,27 @@ def main() -> None:
src_path = Path(args.source.replace("\\", "/"))
draft_path = Path(args.draft.replace("\\", "/"))

# If no postprocessing options are used, use any postprocessing requests in the experiment's translate config
if args.include_paragraph_markers or args.include_style_markers or args.include_embeds:
postprocess_configs = [
{
"include_paragraph_markers": args.include_paragraph_markers,
"include_style_markers": args.include_style_markers,
"include_embeds": args.include_embeds,
}
]
else:
if args.experiment:
LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.")
with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
postprocess_configs = yaml.safe_load(file).get("postprocess", [])
if len(postprocess_configs) == 0:
LOGGER.info("No postprocessing requests found.")
exit()
else:
LOGGER.info("Please use at least one postprocessing option.")
exit()

if str(src_path).startswith(str(get_project_dir(""))):
settings = FileParatextProjectSettingsParser(src_path.parent).parse()
stylesheet = settings.stylesheet
Expand All @@ -198,36 +217,61 @@ def main() -> None:
f"'source' and 'draft' must have the exact same USFM structure. Mismatched ref: {src_ref} {draft_ref}"
)

paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE if args.include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP
)
style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_style_markers else UpdateUsfmMarkerBehavior.STRIP
embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_embeds else UpdateUsfmMarkerBehavior.STRIP

update_block_handlers = []
if args.include_paragraph_markers or args.include_style_markers:
update_block_handlers.append(construct_place_markers_handler(src_refs, src_sents, draft_sents))

with src_path.open(encoding=encoding) as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)],
id_text=book,
text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()
if any(
ppc.get("include_paragraph_markers", False) or ppc.get("include_style_markers", False)
for ppc in postprocess_configs
):
place_markers_handler = construct_place_markers_handler(src_refs, src_sents, draft_sents)

for postprocess_config in postprocess_configs:
update_block_handlers = []
if postprocess_config.get("include_paragraph_markers", False) or postprocess_config.get(
"include_style_markers", False
):
update_block_handlers.append(place_markers_handler)

paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config.get("include_paragraph_markers", False)
else UpdateUsfmMarkerBehavior.STRIP
)
style_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config.get("include_style_markers", False)
else UpdateUsfmMarkerBehavior.STRIP
)
embed_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config.get("include_embeds", False)
else UpdateUsfmMarkerBehavior.STRIP
)
marker_placement_suffix = (
"_"
+ ("p" if postprocess_config.get("include_paragraph_markers", False) else "")
+ ("s" if postprocess_config.get("include_style_markers", False) else "")
+ ("e" if postprocess_config.get("include_embeds", False) else "")
)

with src_path.open(encoding=encoding) as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)],
id_text=book,
text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()

usfm_out = insert_draft_remarks(usfm_out, draft_remarks)
usfm_out = insert_draft_remarks(usfm_out, draft_remarks)

out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent
out_path = out_dir / f"{draft_path.stem}_postprocessed{draft_path.suffix}"
with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f:
f.write(usfm_out)
out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent
out_path = out_dir / f"{draft_path.stem}{marker_placement_suffix}{draft_path.suffix}"
with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f:
f.write(usfm_out)


if __name__ == "__main__":
Expand Down
153 changes: 93 additions & 60 deletions silnlp/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from itertools import groupby
from math import exp
from pathlib import Path
from typing import Iterable, List, Optional
from typing import Dict, Iterable, List, Optional

import docx
import nltk
Expand Down Expand Up @@ -137,9 +137,7 @@ def translate_book(
produce_multiple_translations: bool = False,
chapters: List[int] = [],
trg_project: Optional[str] = None,
include_paragraph_markers: bool = False,
include_style_markers: bool = False,
include_embeds: bool = False,
postprocess_configs: List[Dict[str, bool]] = [],
experiment_ckpt_str: str = "",
) -> None:
book_path = get_book_path(src_project, book)
Expand All @@ -156,9 +154,7 @@ def translate_book(
produce_multiple_translations,
chapters,
trg_project,
include_paragraph_markers,
include_style_markers,
include_embeds,
postprocess_configs,
experiment_ckpt_str,
)

Expand All @@ -171,9 +167,7 @@ def translate_usfm(
produce_multiple_translations: bool = False,
chapters: List[int] = [],
trg_project: Optional[str] = None,
include_paragraph_markers: bool = False,
include_style_markers: bool = False,
include_embeds: bool = False,
postprocess_configs: List[Dict[str, bool]] = [],
experiment_ckpt_str: str = "",
) -> None:
# Create UsfmFileText object for source
Expand Down Expand Up @@ -226,72 +220,111 @@ def translate_usfm(
vrefs.insert(idx, vref)
output.insert(idx, [None, None, None, None])

# Update behaviors
text_behavior = (
UpdateUsfmTextBehavior.PREFER_NEW if trg_project is not None else UpdateUsfmTextBehavior.STRIP_EXISTING
)
paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE if include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP
)
style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if include_style_markers else UpdateUsfmMarkerBehavior.STRIP
embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if include_embeds else UpdateUsfmMarkerBehavior.STRIP
# Base draft
postprocess_configs = [
{"include_paragraph_markers": False, "include_style_markers": False, "include_embeds": False}
] + postprocess_configs

draft_set: DraftGroup = DraftGroup(translations)
for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1):
rows = [([ref], translation) for ref, translation in zip(vrefs, translated_draft)]

update_block_handlers = []
if include_paragraph_markers or include_style_markers:
update_block_handlers.append(construct_place_markers_handler(vrefs, sentences, translated_draft))

# Insert translation into the USFM structure of an existing project
# If the target project is not the same as the translated file's original project,
# no verses outside of the ones translated will be overwritten
if trg_project is not None or src_from_project:
dest_updater = FileParatextProjectTextUpdater(
get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name)
if any(
ppc.get("include_paragraph_markers", False) or ppc.get("include_style_markers", False)
for ppc in postprocess_configs
):
place_markers_handler = construct_place_markers_handler(vrefs, sentences, translated_draft)

for postprocess_config in postprocess_configs:
# Update behaviors
text_behavior = (
UpdateUsfmTextBehavior.PREFER_NEW
if trg_project is not None
else UpdateUsfmTextBehavior.STRIP_EXISTING
)
usfm_out = dest_updater.update_usfm(
book_id=src_file_text.id,
rows=rows,
text_behavior=text_behavior,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config.get("include_paragraph_markers", False)
else UpdateUsfmMarkerBehavior.STRIP
)

if usfm_out is None:
raise FileNotFoundError(f"Book {src_file_text.id} does not exist in target project {trg_project}")
else: # Slightly more manual version for updating an individual file
with open(src_file_path, encoding="utf-8-sig") as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=rows,
id_text=vrefs[0].book,
text_behavior=text_behavior,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
style_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config.get("include_style_markers", False)
else UpdateUsfmMarkerBehavior.STRIP
)
embed_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE
if postprocess_config.get("include_embeds", False)
else UpdateUsfmMarkerBehavior.STRIP
)
marker_placement_suffix = (
"_"
+ ("p" if postprocess_config.get("include_paragraph_markers", False) else "")
+ ("s" if postprocess_config.get("include_style_markers", False) else "")
+ ("e" if postprocess_config.get("include_embeds", False) else "")
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()
marker_placement_suffix = "" if len(marker_placement_suffix) == 1 else marker_placement_suffix

update_block_handlers = []
if postprocess_config.get("include_paragraph_markers", False) or postprocess_config.get(
"include_style_markers", False
):
update_block_handlers.append(place_markers_handler)

# Insert translation into the USFM structure of an existing project
# If the target project is not the same as the translated file's original project,
# no verses outside of the ones translated will be overwritten
if trg_project is not None or src_from_project:
dest_updater = FileParatextProjectTextUpdater(
get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name)
)
usfm_out = dest_updater.update_usfm(
book_id=src_file_text.id,
rows=rows,
text_behavior=text_behavior,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)

if usfm_out is None:
raise FileNotFoundError(
f"Book {src_file_text.id} does not exist in target project {trg_project}"
)
else: # Slightly more manual version for updating an individual file
with open(src_file_path, encoding="utf-8-sig") as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=rows,
id_text=vrefs[0].book,
text_behavior=text_behavior,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()

# Insert draft remark and write to output path
description = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}"
usfm_out = insert_draft_remark(usfm_out, vrefs[0].book, description, experiment_ckpt_str)
trg_draft_file_path = trg_file_path.with_stem(trg_file_path.stem + marker_placement_suffix)
if produce_multiple_translations:
trg_draft_file_path = trg_draft_file_path.with_suffix(f".{draft_index}{trg_file_path.suffix}")
with trg_draft_file_path.open(
"w", encoding=src_settings.encoding if src_from_project else "utf-8"
) as f:
f.write(usfm_out)

# Insert draft remark and write to output path
description = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}"
usfm_out = insert_draft_remark(usfm_out, vrefs[0].book, description, experiment_ckpt_str)
confidence_scores_suffix = ".confidences.tsv"
if produce_multiple_translations:
trg_draft_file_path = trg_file_path.with_suffix(f".{draft_index}{trg_file_path.suffix}")
confidences_path = trg_file_path.with_suffix(
f".{draft_index}{trg_file_path.suffix}{confidence_scores_suffix}"
)
else:
trg_draft_file_path = trg_file_path
confidences_path = trg_file_path.with_suffix(f"{trg_file_path.suffix}{confidence_scores_suffix}")
with trg_draft_file_path.open("w", encoding=src_settings.encoding if src_from_project else "utf-8") as f:
f.write(usfm_out)
with confidences_path.open("w", encoding="utf-8", newline="\n") as confidences_file:
confidences_file.write("\t".join(["VRef"] + [f"Token {i}" for i in range(200)]) + "\n")
confidences_file.write("\t".join(["Sequence Score"] + [f"Token Score {i}" for i in range(200)]) + "\n")
Expand Down
8 changes: 2 additions & 6 deletions silnlp/nmt/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def translate(self):
config.get("trg_project"),
config.get("trg_iso"),
self.produce_multiple_translations,
config.get("include_paragraph_markers", False) or config.get("preserve_usfm_markers", False),
config.get("include_style_markers", False) or config.get("preserve_usfm_markers", False),
config.get("include_embeds", False) or config.get("include_inline_elements", False),
translate_configs.get("postprocess", []),
)
elif config.get("src_prefix"):
translator.translate_text_files(
Expand All @@ -116,9 +114,7 @@ def translate(self):
config.get("src_iso"),
config.get("trg_iso"),
self.produce_multiple_translations,
config.get("include_paragraph_markers", False) or config.get("preserve_usfm_markers", False),
config.get("include_style_markers", False) or config.get("preserve_usfm_markers", False),
config.get("include_embeds", False) or config.get("include_inline_elements", False),
translate_configs.get("postprocess", []),
)
else:
raise RuntimeError("A Scripture book, file, or file prefix must be specified for translation.")
Expand Down
Loading