Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for .pvar.zst input and output #9

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ dependencies = [
"kaleido",
"nbformat",
"pong",
"adjustText"
"adjustText",
"zstandard",
]

[project.optional-dependencies]
Expand Down
7 changes: 4 additions & 3 deletions snputils/snp/io/read/__test__/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,21 @@ def data_path():
)

# Generate bed and pgen formats
for fmt in ["bed", "pgen"]:
for fmt in ["bed", "pgen", "pgen_zst"]:
fmt_path = data_path / fmt
os.makedirs(fmt_path, exist_ok=True)
fmt_file = fmt_path / "subset"
if not fmt_file.exists():
print(f"Generating {fmt} format...")
make_fmt = "--make-pgen vzs" if fmt == "pgen_zst" else f"--make-{fmt.split('_')[0]}"
subprocess.run(
[
"./plink2",
"--vcf",
subset_vcf,
f"--make-{fmt}",
*make_fmt.split(),
"--out",
fmt + "/subset",
f"{fmt}/subset",
],
cwd=str(data_path),
)
Expand Down
19 changes: 19 additions & 0 deletions snputils/snp/io/read/__test__/test_formats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from snputils import PGENReader

# # VCF - BED

Expand Down Expand Up @@ -122,3 +123,21 @@ def test_vcf_pgen_variants_pos(snpobj_vcf, snpobj_pgen):
assert snpobj_vcf.variants_pos is not None
assert snpobj_pgen.variants_pos is not None
assert np.array_equal(snpobj_vcf.variants_pos, snpobj_pgen.variants_pos)


# Compressed VCF
def test_vcf_gz(data_path):
pass # TODO


