Skip to content

Commit

Permalink
Merge pull request #36 from GoekeLab/release_1.1
Browse files Browse the repository at this point in the history
Release 1.1
  • Loading branch information
chrishendra93 authored Apr 29, 2022
2 parents be3148a + d2219b2 commit 530d451
Show file tree
Hide file tree
Showing 15 changed files with 606 additions and 70 deletions.
29 changes: 16 additions & 13 deletions docs/source/cmd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ Argument name Required Default value Descriptio
--eventalign=FILE Yes NA Eventalign filepath, the output from nanopolish.
--out_dir=DIR Yes NA Output directory.
--n_processes=NUM No 1 Number of processes to run.
--chunk_size No 1000000 chunksize argument for pandas read csv function on the eventalign input
--chunk_size=NUM No 1000000 chunksize argument for pandas read csv function on the eventalign input
--readcount_max=NUM No 1000 Maximum read counts per gene.
--readcount_min=NUM No 1 Minimum read counts per gene.
--index No True To skip indexing the eventalign nanopolish output, can only be used if the index has been created before
--n_neighbors No 1 The number of flanking positions to process
--n_neighbors=NUM No 1 The number of flanking positions to process
--min_segment_count=NUM No 1 Minimum read counts over each candidate m6A segment
================================= ========== =================== ============================================================================================================

* Output
Expand All @@ -43,17 +44,19 @@ data.readcount csv Summary of readcounts per gene.

Output files from ``m6anet-dataprep``.

======================= ========== ========================= ==============================================================================
Argument name Required Default value Description
======================= ========== ========================= ==============================================================================
--input_dir=DIR Yes NA Input directory that contains data.json, data.index, and data.readcount from m6anet-dataprep
--out_dir=DIR Yes NA Output directory for the inference results from m6anet
--model_config=FILE No prod_pooling.toml Model architecture specifications. Please see examples in m6anet/model/configs/model_configs/prod_pooling.toml
--model_state_dict=FILE No prod_pooling_pr_auc.pt Model weights to be used for inference. Please see examples in m6anet/model/model_states/
--batch_size No 64 Number of sites to be loaded each time for inference
--n_processes=NUM No 1 Number of processes to run.
--num_iterations=NUM No 5 Number of times m6anet iterates through each potential m6a sites.
======================= ========== ========================= ==============================================================================
========================== ========== ========================= ==============================================================================
Argument name Required Default value Description
========================== ========== ========================= ==============================================================================
--input_dir=DIR Yes NA Input directory that contains data.json, data.index, and data.readcount from m6anet-dataprep
--out_dir=DIR Yes NA Output directory for the inference results from m6anet
--model_config=FILE No prod_pooling.toml Model architecture specifications. Please see examples in m6anet/model/configs/model_configs/prod_pooling.toml
--model_state_dict=FILE No prod_pooling_pr_auc.pt Model weights to be used for inference. Please see examples in m6anet/model/model_states/
--batch_size=NUM No 64 Number of sites to be loaded each time for inference
--n_processes=NUM No 1 Number of processes to run.
--num_iterations=NUM No 5 Number of times m6anet iterates through each potential m6a sites.
--infer_mod_rate No False Whether to output m6A modification stoichiometry for each candidate site
--read_proba_threshold=NUM No 0.033379376 Threshold for each individual read to be considered modified during stoichiometry calculation
========================== ========== ========================= ==============================================================================

* Output

Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ m6anet requires Python version 3.8 or higher. To install the latest release with
See our :ref:`Installation page <installation>` for details.

To detect m6A modifications from your direct RNA sequencing sample, you can follow the instructions in our :ref:`Quickstart page <quickstart>`.
m6Anet is trained on dataset sequenced using the SQK-RNA002 kit and has been validated on dataset from SQK-RNA001 kit.
Newer pore version might alter the raw squiggle and affect segmentation and classification results and in such cases m6Anet might need to be retrained.

Contents
--------------------------
Expand Down
16 changes: 13 additions & 3 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

