From 19aacbccd172c4d2fc5a69c0e33a44cc8d5b767f Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Fri, 23 Jun 2023 20:40:07 -0700 Subject: [PATCH] Add ability to classify bins (#40) * Set up code for bins - Add logic for loading contigs vs bins confidence models - Start logic for taking weighted average of sequences within a bin - Fix bug in batching of chunks within a sequence * Update placeholders for confidence model files * Update TOML with new package data * add NN model to TOML * Update PR template links, fix #33 * update OSF location and checksum * Fx check for downloaded deployment package * Drop sparse tensor use for arm64 compatibility * Remove use of sparse tensors for collapsing sequences * Add support for passing in multiple fasta files * Fix ruff issues * Fix ruff issues * Infer bins/contigs from header * Update tests * Fix sequence tests * update test file * Update filtered classification for bins * Remove bad flag --- .github/pull_request_template.md | 4 +- .github/workflows/run_tests.yml | 22 +++++- data/small.raw.csv | 6 +- data/small.seqs.raw.csv | 4 + data/small.seqs.tax.0.05.csv | 4 + data/small.tax.0.05.csv | 6 +- pyproject.toml | 5 +- src/gtnet/classify.py | 9 ++- src/gtnet/deploy_pkg/{ => bins}/class.roc.npz | 0 .../deploy_pkg/{ => bins}/class.score.pt | 0 .../deploy_pkg/{ => bins}/domain.roc.npz | 0 .../deploy_pkg/{ => bins}/domain.score.pt | 0 .../deploy_pkg/{ => bins}/family.roc.npz | 0 .../deploy_pkg/{ => bins}/family.score.pt | 0 src/gtnet/deploy_pkg/{ => bins}/genus.roc.npz | 0 .../deploy_pkg/{ => bins}/genus.score.pt | 0 src/gtnet/deploy_pkg/{ => bins}/order.roc.npz | 0 .../deploy_pkg/{ => bins}/order.score.pt | 0 .../deploy_pkg/{ => bins}/phylum.roc.npz | 0 .../deploy_pkg/{ => bins}/phylum.score.pt | 0 .../deploy_pkg/{ => bins}/species.roc.npz | 0 .../deploy_pkg/{ => bins}/species.score.pt | 0 src/gtnet/deploy_pkg/contigs/class.roc.npz | 0 src/gtnet/deploy_pkg/contigs/class.score.pt | 0 src/gtnet/deploy_pkg/contigs/domain.roc.npz | 0 src/gtnet/deploy_pkg/contigs/domain.score.pt | 0 src/gtnet/deploy_pkg/contigs/family.roc.npz | 0 src/gtnet/deploy_pkg/contigs/family.score.pt | 0 src/gtnet/deploy_pkg/contigs/genus.roc.npz | 0 src/gtnet/deploy_pkg/contigs/genus.score.pt | 0 src/gtnet/deploy_pkg/contigs/order.roc.npz | 0 src/gtnet/deploy_pkg/contigs/order.score.pt | 0 src/gtnet/deploy_pkg/contigs/phylum.roc.npz | 0 src/gtnet/deploy_pkg/contigs/phylum.score.pt | 0 src/gtnet/deploy_pkg/contigs/species.roc.npz | 0 src/gtnet/deploy_pkg/contigs/species.score.pt | 0 src/gtnet/filter.py | 13 +++- src/gtnet/predict.py | 75 +++++++++++++++---- src/gtnet/utils.py | 30 ++++---- 39 files changed, 127 insertions(+), 51 deletions(-) create mode 100644 data/small.seqs.raw.csv create mode 100644 data/small.seqs.tax.0.05.csv rename src/gtnet/deploy_pkg/{ => bins}/class.roc.npz (100%) rename src/gtnet/deploy_pkg/{ => bins}/class.score.pt (100%) rename src/gtnet/deploy_pkg/{ => bins}/domain.roc.npz (100%) rename src/gtnet/deploy_pkg/{ => bins}/domain.score.pt (100%) rename src/gtnet/deploy_pkg/{ => bins}/family.roc.npz (100%) rename src/gtnet/deploy_pkg/{ => bins}/family.score.pt (100%) rename src/gtnet/deploy_pkg/{ => bins}/genus.roc.npz (100%) rename src/gtnet/deploy_pkg/{ => bins}/genus.score.pt (100%) rename src/gtnet/deploy_pkg/{ => bins}/order.roc.npz (100%) rename src/gtnet/deploy_pkg/{ => bins}/order.score.pt (100%) rename src/gtnet/deploy_pkg/{ => bins}/phylum.roc.npz (100%) rename src/gtnet/deploy_pkg/{ => bins}/phylum.score.pt (100%) rename src/gtnet/deploy_pkg/{ => bins}/species.roc.npz (100%) rename src/gtnet/deploy_pkg/{ => bins}/species.score.pt (100%) create mode 100644 src/gtnet/deploy_pkg/contigs/class.roc.npz create mode 100644 src/gtnet/deploy_pkg/contigs/class.score.pt create mode 100644 src/gtnet/deploy_pkg/contigs/domain.roc.npz create mode 100644 src/gtnet/deploy_pkg/contigs/domain.score.pt create mode 100644 src/gtnet/deploy_pkg/contigs/family.roc.npz create mode 100644 src/gtnet/deploy_pkg/contigs/family.score.pt create mode 100644 src/gtnet/deploy_pkg/contigs/genus.roc.npz create mode 100644 src/gtnet/deploy_pkg/contigs/genus.score.pt create mode 100644 src/gtnet/deploy_pkg/contigs/order.roc.npz create mode 100644 src/gtnet/deploy_pkg/contigs/order.score.pt create mode 100644 src/gtnet/deploy_pkg/contigs/phylum.roc.npz create mode 100644 src/gtnet/deploy_pkg/contigs/phylum.score.pt create mode 100644 src/gtnet/deploy_pkg/contigs/species.roc.npz create mode 100644 src/gtnet/deploy_pkg/contigs/species.score.pt diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 0f6a52a..9e1c9f1 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -10,8 +10,8 @@ Show how to reproduce the new behavior (can be a bug fix or a new feature) ## Checklist - [ ] Did you update CHANGELOG.md with your changes? -- [ ] Have you checked our [Contributing](https://github.com/hdmf-dev/hdmf-ml/blob/dev/docs/CONTRIBUTING.rst) document? +- [ ] Have you checked our [Contributing](https://github.com/exabiome/gtnet/blob/dev/docs/CONTRIBUTING.rst) document? - [ ] Have you ensured the PR clearly describes the problem and the solution? - [ ] Is your contribution compliant with our coding style? This can be checked running `ruff` from the source directory. -- [ ] Have you checked to ensure that there aren't other open [Pull Requests](https://github.com/hdmf-dev/hdmf-ml/pulls) for the same change? +- [ ] Have you checked to ensure that there aren't other open [Pull Requests](https://github.com/exabiome/gtnet/pulls) for the same change? - [ ] Have you included the relevant issue number using "Fix #XXX" notation where XXX is the issue number? By including "Fix #XXX" you allow GitHub to close issue #XXX when the PR is merged. diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 716e1c1..76226e9 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -47,17 +47,33 @@ jobs: run: | pytest - - name: Run GTNet Predict + - name: Run GTNet Predict on Bins run: | gtnet predict data/small.fna > data/small.raw.test.csv python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.raw.csv"), pd.read_csv("data/small.raw.test.csv"), check_exact=False, atol=1e-4)' - - name: Run GTNet Filter + - name: Run GTNet Filter on Bins run: | gtnet filter --fpr 0.05 data/small.raw.csv > data/small.tax.test.csv python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.tax.0.05.csv"), pd.read_csv("data/small.tax.test.csv"), check_exact=False, atol=1e-4)' - - name: Run GTNet Classify + - name: Run GTNet Classify on Bins run: | gtnet classify --fpr 0.05 data/small.fna > data/small.tax.test.csv python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.tax.0.05.csv"), pd.read_csv("data/small.tax.test.csv"), check_exact=False, atol=1e-4)' + + - name: Run GTNet Predict on Contigs + run: | + gtnet predict --seqs data/small.fna > data/small.seqs.raw.test.csv + python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.seqs.raw.csv"), pd.read_csv("data/small.seqs.raw.test.csv"), check_exact=False, atol=1e-4)' + + - name: Run GTNet Filter on Contigs + run: | + gtnet filter --fpr 0.05 data/small.seqs.raw.csv > data/small.seqs.tax.test.csv + python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.seqs.tax.0.05.csv"), pd.read_csv("data/small.seqs.tax.test.csv"), check_exact=False, atol=1e-4)' + + - name: Run GTNet Classify on Bins + run: | + gtnet classify --seqs --fpr 0.05 data/small.fna > data/small.seqs.tax.test.csv + python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.seqs.tax.0.05.csv"), pd.read_csv("data/small.seqs.tax.test.csv"), check_exact=False, atol=1e-4)' + diff --git a/data/small.raw.csv b/data/small.raw.csv index f8fb75d..b8da834 100644 --- a/data/small.raw.csv +++ b/data/small.raw.csv @@ -1,4 +1,2 @@ -ID,domain,domain_prob,phylum,phylum_prob,class,class_prob,order,order_prob,family,family_prob,genus,genus_prob,species,species_prob -AAAC01000001.1,d__Bacteria,0.92493016,p__Firmicutes,0.9686205,c__Bacilli,0.96250373,o__Bacillales,0.97850364,f__Bacillaceae_G,0.97221136,g__Bacillus_A,0.985249,s__Bacillus_A thuringiensis,0.45970166 -AE011190.1,d__Bacteria,0.7773866,p__Firmicutes,0.7861459,c__Bacilli,0.7775143,o__Bacillales,0.31985036,f__Bacillaceae_G,0.2585766,g__Bacillus_A,0.45360962,s__Bacillus_A thuringiensis,0.19787389 -AE011191.1,d__Bacteria,0.7986983,p__Firmicutes,0.8659906,c__Bacilli,0.86254656,o__Bacillales,0.681745,f__Bacillaceae_G,0.6001527,g__Bacillus_A,0.78373003,s__Bacillus_A thuringiensis,0.3619854 +file,domain,domain_prob,phylum,phylum_prob,class,class_prob,order,order_prob,family,family_prob,genus,genus_prob,species,species_prob +data/small.fna,d__Bacteria,0.9684392,p__Firmicutes,0.9986811,c__Bacilli,0.96831495,o__Bacillales,0.98408806,f__Bacillaceae_G,0.99721485,g__Bacillus_A,0.9942883,s__Bacillus_A thuringiensis,0.4345817 diff --git a/data/small.seqs.raw.csv b/data/small.seqs.raw.csv new file mode 100644 index 0000000..baf3674 --- /dev/null +++ b/data/small.seqs.raw.csv @@ -0,0 +1,4 @@ +file,ID,domain,domain_prob,phylum,phylum_prob,class,class_prob,order,order_prob,family,family_prob,genus,genus_prob,species,species_prob +data/small.fna,AAAC01000001.1,d__Bacteria,0.9678016,p__Firmicutes,0.97006994,c__Bacilli,0.9555465,o__Bacillales,0.9729916,f__Bacillaceae_G,0.969557,g__Bacillus_A,0.9896825,s__Bacillus_A thuringiensis,0.45554155 +data/small.fna,AE011190.1,d__Bacteria,0.86854714,p__Firmicutes,0.79800624,c__Bacilli,0.7794117,o__Bacillales,0.3070697,f__Bacillaceae_G,0.27435768,g__Bacillus_A,0.42017764,s__Bacillus_A thuringiensis,0.18835706 +data/small.fna,AE011191.1,d__Bacteria,0.8864407,p__Firmicutes,0.8749427,c__Bacilli,0.85708725,o__Bacillales,0.6319907,f__Bacillaceae_G,0.6114752,g__Bacillus_A,0.7830288,s__Bacillus_A thuringiensis,0.35550776 diff --git a/data/small.seqs.tax.0.05.csv b/data/small.seqs.tax.0.05.csv new file mode 100644 index 0000000..aad6f73 --- /dev/null +++ b/data/small.seqs.tax.0.05.csv @@ -0,0 +1,4 @@ +file,ID,domain,phylum,class,order,family,genus,species +data/small.fna,AAAC01000001.1,d__Bacteria,p__Firmicutes,c__Bacilli,o__Bacillales,f__Bacillaceae_G,g__Bacillus_A, +data/small.fna,AE011190.1,,,,,,, +data/small.fna,AE011191.1,d__Bacteria,,,,,, diff --git a/data/small.tax.0.05.csv b/data/small.tax.0.05.csv index 4784669..7e15bb3 100644 --- a/data/small.tax.0.05.csv +++ b/data/small.tax.0.05.csv @@ -1,4 +1,2 @@ -ID,domain,phylum,class,order,family,genus,species -AAAC01000001.1,d__Bacteria,p__Firmicutes,c__Bacilli,o__Bacillales,f__Bacillaceae_G,g__Bacillus_A, -AE011190.1,,,,,,, -AE011191.1,d__Bacteria,,,,,, +file,domain,phylum,class,order,family,genus,species +data/small.fna,d__Bacteria,p__Firmicutes,c__Bacilli,o__Bacillales,f__Bacillaceae_G,g__Bacillus_A, diff --git a/pyproject.toml b/pyproject.toml index c75ba0c..f729963 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,8 +44,9 @@ dependencies = [ dynamic = ["version"] [tool.setuptools.package-data] -gtnet = ["deploy_pkg/*.npz", - "deploy_pkg/*.pt", +gtnet = ["deploy_pkg/{bins,contigs}/*.npz", + "deploy_pkg/{bins,contigs}/*.pt", + "deploy_pkg/last.pt", "deploy_pkg/manifest.json"] [project.urls] diff --git a/src/gtnet/classify.py b/src/gtnet/classify.py index 81f4d92..193789f 100644 --- a/src/gtnet/classify.py +++ b/src/gtnet/classify.py @@ -25,7 +25,8 @@ def classify(argv=None): "for each sequence") parser = argparse.ArgumentParser(description=desc, epilog=epi) - parser.add_argument('fasta', type=str, help='the Fasta files to do taxonomic classification on') + parser.add_argument('fastas', type=str, nargs='+', help='the Fasta files to do taxonomic classification on') + parser.add_argument('-s', '--seqs', action='store_true', help='provide classification for sequences') parser.add_argument('-c', '--n_chunks', type=int, default=DEFAULT_N_CHUNKS, help='the number of sequence chunks to process at a time') parser.add_argument('-o', '--output', type=str, default=None, help='the output file to save classifications to') @@ -44,13 +45,13 @@ def classify(argv=None): device = check_device(args) - model, conf_models, train_conf, vocab, rocs = load_deploy_pkg(for_predict=True, for_filter=True) + model, conf_models, train_conf, vocab, rocs = load_deploy_pkg(for_predict=True, for_filter=True, contigs=args.seqs) window = train_conf['window'] step = train_conf['step'] - logger.info(f'Getting class predictions for each contig in {args.fasta}') - output = run_torchscript_inference(args.fasta, model, conf_models, window, step, vocab, + logger.info(f'Getting class predictions for each contig in {",".join(args.fastas)}') + output = run_torchscript_inference(args.fastas, model, conf_models, window, step, vocab, seqs=args.seqs, device=device, logger=logger) logger.info(f'Getting probability cutoffs for target false-positive rate of {args.fpr}') diff --git a/src/gtnet/deploy_pkg/class.roc.npz b/src/gtnet/deploy_pkg/bins/class.roc.npz similarity index 100% rename from src/gtnet/deploy_pkg/class.roc.npz rename to src/gtnet/deploy_pkg/bins/class.roc.npz diff --git a/src/gtnet/deploy_pkg/class.score.pt b/src/gtnet/deploy_pkg/bins/class.score.pt similarity index 100% rename from src/gtnet/deploy_pkg/class.score.pt rename to src/gtnet/deploy_pkg/bins/class.score.pt diff --git a/src/gtnet/deploy_pkg/domain.roc.npz b/src/gtnet/deploy_pkg/bins/domain.roc.npz similarity index 100% rename from src/gtnet/deploy_pkg/domain.roc.npz rename to src/gtnet/deploy_pkg/bins/domain.roc.npz diff --git a/src/gtnet/deploy_pkg/domain.score.pt b/src/gtnet/deploy_pkg/bins/domain.score.pt similarity index 100% rename from src/gtnet/deploy_pkg/domain.score.pt rename to src/gtnet/deploy_pkg/bins/domain.score.pt diff --git a/src/gtnet/deploy_pkg/family.roc.npz b/src/gtnet/deploy_pkg/bins/family.roc.npz similarity index 100% rename from src/gtnet/deploy_pkg/family.roc.npz rename to src/gtnet/deploy_pkg/bins/family.roc.npz diff --git a/src/gtnet/deploy_pkg/family.score.pt b/src/gtnet/deploy_pkg/bins/family.score.pt similarity index 100% rename from src/gtnet/deploy_pkg/family.score.pt rename to src/gtnet/deploy_pkg/bins/family.score.pt diff --git a/src/gtnet/deploy_pkg/genus.roc.npz b/src/gtnet/deploy_pkg/bins/genus.roc.npz similarity index 100% rename from src/gtnet/deploy_pkg/genus.roc.npz rename to src/gtnet/deploy_pkg/bins/genus.roc.npz diff --git a/src/gtnet/deploy_pkg/genus.score.pt b/src/gtnet/deploy_pkg/bins/genus.score.pt similarity index 100% rename from src/gtnet/deploy_pkg/genus.score.pt rename to src/gtnet/deploy_pkg/bins/genus.score.pt diff --git a/src/gtnet/deploy_pkg/order.roc.npz b/src/gtnet/deploy_pkg/bins/order.roc.npz similarity index 100% rename from src/gtnet/deploy_pkg/order.roc.npz rename to src/gtnet/deploy_pkg/bins/order.roc.npz diff --git a/src/gtnet/deploy_pkg/order.score.pt b/src/gtnet/deploy_pkg/bins/order.score.pt similarity index 100% rename from src/gtnet/deploy_pkg/order.score.pt rename to src/gtnet/deploy_pkg/bins/order.score.pt diff --git a/src/gtnet/deploy_pkg/phylum.roc.npz b/src/gtnet/deploy_pkg/bins/phylum.roc.npz similarity index 100% rename from src/gtnet/deploy_pkg/phylum.roc.npz rename to src/gtnet/deploy_pkg/bins/phylum.roc.npz diff --git a/src/gtnet/deploy_pkg/phylum.score.pt b/src/gtnet/deploy_pkg/bins/phylum.score.pt similarity index 100% rename from src/gtnet/deploy_pkg/phylum.score.pt rename to src/gtnet/deploy_pkg/bins/phylum.score.pt diff --git a/src/gtnet/deploy_pkg/species.roc.npz b/src/gtnet/deploy_pkg/bins/species.roc.npz similarity index 100% rename from src/gtnet/deploy_pkg/species.roc.npz rename to src/gtnet/deploy_pkg/bins/species.roc.npz diff --git a/src/gtnet/deploy_pkg/species.score.pt b/src/gtnet/deploy_pkg/bins/species.score.pt similarity index 100% rename from src/gtnet/deploy_pkg/species.score.pt rename to src/gtnet/deploy_pkg/bins/species.score.pt diff --git a/src/gtnet/deploy_pkg/contigs/class.roc.npz b/src/gtnet/deploy_pkg/contigs/class.roc.npz new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/class.score.pt b/src/gtnet/deploy_pkg/contigs/class.score.pt new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/domain.roc.npz b/src/gtnet/deploy_pkg/contigs/domain.roc.npz new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/domain.score.pt b/src/gtnet/deploy_pkg/contigs/domain.score.pt new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/family.roc.npz b/src/gtnet/deploy_pkg/contigs/family.roc.npz new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/family.score.pt b/src/gtnet/deploy_pkg/contigs/family.score.pt new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/genus.roc.npz b/src/gtnet/deploy_pkg/contigs/genus.roc.npz new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/genus.score.pt b/src/gtnet/deploy_pkg/contigs/genus.score.pt new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/order.roc.npz b/src/gtnet/deploy_pkg/contigs/order.roc.npz new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/order.score.pt b/src/gtnet/deploy_pkg/contigs/order.score.pt new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/phylum.roc.npz b/src/gtnet/deploy_pkg/contigs/phylum.roc.npz new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/phylum.score.pt b/src/gtnet/deploy_pkg/contigs/phylum.score.pt new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/species.roc.npz b/src/gtnet/deploy_pkg/contigs/species.roc.npz new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/deploy_pkg/contigs/species.score.pt b/src/gtnet/deploy_pkg/contigs/species.score.pt new file mode 100644 index 0000000..e69de29 diff --git a/src/gtnet/filter.py b/src/gtnet/filter.py index 0d50410..47d95e1 100644 --- a/src/gtnet/filter.py +++ b/src/gtnet/filter.py @@ -53,11 +53,18 @@ def filter(argv=None): logger = get_logger() - rocs = load_deploy_pkg(for_filter=True) + df = pd.read_csv(args.csv) - cutoffs = get_cutoffs(rocs, args.fpr) + if 'ID' in df.columns: + seqs = True + df = df.set_index(['file', 'ID']) + else: + df = df.set_index('file') + seqs = False + + rocs = load_deploy_pkg(for_filter=True, contigs=seqs) - df = pd.read_csv(args.csv, index_col='ID') + cutoffs = get_cutoffs(rocs, args.fpr) output = filter_predictions(df, cutoffs) write_csv(output, args) diff --git a/src/gtnet/predict.py b/src/gtnet/predict.py index abbd497..92630e1 100644 --- a/src/gtnet/predict.py +++ b/src/gtnet/predict.py @@ -1,4 +1,5 @@ import argparse +from collections import Counter import logging from time import time @@ -28,7 +29,8 @@ def predict(argv=None): "'filter' command. See the 'classify' command for getting pre-filtered classifications") parser = argparse.ArgumentParser(description=desc, epilog=epi) - parser.add_argument('fasta', type=str, help='the Fasta files to do taxonomic classification on') + parser.add_argument('fastas', type=str, nargs='+', help='the Fasta files to do taxonomic classification on') + parser.add_argument('-s', '--seqs', action='store_true', help='provide classification for sequences') parser.add_argument('-c', '--n_chunks', type=int, default=DEFAULT_N_CHUNKS, help='the number of sequence chunks to process at a time') parser.add_argument('-o', '--output', type=str, default=None, help='the output file to save classifications to') @@ -46,12 +48,13 @@ def predict(argv=None): device = check_device(args) - model, conf_models, train_conf, vocab = load_deploy_pkg(for_predict=True) + model, conf_models, train_conf, vocab = load_deploy_pkg(for_predict=True, contigs=args.seqs) window = train_conf['window'] step = train_conf['step'] - output = run_torchscript_inference(args.fasta, model, conf_models, window, step, vocab, device=device) + output = run_torchscript_inference(args.fastas, model, conf_models, window, step, vocab, seqs=args.seqs, + device=device, logger=logger) # write out data write_csv(output, args) @@ -60,14 +63,14 @@ def predict(argv=None): logger.info(f'Took {after - before:.1f} seconds') -def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_chunks=DEFAULT_N_CHUNKS, +def run_torchscript_inference(fastas, model, conf_models, window, step, vocab, seqs=False, n_chunks=DEFAULT_N_CHUNKS, device=torch.device('cpu'), logger=None): """Run Torchscript inference Parameters ---------- - fasta : str + fastas : str The path to the Fasta file with sequences to do inference on model : RecursiveScriptModule @@ -101,7 +104,7 @@ def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_ logger.setLevel(logging.CRITICAL) encoder = FastaSequenceEncoder(window, step, vocab=vocab, device=device) - reader = FastaReader(encoder, fasta) + reader = FastaReader(encoder, *fastas) model = model.to(device) @@ -109,12 +112,14 @@ def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_ seqnames = list() lengths = list() + total_chunks = list() filepaths = list() aggregated = list() torch.set_grad_enabled(False) - logger.info(f'Calculating classifications for all sequences in {fasta}') + + logger.info(f'Calculating classifications for all sequences in {", ".join(fastas)}') for file_path, seq_name, seq_len, seq_chunks in reader: seqnames.append(seq_name) lengths.append(seq_len) @@ -124,24 +129,64 @@ def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_ f'{seq_chunks.shape[1] * 2} chunks, {lengths[-1]} bases')) outputs = torch.zeros(output_size, device=device) # the output from the network for a single sequence # sum network outputs from all chunks - for s in range(0, len(seq_chunks), n_chunks): + for s in range(0, seq_chunks.shape[1], n_chunks): e = s + n_chunks outputs += model(seq_chunks[0, s:e]).sum(dim=0) outputs += model(seq_chunks[1, s:e]).sum(dim=0) # divide by the number of seq_chunks we processed to get a mean output - outputs /= (seq_chunks.shape[1] * 2) - del seq_chunks - aggregated.append(outputs) + total_chunks.append(seq_chunks.shape[1] * 2.) - lengths = torch.tensor(lengths, device=device) + del seq_chunks + total_chunks = torch.tensor(total_chunks, device=device) # aggregate everything we just pulled from the fasta file all_levels_aggregated = torch.row_stack(aggregated) del aggregated - output_data = {'ID': seqnames} + if not seqs: + logger.info('Calculating classifications for bins') + + ctr = Counter(filepaths) + n_ctgs = list(ctr.values()) + filepaths = list(ctr.keys()) + + max_len = list() + l50 = list() + lengths = torch.tensor(lengths) + + tmp_chunks = list() + tmp_aggregated = list() + + s = 0 + for n in n_ctgs: + e = s + n + tmp_lens = lengths[s:e].sort(descending=True).values + max_len.append(tmp_lens[0]) + csum = torch.cumsum(tmp_lens, 0) + l50.append(tmp_lens[torch.where(csum > (csum[-1] * 0.5))[0][0]]) + tmp_aggregated.append(all_levels_aggregated[s:e].sum(axis=0)) + tmp_chunks.append(total_chunks[s:e].sum()) + s = e + + del total_chunks + total_chunks = torch.tensor(tmp_chunks) + del tmp_chunks + + del all_levels_aggregated + all_levels_aggregated = torch.row_stack(tmp_aggregated) + del tmp_aggregated + + features = torch.tensor([n_ctgs, l50, max_len], device=device).T + output_data = {'file': filepaths} + else: + features = torch.tensor(lengths, device=device)[:, None] + output_data = {'file': filepaths, 'ID': seqnames} + + + all_levels_aggregated = ((all_levels_aggregated.T) / total_chunks).T + indices = list(output_data.keys()) s = 0 for lvl, e in zip(model.levels, model.parse): @@ -163,13 +208,13 @@ def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_ # build input matrix for confidence model logger.debug('Calculating confidence probabilities') - conf_input = torch.column_stack([lengths, maxprobs]) + conf_input = torch.column_stack([features, maxprobs]) output_data[f'{lvl}_prob'] = conf_model(conf_input).cpu().numpy().squeeze() # set next left bound for all_levels_aggregated s = e - output = pd.DataFrame(output_data).set_index('ID') + output = pd.DataFrame(output_data).set_index(indices) return output diff --git a/src/gtnet/utils.py b/src/gtnet/utils.py index a4d8ba0..06964e1 100644 --- a/src/gtnet/utils.py +++ b/src/gtnet/utils.py @@ -1,4 +1,3 @@ -import glob import hashlib from importlib.resources import files import json @@ -34,16 +33,16 @@ def get_logger(): class DeployPkg: """A class to handle loading and manipulating the deployment package""" - _deploy_pkg_url = "https://osf.io/download/mwgb9/" + _deploy_pkg_url = "https://osf.io/download/qf46x/" - _checksum = "0245fcf825bfe4de0770fcb46798ca90" + _checksum = "623aa991fb0d74e874b7d0da25496c26" + + _manifest_name = 'manifest.json' @classmethod def check_pkg(cls): deploy_dir = files(__package__).joinpath('deploy_pkg') - total = 0 - for path in glob.glob(f"{deploy_dir}/*"): - total += os.path.getsize(path) + total = os.path.getsize(os.path.join(deploy_dir, cls._manifest_name)) if total == 0: msg = ("Downloading GTNet deployment package. This will only happen on the first invocation " "of gtnet predict or gtnet classify") @@ -71,7 +70,7 @@ def path(self, path): @property def manifest(self): if self._manifest is None: - with open(self.path('manifest.json'), 'r') as f: + with open(self.path(self._manifest_name), 'r') as f: self._manifest = json.load(f) return self._manifest @@ -82,22 +81,25 @@ def __setitem__(self, key, val): self.manifest[key] = val -def load_deploy_pkg(for_predict=False, for_filter=False): +def load_deploy_pkg(for_predict=False, for_filter=False, contigs=False): if not (for_predict or for_filter): for_predict = True for_filter = True pkg = DeployPkg() + key = 'contigs' if contigs else 'bins' ret = list() if for_predict: tmp_conf_model = dict() - for lvl_dat in pkg['conf_model']: - lvl_dat['taxa'] = np.array(lvl_dat['taxa']) + for cm_data, taxa_data in zip(pkg['conf_model'][key], pkg['taxa']): + if cm_data['level'] != taxa_data['level']: + raise ValueError("Taxonomic levels are out of order in manifest file") + cm_data['taxa'] = np.array(taxa_data['taxa']) - lvl_dat['model'] = torch.jit.load(pkg.path(lvl_dat.pop('model'))) + cm_data['model'] = torch.jit.load(pkg.path(cm_data.pop('model'))) - tmp_conf_model[lvl_dat['level']] = lvl_dat + tmp_conf_model[cm_data['level']] = cm_data ret.append(torch.jit.load(pkg.path(pkg['nn_model']))) ret.append(tmp_conf_model) @@ -106,8 +108,8 @@ def load_deploy_pkg(for_predict=False, for_filter=False): if for_filter: tmp_roc = dict() - for lvl_dat in pkg['conf_model']: - tmp_roc[lvl_dat['level']] = np.load(pkg.path(lvl_dat['roc'])) + for cm_data in pkg['conf_model'][key]: + tmp_roc[cm_data['level']] = np.load(pkg.path(cm_data['roc'])) ret.append(tmp_roc) return tuple(ret) if len(ret) > 1 else ret[0]