# PGEN with compressed pvar
def test_pgen_pvar_zst(data_path, snpobj_pgen):
snpobj = PGENReader(data_path + "/pgen_zst/subset").read(phased=True)
assert np.array_equal(snpobj_pgen.calldata_gt, snpobj.calldata_gt)
assert np.array_equal(snpobj_pgen.variants_ref, snpobj.variants_ref)
assert np.array_equal(snpobj_pgen.variants_alt, snpobj.variants_alt)
assert np.array_equal(snpobj_pgen.variants_chrom, snpobj.variants_chrom)
assert np.array_equal(snpobj_pgen.variants_id, snpobj.variants_id)
assert np.array_equal(snpobj_pgen.variants_pos, snpobj.variants_pos)
assert np.array_equal(snpobj_pgen.variants_filter_pass, snpobj.variants_filter_pass)
assert np.array_equal(snpobj_pgen.variants_qual, snpobj.variants_qual)
4 changes: 2 additions & 2 deletions snputils/snp/io/read/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __new__(cls,
if not suffixes:
raise ValueError("The filename should have an extension when using SNPReader.")

extension = suffixes[-2] if suffixes[-1].lower() == ".gz" else suffixes[-1]
extension = suffixes[-2] if suffixes[-1].lower() in (".zst", ".gz") else suffixes[-1]
extension = extension.lower()

if extension == ".vcf":
Expand All @@ -41,7 +41,7 @@ def __new__(cls,
from snputils.snp.io.read.bed import BEDReader

return BEDReader(filename)
elif extension in (".pgen", ".pvar", ".psam"):
elif extension in (".pgen", ".pvar", ".psam", ".pvar.zst"):
from snputils.snp.io.read.pgen import PGENReader

return PGENReader(filename)
Expand Down
33 changes: 27 additions & 6 deletions snputils/snp/io/read/pgen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -69,26 +70,46 @@ def read(
only_read_pgen = fields == ["GT"] and variant_idxs is None and sample_idxs is None

filename_noext = str(self.filename)
if filename_noext[-5:].lower() in (".pgen", ".pvar", ".psam"):
filename_noext = filename_noext[:-5]
for ext in [".pgen", ".pvar", ".pvar.zst", ".psam"]:
if filename_noext.endswith(ext):
filename_noext = filename_noext[:-len(ext)]
break

if only_read_pgen:
file_num_samples = None # Not needed for pgen
file_num_variants = None # Not needed
else:
log.info(f"Reading {filename_noext}.pvar")
pvar_extensions = [".pvar", ".pvar.zst"]
pvar_filename = None
for ext in pvar_extensions:
possible_pvar = filename_noext + ext
if os.path.exists(possible_pvar):
pvar_filename = possible_pvar
break
if pvar_filename is None:
raise FileNotFoundError(f"No .pvar or .pvar.zst file found for {filename_noext}")

log.info(f"Reading {pvar_filename}")

def open_textfile(filename):
if filename.endswith('.zst'):
import zstandard as zstd
return zstd.open(filename, 'rt')
else:
return open(filename, 'rt')

pvar_has_header = True
pvar_header_line_num = 0
with open(filename_noext + ".pvar") as file:
with open_textfile(pvar_filename) as file:
for line_num, line in enumerate(file):
if line.startswith("#CHROM"):
pvar_header_line_num = line_num
break
else: # if no break
pvar_has_header = False

pvar = pl.scan_csv(
filename_noext + ".pvar",
pvar_filename,
separator='\t',
skip_rows=pvar_header_line_num,
has_header=pvar_has_header,
Expand All @@ -101,7 +122,7 @@ def read(
"ALT": pl.String,
},
).with_row_index()

# since pvar is lazy, the skip_rows operation hasn't materialized
# pl.len() will return the length of the pvar + header
file_num_variants = pvar.select(pl.len()).collect().item() - pvar_header_line_num
Expand Down
52 changes: 36 additions & 16 deletions snputils/snp/io/write/pgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import polars as pl
import pgenlib as pg
from pathlib import Path
import zstandard as zstd

from snputils.snp.genobj.snpobj import SNPObject

Expand All @@ -18,35 +19,42 @@ def __init__(self, snpobj: SNPObject, filename: str):
"""
Initializes the PGENWriter instance.

Parameters
----------
snpobj : SNPObject
The SNPObject containing genotype data to be written.
file : str
Base path for the output files (excluding extension).
TODO: add support for parallel writing by chromosome.
Args:
snpobj (SNPObject): The SNPObject containing genotype data to be written.
filename (str): Base path for the output files (excluding extension).
"""
self.__snpobj = snpobj
self.__filename = Path(filename)

def write(self):
def write(self, vzs: bool = False):
"""
Writes the SNPObject data to .pgen, .psam, and .pvar files.

Args:
vzs (bool, optional): If True, compresses the .pvar file using zstd and saves it as .pvar.zst. Defaults to False.
"""
file_extensions = (".pgen", ".psam", ".pvar")
file_extensions = (".pgen", ".psam", ".pvar", ".pvar.zst")
if self.__filename.suffix in file_extensions:
self.__filename = self.__filename.with_suffix('')
self.__file_extension = ".pgen"

self.write_pvar()
self.write_pvar(vzs=vzs)
self.write_psam()
self.write_pgen()

def write_pvar(self):
def write_pvar(self, vzs: bool = False):
"""
Writes variant data to the .pvar file.

Args:
vzs (bool, optional): If True, compresses the .pvar file using zstd and saves it as .pvar.zst. Defaults to False.
"""
log.info(f"Writing to {self.__filename}.pvar")
output_filename = f"{self.__filename}.pvar"
if vzs:
output_filename += ".zst"
log.info(f"Writing to {output_filename} (compressed)")
else:
log.info(f"Writing to {output_filename}")

df = pl.DataFrame(
{
"#CHROM": self.__snpobj.variants_chrom,
Expand All @@ -59,7 +67,19 @@ def write_pvar(self):
}
)
# TODO: add header to the .pvar file, if not it's lost
df.write_csv(f"{self.__filename}.pvar", separator="\t")

# Write the DataFrame to a CSV string
csv_data = df.write_csv(None, separator="\t")

if vzs:
# Compress the CSV data using zstd
cctx = zstd.ZstdCompressor()
compressed_data = cctx.compress(csv_data.encode('utf-8'))
with open(output_filename, 'wb') as f:
f.write(compressed_data)
else:
with open(output_filename, 'w') as f:
f.write(csv_data)

def write_psam(self):
"""
Expand Down Expand Up @@ -89,10 +109,10 @@ def write_pgen(self):
)
else:
num_variants, num_samples = self.__snpobj.calldata_gt.shape
flat_genotypes = self.__snpobj.__calldata_gt
flat_genotypes = self.__snpobj.calldata_gt

with pg.PgenWriter(
filename=str(self.__filename).encode('utf-8'),
filename=f"{self.__filename}.pgen".encode('utf-8'),
sample_ct=num_samples,
variant_ct=num_variants,
hardcall_phase_present=phased,
Expand Down