Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper to convert Bio.Phylo trees to PyTorch #2557

Merged
merged 5 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
'funsor': ('http://funsor.pyro.ai/en/stable/', None),
'opt_einsum': ('https://optimized-einsum.readthedocs.io/en/stable/', None),
'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None),
'Bio': ('https://biopython.readthedocs.io/en/latest/', None),
}

# document class constructors (__init__ methods):
Expand Down
2 changes: 2 additions & 0 deletions docs/source/contrib.epidemiology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ Distributions
:show-inheritance:
:member-order: bysource
:special-members: __call__

.. autofunction:: pyro.distributions.coalescent.bio_phylo_to_times
3 changes: 3 additions & 0 deletions pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from pyro.distributions.coalescent import bio_phylo_to_times

from .compartmental import CompartmentalModel
from .distributions import beta_binomial_dist, binomial_dist, infection_dist

__all__ = [
"CompartmentalModel",
"beta_binomial_dist",
"binomial_dist",
"bio_phylo_to_times",
"infection_dist",
]
47 changes: 47 additions & 0 deletions pyro/distributions/coalescent.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,53 @@ def __call__(self, rate_grid, t=slice(None)):
return const + linear + log


def bio_phylo_to_times(tree, *, get_time=None):
"""
Extracts coalescent summary statistics from a phylogeny, suitable for use
with :class:`~pyro.distributions.CoalescentRateLikelihood`.

:param Bio.Phylo.BaseTree.Clade tree: A phylogenetic tree.
:param callable get_time: Optional function to extract the time point of
each sub-:class:`~Bio.Phylo.BaseTree.Clade`. If absent, times will be
computed by cumulative `.branch_length`.
:returns: A pair of :class:`~torch.Tensor` s ``(leaf_times, coal_times)``
where ``leaf_times`` are times of sampling events (leaf nodes in the
phylogenetic tree) and ``coal_times`` are times of coalescences (leaf
nodes in the phylogenetic binary tree).
:rtype: tuple
"""
if get_time is None:
# Compute time as cumulative branch length.
def get_branch_length(clade):
branch_length = clade.branch_length
return 1.0 if branch_length is None else branch_length
times = {tree.root: get_branch_length(tree.root)}

leaf_times = []
coal_times = []
for clade in tree.find_clades():
if get_time is None:
time = times[clade]
for child in clade:
times[child] = time + get_branch_length(child)
else:
time = get_time(clade)

num_children = len(clade)
if num_children == 0:
leaf_times.append(time)
else:
# Pyro expects binary coalescent events, so we split n-ary events
# into n-1 separate binary events.
for _ in range(num_children - 1):
coal_times.append(time)
assert len(leaf_times) == 1 + len(coal_times)

leaf_times = torch.tensor(leaf_times)
coal_times = torch.tensor(coal_times)
return leaf_times, coal_times


