Skip to content

Commit

Permalink
Add ability to classify bins (#40)
Browse files Browse the repository at this point in the history
* Set up code for bins

- Add logic for loading contigs vs bins confidence models
- Start logic for taking weighted average of sequences within a bin
- Fix bug in batching of chunks within a sequence

* Update placeholders for confidence model files

* Update TOML with new package data

* add NN model to TOML

* Update PR template links, fix #33

* update OSF location and checksum

* Fx check for downloaded deployment package

* Drop sparse tensor use for arm64 compatibility

* Remove use of sparse tensors for collapsing sequences

* Add support for passing in multiple fasta files

* Fix ruff issues

* Fix ruff issues

* Infer bins/contigs from header

* Update tests

* Fix sequence tests

* update test file

* Update filtered classification for bins

* Remove bad flag
  • Loading branch information
ajtritt authored Jun 24, 2023
1 parent b7762f7 commit 19aacbc
Show file tree
Hide file tree
Showing 39 changed files with 127 additions and 51 deletions.
4 changes: 2 additions & 2 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ Show how to reproduce the new behavior (can be a bug fix or a new feature)
## Checklist

- [ ] Did you update CHANGELOG.md with your changes?
- [ ] Have you checked our [Contributing](https://github.com/hdmf-dev/hdmf-ml/blob/dev/docs/CONTRIBUTING.rst) document?
- [ ] Have you checked our [Contributing](https://github.com/exabiome/gtnet/blob/dev/docs/CONTRIBUTING.rst) document?
- [ ] Have you ensured the PR clearly describes the problem and the solution?
- [ ] Is your contribution compliant with our coding style? This can be checked running `ruff` from the source directory.
- [ ] Have you checked to ensure that there aren't other open [Pull Requests](https://github.com/hdmf-dev/hdmf-ml/pulls) for the same change?
- [ ] Have you checked to ensure that there aren't other open [Pull Requests](https://github.com/exabiome/gtnet/pulls) for the same change?
- [ ] Have you included the relevant issue number using "Fix #XXX" notation where XXX is the issue number? By including "Fix #XXX" you allow GitHub to close issue #XXX when the PR is merged.
22 changes: 19 additions & 3 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,33 @@ jobs:
run: |
pytest
- name: Run GTNet Predict
- name: Run GTNet Predict on Bins
run: |
gtnet predict data/small.fna > data/small.raw.test.csv
python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.raw.csv"), pd.read_csv("data/small.raw.test.csv"), check_exact=False, atol=1e-4)'
- name: Run GTNet Filter
- name: Run GTNet Filter on Bins
run: |
gtnet filter --fpr 0.05 data/small.raw.csv > data/small.tax.test.csv
python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.tax.0.05.csv"), pd.read_csv("data/small.tax.test.csv"), check_exact=False, atol=1e-4)'
- name: Run GTNet Classify
- name: Run GTNet Classify on Bins
run: |
gtnet classify --fpr 0.05 data/small.fna > data/small.tax.test.csv
python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.tax.0.05.csv"), pd.read_csv("data/small.tax.test.csv"), check_exact=False, atol=1e-4)'
- name: Run GTNet Predict on Contigs
run: |
gtnet predict --seqs data/small.fna > data/small.seqs.raw.test.csv
python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.seqs.raw.csv"), pd.read_csv("data/small.seqs.raw.test.csv"), check_exact=False, atol=1e-4)'
- name: Run GTNet Filter on Contigs
run: |
gtnet filter --fpr 0.05 data/small.seqs.raw.csv > data/small.seqs.tax.test.csv
python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.seqs.tax.0.05.csv"), pd.read_csv("data/small.seqs.tax.test.csv"), check_exact=False, atol=1e-4)'
- name: Run GTNet Classify on Bins
run: |
gtnet classify --seqs --fpr 0.05 data/small.fna > data/small.seqs.tax.test.csv
python -c 'import pandas as pd; pd.testing.assert_frame_equal(pd.read_csv("data/small.seqs.tax.0.05.csv"), pd.read_csv("data/small.seqs.tax.test.csv"), check_exact=False, atol=1e-4)'
6 changes: 2 additions & 4 deletions data/small.raw.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
ID,domain,domain_prob,phylum,phylum_prob,class,class_prob,order,order_prob,family,family_prob,genus,genus_prob,species,species_prob
AAAC01000001.1,d__Bacteria,0.92493016,p__Firmicutes,0.9686205,c__Bacilli,0.96250373,o__Bacillales,0.97850364,f__Bacillaceae_G,0.97221136,g__Bacillus_A,0.985249,s__Bacillus_A thuringiensis,0.45970166
AE011190.1,d__Bacteria,0.7773866,p__Firmicutes,0.7861459,c__Bacilli,0.7775143,o__Bacillales,0.31985036,f__Bacillaceae_G,0.2585766,g__Bacillus_A,0.45360962,s__Bacillus_A thuringiensis,0.19787389
AE011191.1,d__Bacteria,0.7986983,p__Firmicutes,0.8659906,c__Bacilli,0.86254656,o__Bacillales,0.681745,f__Bacillaceae_G,0.6001527,g__Bacillus_A,0.78373003,s__Bacillus_A thuringiensis,0.3619854
file,domain,domain_prob,phylum,phylum_prob,class,class_prob,order,order_prob,family,family_prob,genus,genus_prob,species,species_prob
data/small.fna,d__Bacteria,0.9684392,p__Firmicutes,0.9986811,c__Bacilli,0.96831495,o__Bacillales,0.98408806,f__Bacillaceae_G,0.99721485,g__Bacillus_A,0.9942883,s__Bacillus_A thuringiensis,0.4345817
4 changes: 4 additions & 0 deletions data/small.seqs.raw.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
file,ID,domain,domain_prob,phylum,phylum_prob,class,class_prob,order,order_prob,family,family_prob,genus,genus_prob,species,species_prob
data/small.fna,AAAC01000001.1,d__Bacteria,0.9678016,p__Firmicutes,0.97006994,c__Bacilli,0.9555465,o__Bacillales,0.9729916,f__Bacillaceae_G,0.969557,g__Bacillus_A,0.9896825,s__Bacillus_A thuringiensis,0.45554155
data/small.fna,AE011190.1,d__Bacteria,0.86854714,p__Firmicutes,0.79800624,c__Bacilli,0.7794117,o__Bacillales,0.3070697,f__Bacillaceae_G,0.27435768,g__Bacillus_A,0.42017764,s__Bacillus_A thuringiensis,0.18835706
data/small.fna,AE011191.1,d__Bacteria,0.8864407,p__Firmicutes,0.8749427,c__Bacilli,0.85708725,o__Bacillales,0.6319907,f__Bacillaceae_G,0.6114752,g__Bacillus_A,0.7830288,s__Bacillus_A thuringiensis,0.35550776
4 changes: 4 additions & 0 deletions data/small.seqs.tax.0.05.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
file,ID,domain,phylum,class,order,family,genus,species
data/small.fna,AAAC01000001.1,d__Bacteria,p__Firmicutes,c__Bacilli,o__Bacillales,f__Bacillaceae_G,g__Bacillus_A,
data/small.fna,AE011190.1,,,,,,,
data/small.fna,AE011191.1,d__Bacteria,,,,,,
6 changes: 2 additions & 4 deletions data/small.tax.0.05.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
ID,domain,phylum,class,order,family,genus,species
AAAC01000001.1,d__Bacteria,p__Firmicutes,c__Bacilli,o__Bacillales,f__Bacillaceae_G,g__Bacillus_A,
AE011190.1,,,,,,,
AE011191.1,d__Bacteria,,,,,,
file,domain,phylum,class,order,family,genus,species
data/small.fna,d__Bacteria,p__Firmicutes,c__Bacilli,o__Bacillales,f__Bacillaceae_G,g__Bacillus_A,
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ dependencies = [
dynamic = ["version"]

[tool.setuptools.package-data]
gtnet = ["deploy_pkg/*.npz",
"deploy_pkg/*.pt",
gtnet = ["deploy_pkg/{bins,contigs}/*.npz",
"deploy_pkg/{bins,contigs}/*.pt",
"deploy_pkg/last.pt",
"deploy_pkg/manifest.json"]

[project.urls]
Expand Down
9 changes: 5 additions & 4 deletions src/gtnet/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def classify(argv=None):
"for each sequence")

parser = argparse.ArgumentParser(description=desc, epilog=epi)
parser.add_argument('fasta', type=str, help='the Fasta files to do taxonomic classification on')
parser.add_argument('fastas', type=str, nargs='+', help='the Fasta files to do taxonomic classification on')
parser.add_argument('-s', '--seqs', action='store_true', help='provide classification for sequences')
parser.add_argument('-c', '--n_chunks', type=int, default=DEFAULT_N_CHUNKS,
help='the number of sequence chunks to process at a time')
parser.add_argument('-o', '--output', type=str, default=None, help='the output file to save classifications to')
Expand All @@ -44,13 +45,13 @@ def classify(argv=None):

device = check_device(args)

model, conf_models, train_conf, vocab, rocs = load_deploy_pkg(for_predict=True, for_filter=True)
model, conf_models, train_conf, vocab, rocs = load_deploy_pkg(for_predict=True, for_filter=True, contigs=args.seqs)

window = train_conf['window']
step = train_conf['step']

logger.info(f'Getting class predictions for each contig in {args.fasta}')
output = run_torchscript_inference(args.fasta, model, conf_models, window, step, vocab,
logger.info(f'Getting class predictions for each contig in {",".join(args.fastas)}')
output = run_torchscript_inference(args.fastas, model, conf_models, window, step, vocab, seqs=args.seqs,
device=device, logger=logger)

logger.info(f'Getting probability cutoffs for target false-positive rate of {args.fpr}')
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
13 changes: 10 additions & 3 deletions src/gtnet/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,18 @@ def filter(argv=None):

logger = get_logger()

rocs = load_deploy_pkg(for_filter=True)
df = pd.read_csv(args.csv)

cutoffs = get_cutoffs(rocs, args.fpr)
if 'ID' in df.columns:
seqs = True
df = df.set_index(['file', 'ID'])
else:
df = df.set_index('file')
seqs = False

rocs = load_deploy_pkg(for_filter=True, contigs=seqs)

df = pd.read_csv(args.csv, index_col='ID')
cutoffs = get_cutoffs(rocs, args.fpr)

output = filter_predictions(df, cutoffs)
write_csv(output, args)
Expand Down
75 changes: 60 additions & 15 deletions src/gtnet/predict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from collections import Counter
import logging
from time import time

Expand Down Expand Up @@ -28,7 +29,8 @@ def predict(argv=None):
"'filter' command. See the 'classify' command for getting pre-filtered classifications")

parser = argparse.ArgumentParser(description=desc, epilog=epi)
parser.add_argument('fasta', type=str, help='the Fasta files to do taxonomic classification on')
parser.add_argument('fastas', type=str, nargs='+', help='the Fasta files to do taxonomic classification on')
parser.add_argument('-s', '--seqs', action='store_true', help='provide classification for sequences')
parser.add_argument('-c', '--n_chunks', type=int, default=DEFAULT_N_CHUNKS,
help='the number of sequence chunks to process at a time')
parser.add_argument('-o', '--output', type=str, default=None, help='the output file to save classifications to')
Expand All @@ -46,12 +48,13 @@ def predict(argv=None):

device = check_device(args)

model, conf_models, train_conf, vocab = load_deploy_pkg(for_predict=True)
model, conf_models, train_conf, vocab = load_deploy_pkg(for_predict=True, contigs=args.seqs)

window = train_conf['window']
step = train_conf['step']

output = run_torchscript_inference(args.fasta, model, conf_models, window, step, vocab, device=device)
output = run_torchscript_inference(args.fastas, model, conf_models, window, step, vocab, seqs=args.seqs,
device=device, logger=logger)

# write out data
write_csv(output, args)
Expand All @@ -60,14 +63,14 @@ def predict(argv=None):
logger.info(f'Took {after - before:.1f} seconds')


def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_chunks=DEFAULT_N_CHUNKS,
def run_torchscript_inference(fastas, model, conf_models, window, step, vocab, seqs=False, n_chunks=DEFAULT_N_CHUNKS,
device=torch.device('cpu'), logger=None):
"""Run Torchscript inference
Parameters
----------
fasta : str
fastas : str
The path to the Fasta file with sequences to do inference on
model : RecursiveScriptModule
Expand Down Expand Up @@ -101,20 +104,22 @@ def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_
logger.setLevel(logging.CRITICAL)

encoder = FastaSequenceEncoder(window, step, vocab=vocab, device=device)
reader = FastaReader(encoder, fasta)
reader = FastaReader(encoder, *fastas)

model = model.to(device)

output_size = sum(len(lvl['taxa']) for lvl in conf_models.values())

seqnames = list()
lengths = list()
total_chunks = list()
filepaths = list()
aggregated = list()

torch.set_grad_enabled(False)

logger.info(f'Calculating classifications for all sequences in {fasta}')

logger.info(f'Calculating classifications for all sequences in {", ".join(fastas)}')
for file_path, seq_name, seq_len, seq_chunks in reader:
seqnames.append(seq_name)
lengths.append(seq_len)
Expand All @@ -124,24 +129,64 @@ def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_
f'{seq_chunks.shape[1] * 2} chunks, {lengths[-1]} bases'))
outputs = torch.zeros(output_size, device=device) # the output from the network for a single sequence
# sum network outputs from all chunks
for s in range(0, len(seq_chunks), n_chunks):
for s in range(0, seq_chunks.shape[1], n_chunks):
e = s + n_chunks
outputs += model(seq_chunks[0, s:e]).sum(dim=0)
outputs += model(seq_chunks[1, s:e]).sum(dim=0)
# divide by the number of seq_chunks we processed to get a mean output
outputs /= (seq_chunks.shape[1] * 2)
del seq_chunks

aggregated.append(outputs)
total_chunks.append(seq_chunks.shape[1] * 2.)

lengths = torch.tensor(lengths, device=device)
del seq_chunks

total_chunks = torch.tensor(total_chunks, device=device)

# aggregate everything we just pulled from the fasta file
all_levels_aggregated = torch.row_stack(aggregated)
del aggregated

output_data = {'ID': seqnames}
if not seqs:
logger.info('Calculating classifications for bins')

ctr = Counter(filepaths)
n_ctgs = list(ctr.values())
filepaths = list(ctr.keys())

max_len = list()
l50 = list()
lengths = torch.tensor(lengths)

tmp_chunks = list()
tmp_aggregated = list()

s = 0
for n in n_ctgs:
e = s + n
tmp_lens = lengths[s:e].sort(descending=True).values
max_len.append(tmp_lens[0])
csum = torch.cumsum(tmp_lens, 0)
l50.append(tmp_lens[torch.where(csum > (csum[-1] * 0.5))[0][0]])
tmp_aggregated.append(all_levels_aggregated[s:e].sum(axis=0))
tmp_chunks.append(total_chunks[s:e].sum())
s = e

del total_chunks
total_chunks = torch.tensor(tmp_chunks)
del tmp_chunks

del all_levels_aggregated
all_levels_aggregated = torch.row_stack(tmp_aggregated)
del tmp_aggregated

features = torch.tensor([n_ctgs, l50, max_len], device=device).T
output_data = {'file': filepaths}
else:
features = torch.tensor(lengths, device=device)[:, None]
output_data = {'file': filepaths, 'ID': seqnames}


all_levels_aggregated = ((all_levels_aggregated.T) / total_chunks).T
indices = list(output_data.keys())

s = 0
for lvl, e in zip(model.levels, model.parse):
Expand All @@ -163,13 +208,13 @@ def run_torchscript_inference(fasta, model, conf_models, window, step, vocab, n_

# build input matrix for confidence model
logger.debug('Calculating confidence probabilities')
conf_input = torch.column_stack([lengths, maxprobs])
conf_input = torch.column_stack([features, maxprobs])

output_data[f'{lvl}_prob'] = conf_model(conf_input).cpu().numpy().squeeze()

# set next left bound for all_levels_aggregated
s = e

output = pd.DataFrame(output_data).set_index('ID')
output = pd.DataFrame(output_data).set_index(indices)

return output
30 changes: 16 additions & 14 deletions src/gtnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import glob
import hashlib
from importlib.resources import files
import json
Expand Down Expand Up @@ -34,16 +33,16 @@ def get_logger():
class DeployPkg:
"""A class to handle loading and manipulating the deployment package"""

_deploy_pkg_url = "https://osf.io/download/mwgb9/"
_deploy_pkg_url = "https://osf.io/download/qf46x/"

_checksum = "0245fcf825bfe4de0770fcb46798ca90"
_checksum = "623aa991fb0d74e874b7d0da25496c26"

_manifest_name = 'manifest.json'

@classmethod
def check_pkg(cls):
deploy_dir = files(__package__).joinpath('deploy_pkg')
total = 0
for path in glob.glob(f"{deploy_dir}/*"):
total += os.path.getsize(path)
total = os.path.getsize(os.path.join(deploy_dir, cls._manifest_name))
if total == 0:
msg = ("Downloading GTNet deployment package. This will only happen on the first invocation "
"of gtnet predict or gtnet classify")
Expand Down Expand Up @@ -71,7 +70,7 @@ def path(self, path):
@property
def manifest(self):
if self._manifest is None:
with open(self.path('manifest.json'), 'r') as f:
with open(self.path(self._manifest_name), 'r') as f:
self._manifest = json.load(f)
return self._manifest

Expand All @@ -82,22 +81,25 @@ def __setitem__(self, key, val):
self.manifest[key] = val


def load_deploy_pkg(for_predict=False, for_filter=False):
def load_deploy_pkg(for_predict=False, for_filter=False, contigs=False):
if not (for_predict or for_filter):
for_predict = True
for_filter = True

pkg = DeployPkg()
key = 'contigs' if contigs else 'bins'

ret = list()
if for_predict:
tmp_conf_model = dict()
for lvl_dat in pkg['conf_model']:
lvl_dat['taxa'] = np.array(lvl_dat['taxa'])
for cm_data, taxa_data in zip(pkg['conf_model'][key], pkg['taxa']):
if cm_data['level'] != taxa_data['level']:
raise ValueError("Taxonomic levels are out of order in manifest file")
cm_data['taxa'] = np.array(taxa_data['taxa'])

lvl_dat['model'] = torch.jit.load(pkg.path(lvl_dat.pop('model')))
cm_data['model'] = torch.jit.load(pkg.path(cm_data.pop('model')))

tmp_conf_model[lvl_dat['level']] = lvl_dat
tmp_conf_model[cm_data['level']] = cm_data

ret.append(torch.jit.load(pkg.path(pkg['nn_model'])))
ret.append(tmp_conf_model)
Expand All @@ -106,8 +108,8 @@ def load_deploy_pkg(for_predict=False, for_filter=False):

if for_filter:
tmp_roc = dict()
for lvl_dat in pkg['conf_model']:
tmp_roc[lvl_dat['level']] = np.load(pkg.path(lvl_dat['roc']))
for cm_data in pkg['conf_model'][key]:
tmp_roc[cm_data['level']] = np.load(pkg.path(cm_data['roc']))
ret.append(tmp_roc)

return tuple(ret) if len(ret) > 1 else ret[0]
Expand Down

0 comments on commit 19aacbc

Please sign in to comment.