Quick Start
==================================
The demo dataset is provided with the repository under from /path/to/m6anet/demo/eventalign.txt
m6Anet requires eventalign.txt from nanopolish::
nanopolish eventalign --reads reads.fastq --bam reads.sorted.bam --genome transcript.fa --scale-events --signal-index --summary /path/to/summary.txt --threads 50 > /path/to/eventalign.txt

We have also provided a demo dataset in the repository under /path/to/m6anet/demo/eventalign.txt.

Firstly, we need to preprocess the segmented raw signal file in the form of nanopolish eventalign file using 'm6anet-dataprep'::

Expand All @@ -19,13 +22,20 @@ The output files are stored in ``/path/to/output``:

Now we can run m6anet over our data using m6anet-run_inference::

m6anet-run_inference --input_dir demo_data --out_dir demo_data ---n_processes 4
m6anet-run_inference --input_dir demo_data --out_dir demo_data --infer_mod-rate --n_processes 4

The output files `demo_data/data.result.csv.gz` contains the probability of modification at each individual position for each transcript. The output file will have 4 columns

* ``transcript_id``: The transcript id of the predicted position
* ``transcript_position``: The transcript position of the predicted position
* ``n_reads``: The number of reads for that particular position
* ``probability_modified``: The probability that a given site is modified
* ``kmer``: The 5-mer motif of a given site
* ``mod_ratio``: The estimated percentage of reads in a given site that is modified

The total run time should not exceed 10 minutes on a normal laptop. We also recommend a threshold of 0.9 for selecting m6A sites
based on the ``probability_modified`` column, which can be relaxed at the expense of having lower model precision.

The total run time should not exceed 10 minutes on a normal laptop
m6Anet also supports pooling over multiple replicates. To do this, simply input multiple folders containing m6anet-dataprep outputs::
m6anet-run_inference --input_dir demo_data_1 demo_data_2 ... --out_dir demo_data --infer_mod-rate --n_processes 4
4 changes: 2 additions & 2 deletions m6anet/model/configs/training_configs/oversampled.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
loss_function_type = "binary_cross_entropy_loss"

[dataset]
root_dir = "/home/christopherhendra/m6anet_test/demo2/"
root_dir = "/home/christopher/hct116_dataprep_5_neighbors/"
min_reads = 20
norm_path = "/home/christopherhendra/m6anet/m6anet/model/norm_factors/norm_dict.joblib"
norm_path = "/home/christopher/m6anet/m6anet/model/norm_factors/norm_dict.joblib"
num_neighboring_features = 1

[dataloader]
Expand Down
5 changes: 3 additions & 2 deletions m6anet/model/model_blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .blocks import ConcatenateFeatures, DeaggregateNanopolish, Flatten, KmerMultipleEmbedding, Linear
from .blocks import ConcatenateFeatures, DeaggregateNanopolish, Flatten, KmerMultipleEmbedding, Linear, ExtractSignal
from .pooling_blocks import SummaryStatsAggregator, ProbabilityAttention, SummaryStatsProbability, MeanAggregator
from .pooling_blocks import Attention, GatedAttention, KDELayer, KDEAttentionLayer, KDEGatedAttentionLayer
from .pooling_blocks import SigmoidMaxPooling, SigmoidMeanPooling, SigmoidProdPooling


__all__ = [
'ConcatenateFeatures', 'DeaggregateNanopolish', 'Flatten', 'KmerMultipleEmbedding', 'Linear', 'SummaryStatsAggregator', 'ProbabilityAttention', 'SummaryStatsProbability',

'ConcatenateFeatures', 'DeaggregateNanopolish', 'Flatten', 'ExtractSignal', 'KmerMultipleEmbedding', 'Linear', 'SummaryStatsAggregator', 'ProbabilityAttention', 'SummaryStatsProbability',
'Attention', 'GatedAttention', 'KDELayer', 'KDEAttentionLayer', 'KDEGatedAttentionLayer', 'SigmoidMaxPooling', 'SigmoidMeanPooling', 'SigmoidProdPooling', 'MeanAggregator'
]
9 changes: 9 additions & 0 deletions m6anet/model/model_blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ def forward(self, x):
return x


class ExtractSignal(Block):

