Skip to content

Commit

Permalink
update test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
hanyangii committed Aug 15, 2024
1 parent abbc783 commit 89f4d21
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
test/data/*
htmlcov/
.tox/
.nox/
Expand Down
28 changes: 28 additions & 0 deletions test/test_deconvolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@ def test_adjustment(trainer, tokenizer, data_loader, output_path, df_train):

assert pd.read_csv(os.path.join(output_path, "FI.csv"), sep="\t").shape[0] == len(df_train["dmr_label"].unique())

def test_multi_cell_type(trainer, tokenizer, data_loader, output_path, df_train):
deconvolute(trainer = trainer,
tokenizer = tokenizer,
data_loader = data_loader,
output_path = output_path,
df_train = df_train,
adjustment = False)

assert pd.read_csv(os.path.join(output_path, "deconvolution.csv"), sep="\t").shape[0] == 3

if __name__=="__main__":
f_bulk = "data/processed/test_seq.csv"
f_train = "data/processed/train_seq.csv"
model_dir = "res/bert.model/"
out_dir = "res/deconvolution/"

tokenizer = MethylVocab(k=3)

dataset = MethylBertFinetuneDataset(f_bulk,
tokenizer,
seq_len=150)
Expand All @@ -37,4 +48,21 @@ def test_adjustment(trainer, tokenizer, data_loader, output_path, df_train):
trainer.load(model_dir)

test_adjustment(trainer, tokenizer, data_loader, out_dir, df_train)
# multiple cell type
model_dir = "data/multi_cell_type/res/bert.model/"
f_bulk = "data/multi_cell_type/test_seq.csv"
f_train = "data/multi_cell_type/train_seq.csv"
dataset = MethylBertFinetuneDataset(f_bulk,
tokenizer,
seq_len=150)
data_loader = DataLoader(dataset, batch_size=50, num_workers=20)
df_train = pd.read_csv(f_train, sep="\t")

trainer = MethylBertFinetuneTrainer(len(tokenizer),
train_dataloader=data_loader,
test_dataloader=data_loader,
)
trainer.load(model_dir)
test_multi_cell_type(trainer, tokenizer, data_loader, out_dir, df_train)

print("Everything passed!")
89 changes: 77 additions & 12 deletions test/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch.utils.data import DataLoader
import pandas as pd
import os, shutil
import os, shutil, json

def load_data(train_dataset: str, test_dataset: str, batch_size: int = 64, num_workers: int = 40):
tokenizer=MethylVocab(k=3)
Expand All @@ -14,6 +14,9 @@ def load_data(train_dataset: str, test_dataset: str, batch_size: int = 64, num_w
train_dataset = MethylBertFinetuneDataset(train_dataset, tokenizer, seq_len=150)
test_dataset = MethylBertFinetuneDataset(test_dataset, tokenizer, seq_len=150)

if len(test_dataset) > 500:
test_dataset.subset_data(500)

# Create a data loader
print("Creating Dataloader")
local_step_batch_size = int(batch_size/4)
Expand Down Expand Up @@ -53,27 +56,73 @@ def test_finetune_no_pretrain(tokenizer : MethylVocab,
assert os.path.exists(os.path.join(save_path, "train.csv"))
assert steps == pd.read_csv(os.path.join(save_path, "train.csv")).shape[0]

def test_finetune_savefreq(tokenizer : MethylVocab,
def test_finetune_no_pretrain_focal(tokenizer : MethylVocab,
save_path : str,
train_data_loader : DataLoader,
test_data_loader : DataLoader,
pretrain_model : str,
steps : int=10):

trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer),
save_path=save_path,
train_dataloader=train_data_loader,
test_dataloader=test_data_loader,
with_cuda=False,
loss="focal_bce")

trainer.create_model(config_file=os.path.join(pretrain_model, "config.json"))

trainer.train(steps)
assert os.path.exists(os.path.exists(os.path.join(save_path, "config.json")))

with open(os.path.join(save_path, "config.json")) as fp:
config = json.load(fp)
assert config["loss"] == "focal_bce"

def test_finetune_no_pretrain_focal(tokenizer : MethylVocab,
save_path : str,
train_data_loader : DataLoader,
test_data_loader : DataLoader,
pretrain_model : str,
steps : int=10,
save_freq: int=1):
steps : int=10):

trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer),
save_path=save_path,
train_dataloader=train_data_loader,
test_dataloader=test_data_loader,
with_cuda=False,
loss="focal_bce")

trainer.create_model(config_file=os.path.join(pretrain_model, "config.json"))

trainer.train(steps)
assert os.path.exists(os.path.exists(os.path.join(save_path, "config.json")))

with open(os.path.join(save_path, "config.json")) as fp:
config = json.load(fp)
assert config["loss"] == "focal_bce"


def test_finetune_focal_multicelltype(tokenizer : MethylVocab,
save_path : str,
train_data_loader : DataLoader,
test_data_loader : DataLoader,
pretrain_model : str,
steps : int=10):

trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer),
save_path=save_path+"bert.model/",
train_dataloader=train_data_loader,
test_dataloader=test_data_loader,
save_freq=save_freq,
with_cuda=False)
with_cuda=False,
loss="focal_bce")
trainer.load(pretrain_model)
trainer.train(steps)

assert os.path.exists(os.path.join(save_path, "bert.model_step0/config.json"))
assert os.path.exists(os.path.join(save_path, "bert.model_step0/dmr_encoder.pickle"))
assert os.path.exists(os.path.join(save_path, "bert.model_step0/pytorch_model.bin"))
assert os.path.exists(os.path.join(save_path, "bert.model_step0/read_classification_model.pickle"))
assert os.path.exists(os.path.join(save_path, "bert.model/config.json"))
assert os.path.exists(os.path.join(save_path, "bert.model/dmr_encoder.pickle"))
assert os.path.exists(os.path.join(save_path, "bert.model/pytorch_model.bin"))
assert os.path.exists(os.path.join(save_path, "bert.model/read_classification_model.pickle"))

def test_finetune(tokenizer : MethylVocab,
save_path : str,
Expand Down Expand Up @@ -104,9 +153,9 @@ def reset_dir(dirname):
# For data processing
f_bam_list = "data/bam_list.txt"
f_dmr = "data/dmrs.csv"
f_ref = "data/genome/hg19.fa"
f_ref = "data/genome.fa"
out_dir = "data/processed/"

# Process data for fine-tuning
fdg.finetune_data_generate(
sc_dataset = f_bam_list,
Expand Down Expand Up @@ -137,7 +186,23 @@ def reset_dir(dirname):
reset_dir(save_path)
test_finetune_no_pretrain(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step)

reset_dir(save_path)
test_finetune_no_pretrain_focal(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step)

reset_dir(save_path)
test_finetune_savefreq(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step, save_freq=1)

# Multiple cell type
out_dir="data/multi_cell_type/"
tokenizer, train_data_loader, test_data_loader = \
load_data(train_dataset = os.path.join(out_dir, "train_seq.csv"),
test_dataset = os.path.join(out_dir, "test_seq.csv"))

# For fine-tuning
model_dir="data/pretrained_model/"
save_path="data/multi_cell_type/res/"

reset_dir(save_path)
test_finetune_focal_multicelltype(tokenizer, save_path, train_data_loader, test_data_loader, model_dir)

print("Everything passed!")
31 changes: 31 additions & 0 deletions test/test_finetune_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,50 @@ def test_dorado_aligned_file(bam_file: str, f_dmr: str, f_ref: str, out_dir = "t

print("test_dorado_aligned_file passed!")


def test_multi_cell_type(f_bam_file_list: str, f_dmr: str, f_ref: str, out_dir = "tmp/"):
fdg.finetune_data_generate(
sc_dataset = f_bam_file_list,
f_dmr = f_dmr,
f_ref = f_ref,
output_dir=out_dir,
split_ratio = 0.8,
n_cores=1
)

assert os.path.exists(out_dir+"train_seq.csv")
assert os.path.exists(out_dir+"test_seq.csv")
assert os.path.exists(out_dir+"dmrs.csv")

res = pd.read_csv(out_dir+"train_seq.csv", sep="\t")
assert "T" in res["ctype"].tolist()
assert "N" in res["ctype"].tolist()
assert "P" in res["ctype"].tolist()

print("test_multi_cell_type passed!")


if __name__=="__main__":
f_bam = "data/T_sample.bam"
f_bam_list = "data/bam_list.txt"
f_dmr = "data/dmrs.csv"
f_ref = "data/genome.fa"


test_single_bam_file(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref)
test_list_bam_file(f_bam_file_list = f_bam_list, f_dmr=f_dmr, f_ref=f_ref)
test_dmr_subset(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, n_dmrs=10)
test_multi_cores(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, n_cores=4)
test_split_ratio(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, split_ratio=0.7)


f_bam_list = "data/multi_cell_type/bam_list.txt"
f_dmr = "data/multi_cell_type/dmrs.csv"
out_dir = "data/multi_cell_type/"
test_multi_cell_type(f_bam_file_list = f_bam_list, f_dmr=f_dmr, f_ref=f_ref, out_dir=out_dir)

f_dorado = "data/dorado_aligned.bam"
f_ref_hg38="data/hg38_genome.fa"
test_dorado_aligned_file(bam_file = f_dorado, f_dmr=f_dmr, f_ref=f_ref_hg38)
print("Everything passed!")

0 comments on commit 89f4d21

Please sign in to comment.