def _gather(tensor, dim, index):
"""
Like :func:`torch.gather` but broadcasts.
Expand Down
2 changes: 1 addition & 1 deletion scripts/update_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys

root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
blacklist = ["/build/", "/dist/"]
blacklist = ["/build/", "/dist/", "/pyro/_version.py"]
file_types = [
("*.py", "# {}"),
("*.cpp", "// {}"),
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
'matplotlib>=1.3',
'torchvision>=0.6.0',
'visdom>=0.1.4',
# 'biopython>=1.54', # requires Python 3.6
'pandas',
'seaborn',
'wget',
Expand Down
87 changes: 86 additions & 1 deletion tests/distributions/test_coalescent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import io
import re

import pytest
import torch

import pyro
from pyro.distributions import CoalescentTimes, CoalescentTimesWithRate
from pyro.distributions.coalescent import CoalescentRateLikelihood, CoalescentTimesConstraint, _sample_coalescent_times
from pyro.distributions.coalescent import (CoalescentRateLikelihood, CoalescentTimesConstraint,
_sample_coalescent_times, bio_phylo_to_times)
from pyro.distributions.util import broadcast_shape
from tests.common import assert_close

Expand Down Expand Up @@ -144,3 +148,84 @@ def test_likelihood_sequential(num_leaves, num_steps, batch_shape, clamped):
for t in range(num_steps))

assert_close(actual, expected)


TREE_NEXUS = """
#NEXUS
Begin Trees;
Tree tree1=((EPI_ISL_408009:0.00000[&date=2020.08],
EPI_ISL_408008:0.00000[&date=2020.08]) NODE_0000004:0.17430[&date=2020.08],
(EPI_ISL_417931:0.28554[&date=2020.21],
(EPI_ISL_417332:0.11102[&date=2020.20], EPI_ISL_413931:0.08643[&date=2020.18])
NODE_0000005:0.16360[&date=2020.09], ((EPI_ISL_413558:0.11909[&date=2020.16],
(EPI_ISL_413559:0.07179[&date=2020.16],
(EPI_ISL_412862:0.00000[&date=2020.15],
EPI_ISL_413561:0.01093[&date=2020.16])
NODE_0000011:0.06086[&date=2020.15])
NODE_0000012:0.04730[&date=2020.09]) NODE_0000007:0.06603[&date=2020.04],
(EPI_ISL_411955:0.09393[&date=2020.11],
(EPI_ISL_417325:0.08372[&date=2020.17],
(EPI_ISL_417318:0.02411[&date=2020.16],
EPI_ISL_417320:0.03504[&date=2020.17])
NODE_0000009:0.05141[&date=2020.14])
NODE_0000006:0.07032[&date=2020.09])
NODE_0000014:0.04474[&date=2020.02]) NODE_0000010:0.04578[&date=2019.97],
(EPI_ISL_417933:0.15496[&date=2020.21], EPI_ISL_414648:0.13583[&date=2020.19],
(EPI_ISL_417932:0.09490[&date=2020.21],
(EPI_ISL_417937:0.05785[&date=2020.21],
EPI_ISL_417331:0.04419[&date=2020.20])
NODE_0000001:0.03705[&date=2020.15],
(EPI_ISL_417938:0.06860[&date=2020.21],
(EPI_ISL_417939:0.04394[&date=2020.21],
(EPI_ISL_417330:0.00314[&date=2020.20],
(EPI_ISL_416457:0.00000[&date=2020.21],
EPI_ISL_417935:0.00000[&date=2020.21])
NODE_0000018:0.01680[&date=2020.21])
NODE_0000017:0.02714[&date=2020.19])
NODE_0000016:0.02466[&date=2020.17])
NODE_0000015:0.02630[&date=2020.14])
NODE_0000000:0.06006[&date=2020.12])
NODE_0000002:0.13059[&date=2020.06])
NODE_0000003:0.02264[&date=2019.93]) NODE_0000008:0.10000[&date=2019.90];
End;
"""


@pytest.fixture
def tree():
Phylo = pytest.importorskip("Bio.Phylo")
tree_file = io.StringIO(TREE_NEXUS)
trees = list(Phylo.parse(tree_file, "nexus"))
assert len(trees) == 1
return trees[0]


def test_bio_phylo_to_times(tree):
leaf_times, coal_times = bio_phylo_to_times(tree)
assert len(coal_times) + 1 == len(leaf_times)

# Check positivity.
times = torch.cat([coal_times, leaf_times])
signs = torch.cat([-torch.ones_like(coal_times), torch.ones_like(leaf_times)])
times, index = times.sort(0)
signs = signs[index]
lineages = signs.flip([0]).cumsum(0).flip([0])
assert (lineages >= 0).all()


def test_bio_phylo_to_times_custom(tree):
# Test a custom time parser.
def get_time(clade):
date_string = re.search(r"date=(\d\d\d\d\.\d\d)", clade.comment).group(1)
return (float(date_string) - 2020) * 365.25

leaf_times, coal_times = bio_phylo_to_times(tree, get_time=get_time)
assert len(coal_times) + 1 == len(leaf_times)

# Check positivity.
times = torch.cat([coal_times, leaf_times])
signs = torch.cat([-torch.ones_like(coal_times), torch.ones_like(leaf_times)])
times, index = times.sort(0)
signs = signs[index]
lineages = signs.flip([0]).cumsum(0).flip([0])
assert (lineages >= 0).all()