def __init__(self):
super(ExtractSignal, self).__init__()

def forward(self, x):
return x['X']


class DeaggregateNanopolish(Block):

def __init__(self, num_neighboring_features, n_features=3):
Expand Down
28 changes: 27 additions & 1 deletion m6anet/scripts/compute_normalization_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
from multiprocessing import Pool
from tqdm import tqdm
import joblib
import json

Expand All @@ -27,6 +28,29 @@ def get_mean_std(task):
return kmer, mean, stds


def get_mean_std_replicates(task):
kmer, site_df = task
n_reads = 0
sum_X = []
sum_X2 = []
for _, row in site_df.iterrows():
tx_id, tx_pos = row["transcript_id"], row["transcript_position"]
coords, fpaths, segment_number = row["coords"], row["fpath"], row["segment_number"]
for coord, fpath in zip(coords, fpaths):
start_pos, end_pos = coord
features = read_features(os.path.join(fpath, "data.json"), tx_id, tx_pos, start_pos, end_pos)
indices = np.arange(3 * segment_number, 3 * (segment_number + 1))
signals = features[:, indices]
sum_X.append(np.sum(signals, axis=0))
sum_X2.append(np.sum(np.square(signals), axis=0))

n_reads += row["n_reads"]

mean = np.sum(sum_X, axis=0) / n_reads
stds = np.sqrt((np.sum(sum_X2, axis=0) / n_reads) - mean ** 2)
return kmer, mean, stds


def read_kmer(task):
data_fpath, tx_id, tx_pos, start_pos, end_pos = task
with open(data_fpath) as f:
Expand Down Expand Up @@ -100,10 +124,12 @@ def main():
info_df = info_df[info_df["set_type"] == 'Train']
index_df = pd.read_csv(os.path.join(input_dir, "data.index"))

index_df["transcript_position"] = index_df["transcript_position"].astype('int')
info_df["transcript_position"] = info_df["transcript_position"].astype('int')

merged_df = info_df.merge(index_df, on=["transcript_id", "transcript_position"])
merged_df = annotate_kmer_information(data_fpath, merged_df, n_processes)


if not os.path.exists(save_dir):
os.mkdir(save_dir)

Expand Down
156 changes: 156 additions & 0 deletions m6anet/scripts/cross_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@


import pandas as pd
import numpy as np
import os
import torch
import datetime
import joblib
import toml
import json
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter
from ..utils.training_utils import train, validate
from ..utils.builder import build_dataloader, build_loss_function
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from tqdm import tqdm
from ..model.model import MILModel
from copy import deepcopy


def argparser():
parser = ArgumentParser(
formatter_class=ArgumentDefaultsHelpFormatter,
add_help=False
)
parser.add_argument("--cv_dir", default=None, required=True)
parser.add_argument("--model_config", default=None, required=True)
parser.add_argument("--train_config", default=None, required=True)
parser.add_argument("--cv", dest='cv', default=5, type=int)
parser.add_argument("--save_dir", default=None, required=True)
parser.add_argument("--device", default="cuda:2")
parser.add_argument("--lr", default=4e-4, type=float)
parser.add_argument("--lr_scheduler", dest='lr_scheduler', default=None, action='store_true')
parser.add_argument("--seed", default=25, type=int)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--num_workers", default=25, type=int)
parser.add_argument("--save_per_epoch", default=10, type=int)
parser.add_argument("--weight_decay", dest="weight_decay", default=0, type=float)
parser.add_argument("--num_iterations", default=1, type=int)
return parser


def cross_validate(args):

seed = args.seed
np.random.seed(seed)
torch.manual_seed(seed)

cv_dir = args.cv_dir

device = args.device
num_workers = args.num_workers
n_epoch = args.epochs
lr = args.lr
lr_scheduler = args.lr_scheduler
save_per_epoch = args.save_per_epoch
save_dir = args.save_dir
cv = args.cv
weight_decay = args.weight_decay
n_iterations = args.num_iterations

model_config = toml.load(args.model_config)
train_config = toml.load(args.train_config)

if not os.path.exists(save_dir):
os.makedirs(save_dir)

