Skip to content

Commit

Permalink
Mean imputation for missing data
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Feb 8, 2024
1 parent 41fedeb commit d67098e
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 27 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,9 @@ outputs/

# figures
**figures**

# DS_Store
**/.DS_Store

# Version file
neural_admixture/_version.py
1 change: 1 addition & 0 deletions neural_admixture/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

def main():
log.info(f"Neural ADMIXTURE - Version {__version__}")
log.info("[CHANGELOG] Mean imputation for missing data was added in version 1.4.0. To reproduce old behaviour, please use `--imputation zero` when invoking the software.")
log.info("[CHANGELOG] Default P initialization was changed to 'pckmeans' in version 1.3.0.")
log.info("[CHANGELOG] Warmup training for initialization of Q was added in version 1.3.0 to improve training stability (only for `pckmeans`).")
log.info("[CHANGELOG] Convergence check changed so it is performed after 15 epochs in version 1.3.0 to improve training stability.")
Expand Down
45 changes: 29 additions & 16 deletions neural_admixture/src/snp_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys

from pathlib import Path
from typing import Literal

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
log = logging.getLogger(__name__)
Expand All @@ -24,9 +25,6 @@ def _read_vcf(self, file: str) -> np.ndarray:
import allel
f_tr = allel.read_vcf(file)
calldata = f_tr["calldata/GT"]
if np.isnan(calldata).any():
log.warning("Data contains missing values. Will perform zero-imputation.")
calldata = np.nan_to_num(calldata, nan=0.)
return np.sum(calldata, axis=2).T/2

def _read_hdf5(self, file: str) -> np.ndarray:
Expand Down Expand Up @@ -55,9 +53,6 @@ def _read_bed(self, file: str) -> da.core.Array:
log.info('Input format is BED.')
from pandas_plink import read_plink
_, _, G = read_plink(str(Path(file).with_suffix("")))
if da.isnan(G).any().compute():
log.warning("Data contains missing values. Will perform zero-imputation.")
G = da.nan_to_num(G, 0.)
return (G.T/2)

def _read_pgen(self, file: str) -> np.ndarray:
Expand All @@ -81,9 +76,6 @@ def _read_pgen(self, file: str) -> np.ndarray:
pgen_reader = pg.PgenReader(pgen)
calldata = np.ascontiguousarray(np.empty((pgen_reader.get_variant_ct(), 2*pgen_reader.get_raw_sample_ct())).astype(np.int32))
pgen_reader.read_alleles_range(0, pgen_reader.get_variant_ct(), calldata)
if np.isnan(calldata).any():
log.warning("Data contains missing values. Will perform zero-imputation.")
calldata = np.nan_to_num(calldata, nan=0.)
return (calldata[:,::2]+calldata[:,1::2]).T/2

def _read_npy(self, file: str) -> np.ndarray:
Expand All @@ -98,18 +90,16 @@ def _read_npy(self, file: str) -> np.ndarray:
log.info('Input format is NPY.')
calldata = np.load(file)
assert calldata.ndim in [2, 3]
if np.isnan(calldata).any():
log.warning("Data contains missing values. Will perform zero-imputation.")
calldata = np.nan_to_num(calldata, nan=0.)
if calldata.ndim == 2:
return calldata/2
return np.nan_to_num(calldata, nan=0.).sum(axis=2)/2

