Skip to content

Commit

Permalink
Added possibilty to add the common variant genotype vector during ass…
Browse files Browse the repository at this point in the history
…ociation testing. Added new internal flag to dense_gt to enable this behaviour. Changed make_dataset functions to export additional zarr file and load only needed genotype vector columns for regression.
  • Loading branch information
ThibaultBechtler committed Dec 18, 2023
1 parent 995e5bb commit b7a2300
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 32 deletions.
30 changes: 21 additions & 9 deletions deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
x_phenotypes: List[str] = [],
grouping_level: Optional[str] = "gene",
group_common: bool = False,
in_association: bool = False,
return_sparse: bool = False,
annotations: List[str] = [],
annotation_file: Optional[str] = None,
Expand Down Expand Up @@ -127,6 +128,14 @@ def __init__(
self.variant_matrix = f["variant_matrix"][:]
self.genotype_matrix = f["genotype_matrix"][:]

# check if we need to standardize the common genotype data
# if so cache the whole thing in memory to perform the computation
# if config["training_data"]["dataset_config"]["std_common"]:
# self.genotype_matrix = f["genotype_matrix"][:]
# cvar_mean = np.mean(self.genotype_matrix, dim=0)
# cvar_std = np.std(self.genotype_matrix, dim=0)
# self.genotype_matrix = (self.genotype_matrix - cvar_mean) / cvar_std

logger.info(
f"Using phenotype file {phenotype_file} and genotype file {self.gt_filename}"
)
Expand Down Expand Up @@ -165,6 +174,7 @@ def __init__(
)

self.group_common = group_common
self.in_association = in_association
self.return_sparse = return_sparse

self.annotations = annotations
Expand Down Expand Up @@ -837,15 +847,17 @@ def get_common_variants(
common_variants = torch.tensor(common_variants, dtype=torch.float)

if self.group_common:
# import pdb; pdb.set_trace()
# common_variants = [
# common_variants[vmap] for vmap in self.group_matrix_maps
# ]

# repurposing the group_common flag here
# subset it down to the indices already collected in self.group_matrix_maps
common_geno_idx = np.unique(np.concatenate(self.group_matrix_maps, axis=None))
common_variants = common_variants[common_geno_idx]
if self.in_association:
# check if this instance is used for association testing
# if so return the follow genotype vector while also
# creating self.group_matrix_maps
return common_variants, masked_sparse_variants, masked_sparse_genotype
else:
# repurposing the group_common flag here
# subset genotype vector down to the indices collected in self.group_matrix_maps
# this subsetting is only needed if this instance is used for training
common_geno_idx = np.unique(np.concatenate(self.group_matrix_maps, axis=None))
common_variants = common_variants[common_geno_idx]

return common_variants, masked_sparse_variants, masked_sparse_genotype

Expand Down
131 changes: 115 additions & 16 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_burden(
batch: Dict,
agg_models: Dict[str, List[nn.Module]],
device: torch.device = torch.device("cpu"),
skip_burdens=False,
skip_burdens=False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute burden scores for rare variants.
Expand All @@ -63,6 +63,8 @@ def get_burden(
:type device: torch.device
:param skip_burdens: Flag to skip burden computation, defaults to False.
:type skip_burdens: bool
:param use_common: Flag to add common variant genotype information to x_pheno, defaults to False.
:type use_common: bool
:return: Tuple containing burden scores, target y phenotype values, and x phenotypes.
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Expand All @@ -87,13 +89,9 @@ def get_burden(

y = batch["y"]
x = batch["x_phenotypes"] # containes other covariates e.g. age, genetic PCs
cvar = batch["common_variants"].numpy()

# get common geno on forward pass
# cvar = batch["common_variants"].to_numpy()
# glue common variant genotype to other covariates
# x = np.stack(x, cvar, axis=1)

return burden, y, x
return burden, y, x, cvar


def separate_parallel_results(results: List) -> Tuple[List, ...]:
Expand Down Expand Up @@ -146,6 +144,7 @@ def make_dataset_(
variant_file=data_config["variant_file"],
split="",
skip_y_na=False,
in_association=True, # enable alternative logic if group_common is activated
**copy.deepcopy(data_config["dataset_config"]),
)

Expand Down Expand Up @@ -279,22 +278,27 @@ def compute_burdens_(

logger.info("Computing burden scores")
batch_size = data_config["dataloader_config"]["batch_size"]
use_common = data_config["dataset_config"]["use_common_variants"]
with torch.no_grad():
for i, batch in tqdm(
enumerate(dl),
file=sys.stdout,
total=(n_samples // batch_size + (n_samples % batch_size != 0)),
):
# run forward pass on all repeats to get gene burden
this_burdens, this_y, this_x = get_burden(
this_burdens, this_y, this_x, this_cvar = get_burden(
batch, agg_models, device=device, skip_burdens=skip_burdens
)

if i == 0:
if not skip_burdens:
chunk_burden = np.zeros(shape=(n_samples,) + this_burdens.shape[1:])
chunk_y = np.zeros(shape=(n_samples,) + this_y.shape[1:])
chunk_x = np.zeros(shape=(n_samples,) + this_x.shape[1:])

if use_common:
chunk_cvar = np.zeros(shape=(n_samples,) + this_cvar.shape[1:])

logger.info(f"Batch size: {batch['rare_variant_annotations'].shape}")

if not skip_burdens:
Expand All @@ -310,6 +314,17 @@ def compute_burdens_(
else:
burdens = None

if use_common:
cvar = zarr.open(
Path(cache_dir) / "common_variants.zarr",
mode="a",
shape=(n_total_samples,) + this_cvar.shape[1:],
chunks=(1000, 1000),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)
logger.info(f"common genotype shape: {cvar.shape}")

y = zarr.open(
Path(cache_dir) / "y.zarr",
mode="a",
Expand All @@ -333,6 +348,9 @@ def compute_burdens_(
if not skip_burdens:
chunk_burden[start_idx:end_idx] = this_burdens

if use_common:
chunk_cvar[start_idx:end_idx] = this_cvar

chunk_y[start_idx:end_idx] = this_y
chunk_x[start_idx:end_idx] = this_x

Expand All @@ -350,11 +368,24 @@ def compute_burdens_(
y[chunk_start:chunk_end] = chunk_y
x[chunk_start:chunk_end] = chunk_x

if use_common:
cvar[chunk_start:chunk_end] = chunk_cvar

if torch.cuda.is_available():
logger.info(
"Max GPU memory allocated: " f"{torch.cuda.max_memory_allocated(0)} bytes"
)

if use_common:
# build dict for gene group mapping on common genotype vector
# and store as pickle in burden dir
cvar_group_dict = {
gene: group_indices for gene, group_indices in zip(ds_full.group_names, ds_full.group_matrix_maps)
}

with open(Path(cache_dir) / "common_variants_group_map.pkl", 'wb') as file:
pickle.dump(cvar_group_dict, file)

return ds_full.rare_embedding.genes, burdens, y, x


Expand Down Expand Up @@ -695,6 +726,7 @@ def regress_on_gene(
x_pheno: np.ndarray,
use_bias: bool,
use_x_pheno: bool,
common_var_genotype: np.ndarray = None,
) -> Tuple[List[str], List[float], List[float]]:
"""
Perform regression on a gene using Ordinary Least Squares (OLS).
Expand All @@ -711,6 +743,9 @@ def regress_on_gene(
:type use_bias: bool
:param use_x_pheno: Flag to include x phenotype data in regression.
:type use_x_pheno: bool
:param common_var_genotype: common variant genotype vector
: type common_var_genotype: np.ndarray
:return: Tuple containing gene name, beta, and p-value.
:rtype: Tuple[List[str], List[float], List[float]]
"""
Expand All @@ -734,7 +769,9 @@ def regress_on_gene(
x_pheno = np.expand_dims(x_pheno, axis=1)
X = np.concatenate((X, x_pheno), axis=1)

# TODO: add something similar for common genotype?
if common_var_genotype is not None:
# add common variant genotype vector to X
X = np.concatenate((X, common_var_genotype.T), axis=1)

genes_params_pvalues = ([], [], [])
for this_y in np.split(y, y.shape[1], axis=1):
Expand All @@ -758,6 +795,7 @@ def regress_(
x_pheno: np.ndarray,
use_x_pheno: bool = True,
do_scoretest: bool = True,
cvar_vector_dict: Dict[int, np.array] = None,
) -> pd.DataFrame:
"""
Perform regression on multiple genes.
Expand All @@ -776,6 +814,8 @@ def regress_(
:type genes: pd.Series
:param x_pheno: X phenotype data.
:type x_pheno: np.ndarray
:param cvar_vector_dict: dictionary of common variant genotype data per gene.
:type cvar_vector_dict: Dict[int, np.array]
:param use_x_pheno: Flag to include x phenotype data when performing OLS regression, defaults to True.
:type use_x_pheno: bool
:param do_scoretest: Flag to use the scoretest from SEAK, defaults to True.
Expand Down Expand Up @@ -814,13 +854,39 @@ def regress_(
)
]
else:
logger.info("Running regression on each gene using OLS")
genes_betas_pvals = [
regress_on_gene(gene, burdens[:, i], y, x_pheno, use_bias, use_x_pheno)
for i, gene in tqdm(
zip(gene_indices, genes), total=genes.shape[0], file=sys.stdout
)
]
if cvar_vector_dict is not None:
# regression with common variant genotype
logger.info("Running regression on each gene using OLS")
genes_betas_pvals = [
regress_on_gene(
gene,
burdens[:, i],
y,
x_pheno,
use_bias,
use_x_pheno,
common_var_genotype=cvar_vector_dict[i],
)
for i, gene in tqdm(
zip(gene_indices, genes), total=genes.shape[0], file=sys.stdout
)
]
else:
# regression with rare only
logger.info("Running regression on each gene using OLS")
genes_betas_pvals = [
regress_on_gene(
gene,
burdens[:, i],
y,
x_pheno,
use_bias,
use_x_pheno,
)
for i, gene in tqdm(
zip(gene_indices, genes), total=genes.shape[0], file=sys.stdout
)
]

genes_betas_pvals = [x for x in genes_betas_pvals if x is not None]
regressed_genes, betas, pvals = separate_parallel_results(genes_betas_pvals)
Expand Down Expand Up @@ -895,6 +961,8 @@ def regress(
x_pheno = zarr.open(Path(burden_dir) / "x.zarr")[:]
genes = pd.Series(np.load(Path(burden_dir) / "genes.npy"))

# debug = True

if sample_file is not None:
with open(sample_file, "rb") as f:
samples = pickle.load(f)["association_samples"]
Expand All @@ -917,6 +985,8 @@ def regress(
with open(config_file) as f:
config = yaml.safe_load(f)

use_common = config["data"]["dataset_config"]["use_common_variants"]

if gene_file is not None:
logger.info("Loading gene names")
gene_df = pd.read_parquet(gene_file, engine="pyarrow")
Expand All @@ -931,6 +1001,34 @@ def regress(
gene_indices = np.arange(chunk_start, chunk_end)
genes = genes.iloc[chunk_start:chunk_end]

cvar_vector_dict = None

if use_common:
# load additional files if common variant data should be added
cvar = zarr.open(Path(burden_dir) / "common_variants.zarr")

with open(Path(burden_dir) / "common_variants_group_map.pkl", 'rb') as file:
cvar_group_map = pickle.load(file)

# build dict containing needed genotype positions per gene in chunk
cvar_vector_dict = {}

# build index array for samples
# apply nan mask to slice coordinates
sample_indices = np.arange(0, len(cvar))[nan_mask]

for g_i in gene_indices:
g_i_cvar = None
# load common variant information for required genes only
if g_i in cvar_group_map.keys():
# load common variant genotype data for gene i
var_indices = np.expand_dims(cvar_group_map[g_i], axis=1)
g_i_cvar = cvar.get_coordinate_selection(
(sample_indices, var_indices)
)

cvar_vector_dict[g_i] = g_i_cvar

associations = regress_(
config,
use_bias,
Expand All @@ -939,6 +1037,7 @@ def regress(
gene_indices,
genes,
x_pheno,
cvar_vector_dict=cvar_vector_dict,
do_scoretest=do_scoretest,
)

Expand Down
12 changes: 5 additions & 7 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,11 @@ def make_dataset_(
input_tensor, covariates, y, common_variants, config["training"]["min_variant_count"]
)

# TODO: add standardization of common_var genotype here ?
# if config["data"]["std_common"]:
# cvar_mean = torch.mean(common_variants, dim=0)
# cvar_std = torch.std(common_variants, dim=0)

if config["data"]["dataset_config"]["std_common"]:
cvar_mean = torch.mean(common_variants, dim=0)
cvar_std = torch.std(common_variants, dim=0)

common_variants = (common_variants - cvar_mean) / cvar_std
# common_variants = (common_variants - cvar_mean) / cvar_std

return input_tensor, covariates, y, common_variants

Expand Down Expand Up @@ -317,7 +315,7 @@ def make_dataset(
del input_tensor
zarr.save_array(covariates_out_file, covariates.numpy())
zarr.save_array(y_out_file, y.numpy())
zarr.save_array(common_vars_out_file,
zarr.save_array(common_vars_out_file,
common_variants.numpy(),
chunks=(1000, None),
compressor=Blosc(clevel=compression_level),
Expand Down

0 comments on commit b7a2300

Please sign in to comment.