print("Saving training information to {}".format(save_dir))

cv_info = dict()
cv_info["model_config"] = model_config
cv_info["train_config"] = train_config
cv_info["train_config"]["learning_rate"] = lr
cv_info["train_config"]["epochs"] = n_epoch
cv_info["train_config"]["save_per_epoch"] = save_per_epoch
cv_info["train_config"]["weight_decay"] = weight_decay
cv_info["train_config"]["number_of_validation_iterations"] = n_iterations
cv_info["train_config"]["lr_scheduler"] = lr_scheduler
cv_info["train_config"]["cv_folds"] = cv
cv_info["train_config"]["seed"] = seed

with open(os.path.join(save_dir, "cv_info.toml"), 'w') as f:
toml.dump(cv_info, f)

final_test_df = {}
selection_criterions = ['avg_loss', 'roc_auc', 'pr_auc']
columns = ["transcript_id", "transcript_position", "n_reads", "chr", "gene_id", "genomic_position", "kmer", "modification_status", "probability_modified"]

for fold_num in range(1, cv + 1):
fold_dir_save = os.path.join(save_dir, str(fold_num))

if not os.path.exists(fold_dir_save):
os.mkdir(fold_dir_save)

print("Begin running cross validation for fold number {} for a total of {} folds".format( fold_num, cv))

model_config_copy, train_config_copy = deepcopy(model_config), deepcopy(train_config)
fold_dir = os.path.join(cv_dir, str(fold_num))
train_config_copy["dataset"]["site_info"] = fold_dir
train_config_copy["dataset"]["norm_path"] = os.path.join(fold_dir, "norm_dict.joblib")

model = MILModel(model_config_copy).to(device)
train_dl, val_dl, test_dl = build_dataloader(train_config, num_workers)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

criterion = build_loss_function(train_config_copy['loss_function'])

train_results, val_results = train(model, train_dl, val_dl, optimizer, n_epoch, device,
criterion, save_dir=fold_dir_save,
save_per_epoch=save_per_epoch, n_iterations=n_iterations)

joblib.dump(train_results, os.path.join(fold_dir_save, "train_results.joblib"))
joblib.dump(val_results, os.path.join(fold_dir_save, "val_results.joblib"))

for selection_criterion in selection_criterions:
val_loss = [val_results[selection_criterion][i] for i in range (0, len(val_results[selection_criterion]), save_per_epoch)]

if selection_criterion == 'avg_loss':
best_model = (np.argmin(val_loss) + 1) * save_per_epoch
else:
best_model = (np.argmax(val_loss) + 1) * save_per_epoch

model.load_state_dict(torch.load(os.path.join(fold_dir_save, "model_states", str(best_model), "model_states.pt")))
test_results = validate(model, test_dl, device, criterion, n_iterations)
print("Criteria: {criteria} \t"
"Compute time: {compute_time:.3f}".format(criteria=selection_criterion, compute_time=test_results["compute_time"]))
print("Test Loss: {loss:.3f} \t"
"Test ROC AUC: {roc_auc:.3f} \t "
"Test PR AUC: {pr_auc:.3f}".format(loss=test_results["avg_loss"],
roc_auc=test_results["roc_auc"],
pr_auc=test_results["pr_auc"]))
print("=====================================")

joblib.dump(test_results, os.path.join(save_dir, "test_results_{}.joblib".format(selection_criterion)))
test_df = deepcopy(test_dl.dataset.data_info)
test_df.loc[:, "probability_modified"] = np.mean(test_results["y_pred"], axis=0)
test_df = test_df[columns]

if selection_criterion not in final_test_df:
final_test_df[selection_criterion] = [test_df]
else:
final_test_df[selection_criterion].append(test_df)

for selection_criterion in selection_criterions:
pd.concat(final_test_df[selection_criterion]).reset_index(drop=True).to_csv(os.path.join(save_dir, "test_results_{}.csv.gz".format(selection_criterion)),
index=False)


def main():
args = argparser().parse_args()
cross_validate(args)


if __name__ == '__main__':
main()
Loading

0 comments on commit 530d451

Please sign in to comment.