def read_data(self, file: str) -> da.core.Array:
def read_data(self, file: str, imputation: str) -> da.core.Array:
"""Wrapper of readers
Args:
file (str): path to file
imputation (str): imputation method. Should be either 'zero' or 'mean'
Returns:
da.core.Array: averaged genotype Dask array of shape (n_samples, n_snps)
Expand All @@ -128,6 +118,29 @@ def read_data(self, file: str) -> da.core.Array:
else:
log.error('Invalid format. Unrecognized file format. Make sure file ends with .vcf | .vcf.gz | .bed | .pgen | .h5 | .hdf5 | .npy')
sys.exit(1)
assert int(G.min()) == 0 and int(G.max()) == 1, 'Only biallelic SNPs are supported. Please make sure multiallelic sites have been removed.'
G_corr = G if np.mean(G) < 0.5 else 1-G
return G_corr if isinstance(G_corr, da.core.Array) else da.from_array(G_corr)
if isinstance(G, np.ndarray):
G = da.from_array(G)
G = self._impute(G, method=imputation)
assert int(G.min().compute()) == 0 and int(G.max().compute()) == 1, 'Only biallelic SNPs are supported. Please make sure multiallelic sites have been removed.'
return G if G.mean().compute() < 0.5 else 1-G

@staticmethod
def _impute(G: da.core.Array, method: Literal["zero", "mean"]="mean") -> da.core.Array:
"""Impute missing values
Args:
G (da.core.Array): genotype array
Returns:
da.core.Array: imputed genotype array
"""
if da.isnan(G).any().compute():
log.warning(f"Data contains missing values. Will perform {method}-imputation.")
if method == "zero":
return da.nan_to_num(G, 0.)
elif method == "mean":
snp_means = da.nanmean(G, axis=0)[None, :]
return da.where(da.isnan(G), snp_means, G)
else:
raise ValueError("Invalid imputation method. Only 'zero' and 'mean' are supported.")
return G
4 changes: 2 additions & 2 deletions neural_admixture/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def main(argv: List[str]):
"""
args = utils.parse_train_args(argv)
tr_file, val_file = args.data_path, args.validation_data_path
assert not (val_file and args.cv is not None), 'Cross-validation not available when validation data path is provided.'
# assert not (val_file and args.cv is not None), 'Cross-validation not available when validation data path is provided.'
tr_pops_f, val_pops_f = args.populations_path, args.validation_populations_path
trX, trY, valX, valY = utils.read_data(tr_file, val_file, tr_pops_f, val_pops_f)
trX, trY, valX, valY = utils.read_data(tr_file, val_file, tr_pops_f, val_pops_f, imputation=args.imputation)
"""
if args.cv is not None:
log.info(f'Performing {args.cv}-fold cross-validation...')
Expand Down
19 changes: 10 additions & 9 deletions neural_admixture/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def parse_train_args(argv: List[str]):
parser.add_argument('--batch_size', required=False, default=400, type=int, help='Batch size')
parser.add_argument('--supervised_loss_weight', required=False, default=0.05, type=float, help='Weight given to the supervised loss')
parser.add_argument('--warmup_epochs', required=False, default=10, type=int, help='Number of warmup epochs to bring Q to a good initialization. Set to 0 to skip warmup.')
parser.add_argument('--imputation', type=str, default='mean', choices=['mean', 'zero'], help='Imputation method for missing data (zero or mean)')
# parser.add_argument('--cv', required=False, default=None, type=int, help='Number of folds for cross-validation')
return parser.parse_args(argv)

Expand Down Expand Up @@ -101,7 +102,7 @@ def initialize_wandb(run_name: str, trX: da.core.Array, valX: da.core.Array, arg
'out_path': out_path})
return run_name

def read_data(tr_file: str, val_file: str=None, tr_pops_f: str=None, val_pops_f: str=None) -> Tuple[da.core.Array, Union[None, da.core.Array], Union[None, List[str]], Union[None, List[str]]]:
def read_data(tr_file: str, val_file: str=None, tr_pops_f: str=None, val_pops_f: str=None, imputation: str="mean") -> Tuple[da.core.Array, Union[None, da.core.Array], Union[None, List[str]], Union[None, List[str]]]:
"""Read data in any compatible format
Args:
Expand All @@ -122,8 +123,8 @@ def read_data(tr_file: str, val_file: str=None, tr_pops_f: str=None, val_pops_f:
tr_pops, val_pops = None, None
log.info('Reading data...')
snp_reader = SNPReader()
tr_snps = snp_reader.read_data(tr_file)
val_snps = snp_reader.read_data(val_file) if val_file else None
tr_snps = snp_reader.read_data(tr_file, imputation)
val_snps = snp_reader.read_data(val_file, imputation) if val_file else None
if tr_pops_f:
with open(tr_pops_f, 'r') as fb:
tr_pops = [p.strip() for p in fb.readlines()]
Expand Down Expand Up @@ -234,16 +235,16 @@ def write_outputs(model: NeuralAdmixture, trX: da.core.Array, valX: Union[da.cor
return 0

def compute_deviances(model: NeuralAdmixture, data: da.core.Array, bsize: int, device: torch.device) -> Dict[int, float]:
"""_summary_
"""Compute deviances for CV error
Args:
model (NeuralAdmixture): _description_
trX (np.ndarray): _description_
bsize (int): _description_
device (torch.device): _description_
model (NeuralAdmixture): trained model object.
trX (np.ndarray): training data matrix.
bsize (int): batch size to retrieve the predictions.
device (torch.device): computing device.
Returns:
Dict[int, float]: _description_
Dict[int, float]: dictionary containing deviance value for each K
"""
eps = 1e-7
reconstructions = get_model_reconstructions(model, data, bsize, device)
Expand Down

0 comments on commit d67098e

Please sign in to comment.