Skip to content

Commit

Permalink
🚧 Use tsv-utils for --output-metadata
Browse files Browse the repository at this point in the history
tsv-join is much faster than the other implementation here (18x faster -
12s vs. 3m43s on the current SARS-CoV-2 GISAID dataset containing 16
million rows).
  • Loading branch information
victorlin committed Jul 17, 2024
1 parent 98517cd commit 31d70c3
Showing 1 changed file with 60 additions and 10 deletions.
70 changes: 60 additions & 10 deletions augur/filter/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from argparse import Namespace
import os
import re
from shutil import which
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from textwrap import dedent
from typing import Iterable, Sequence, Set
import numpy as np
Expand Down Expand Up @@ -96,6 +99,15 @@ def constant_factory(value):
raise AugurError(f"missing or malformed priority scores file {fname}")


def get_cat(file):
if file.endswith(".gz"):
return which("gzcat")
if file.endswith(".xz"):
return which("xzcat")
else:
return which("cat")


def write_metadata(input_metadata_path: str, delimiters: Sequence[str],
id_columns: Sequence[str], output_metadata_path: str,
ids_to_write: Set[str]):
Expand All @@ -105,16 +117,54 @@ def write_metadata(input_metadata_path: str, delimiters: Sequence[str],
"""
input_metadata = Metadata(input_metadata_path, delimiters, id_columns)

with xopen(output_metadata_path, "w") as output_metadata_handle:
output_metadata = csv.DictWriter(output_metadata_handle, fieldnames=input_metadata.columns,
delimiter="\t", lineterminator=os.linesep)
output_metadata.writeheader()

# Write outputs based on rows in the original metadata.
for row in input_metadata.rows():
row_id = row[input_metadata.id_column]
if row_id in ids_to_write:
output_metadata.writerow(row)
output_is_tsv = output_metadata_path.endswith(".tsv")
tsv_join = which("tsv-join")
cat = get_cat(input_metadata_path)

# TODO: support compressed outputs when xopen supports them
# FIXME: create an issue in the xopen repo
if output_is_tsv and tsv_join and cat:
with NamedTemporaryFile(delete=False) as include_file:
# 1. Write the IDs to a single-column TSV file for tsv-join
with open(include_file.name, "w") as f:
f.write(input_metadata.id_column + '\n')
for strain in ids_to_write:
f.write(strain + '\n')

# 2. Open a process to stream the input metadata as text (handling compression)
cat_args = [cat, input_metadata_path]
cat_process = Popen(cat_args, stdout=PIPE)

# 3. Use tsv-join to subset the input metadata by the IDs in the file created in (1)
tsv_join_args = [
tsv_join,
'-H',
'--filter-file', include_file.name,
'--key-fields', input_metadata.id_column,
]

with open(output_metadata_path, "w") as output_metadata_handle:
tsv_join_process = Popen(tsv_join_args, stdin=cat_process.stdout, stdout=output_metadata_handle)
if cat_process.stdout:
cat_process.stdout.close() # Allow cat_process to receive a SIGPIPE if tsv-join exits.
tsv_join_process.wait()
stdout, stderr = tsv_join_process.communicate()
# FIXME: check for errors from subprocesses

# TODO: use NamedTemporaryFile(delete=True, delete_on_close=False)
# once Python 3.12 is the minimum supported version.
os.unlink(include_file.name)
else:
with xopen(output_metadata_path, "w") as output_metadata_handle:
output_metadata = csv.DictWriter(output_metadata_handle, fieldnames=input_metadata.columns,
delimiter="\t", lineterminator=os.linesep)
output_metadata.writeheader()

# Write outputs based on rows in the original metadata.
for row in input_metadata.rows():
row_id = row[input_metadata.id_column]
if row_id in ids_to_write:
output_metadata.writerow(row)


def write_strains(output_strains_path: str, ids_to_write: Iterable[str]):
Expand Down

0 comments on commit 31d70c3

Please sign in to comment.