Skip to content

Commit

Permalink
Feat/logging cif support (#402)
Browse files Browse the repository at this point in the history
* added cif support

* add cif tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove prints

* make cath metadata download optional

* modify cif tests

* pdb download formats

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bump biopandas pinned version to latest

* update torch index url

* add deprecation decorator to regnetwork

* ignore GRN tutorial notebook in tests due to RegNetwork going offline

* configure torch==1.13.0 install index url

* rm +cpu flag for torch >=2.0 install

* try removing pyg lib from CI

* use latest miniconda setup actions

* rm unused and deprecated 'U'flag in file read

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Arian Jamasb <[email protected]>
Co-authored-by: Arian Jamasb <[email protected]>
  • Loading branch information
4 people authored Aug 2, 2024
1 parent f6d9d72 commit 848a3f8
Show file tree
Hide file tree
Showing 17 changed files with 11,229 additions and 50 deletions.
15 changes: 9 additions & 6 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
uses: actions/checkout@v3
# See: https://github.com/marketplace/actions/setup-miniconda
- name: Setup miniconda
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
miniforge-variant: Mambaforge
Expand All @@ -48,15 +48,18 @@ jobs:
run: conda install dssp -c salilab
- name: Install mmseqs
run: mamba install -c conda-forge -c bioconda mmseqs2
- name: Install PyTorch
#run: mamba install -c pytorch pytorch==${{matrix.torch}} cpuonly
run: pip install torch==${{matrix.torch}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install PyTorch (1.13.0)
if: matrix.torch == '1.13.0'
run: pip install torch==${{matrix.torch}}+cpu --extra-index-url https://download.pytorch.org/whl/cpu
- name: Install PyTorch (2.0+)
if: matrix.torch != '1.13.0'
run: pip install torch==${{matrix.torch}} -f https://download.pytorch.org/whl/cpu
- name: Install PyG
#run: mamba install -c pyg pyg
run: pip install torch_geometric
- name: Install torch-cluster
#run: mamba install pytorch-cluster -c pyg
run: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${{matrix.torch}}+cpu.html
run: pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${{matrix.torch}}+cpu.html
- name: Install BLAST
run: sudo apt install ncbi-blast+
- name: Install Graphein
Expand All @@ -70,4 +73,4 @@ jobs:
- name: Run unit tests and generate coverage report
run: pytest .
- name: Test notebook execution
run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" --ignore-glob="notebooks/datasets_and_dataloaders.ipynb" --ignore-glob="notebooks/foldcomp.ipynb"
run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" --ignore-glob="notebooks/datasets_and_dataloaders.ipynb" --ignore-glob="notebooks/foldcomp.ipynb" --ignore-glob="notebooks/grn_tutorial.ipynb"
4 changes: 2 additions & 2 deletions .github/workflows/minimal__install.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
uses: actions/checkout@v3
# See: https://github.com/marketplace/actions/setup-miniconda
- name: Setup miniconda
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
miniforge-variant: Mambaforge
Expand All @@ -50,4 +50,4 @@ jobs:
- name: Run unit tests and generate coverage report
run: pytest . --ignore-glob="tests/protein/tensor" --ignore="tests/ml/test_conversion.py" --ignore="tests/ml/test_torch_geometric_dataset.py"
- name: Test notebook execution
run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" --ignore-glob="notebooks/higher_order_graphs.ipynb" --ignore-glob="notebooks/protein_graph_analytics.ipynb" --ignore-glob="notebooks/subgraphing_tutorial.ipynb" --ignore-glob="notebooks/splitting_a_dataset.ipynb" --ignore-glob="notebooks/protein_tensors.ipynb" --ignore-glob="notebooks/datasets_and_dataloaders.ipynb" --ignore-glob="notebooks/foldcomp.ipynb" --ignore-glob="notebooks/creating_datasets_from_the_pdb.ipynb"
run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" --ignore-glob="notebooks/higher_order_graphs.ipynb" --ignore-glob="notebooks/protein_graph_analytics.ipynb" --ignore-glob="notebooks/subgraphing_tutorial.ipynb" --ignore-glob="notebooks/splitting_a_dataset.ipynb" --ignore-glob="notebooks/protein_tensors.ipynb" --ignore-glob="notebooks/datasets_and_dataloaders.ipynb" --ignore-glob="notebooks/foldcomp.ipynb" --ignore-glob="notebooks/creating_datasets_from_the_pdb.ipynb" --ignore-glob="notebooks/grn_tutorial.ipynb"
2 changes: 1 addition & 1 deletion .requirements/base.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pandas<2.0.0
biopandas>=0.5.0.dev0
biopandas>=0.5.1
biopython
bioservices>=1.10.0
deepdiff
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Fix cluster file loading bug in `pdb_data.py` [#396](https://github.com/a-r-j/graphein/pull/396)

#### Misc
* set logging to false by default and added mmcif support [#402](https://github.com/a-r-j/graphein/pull/402)
* add metadata options for uniprot, ecnumber and CATH code to pdb manager [#398](https://github.com/a-r-j/graphein/pull/398)
* bumped logging level down from `INFO` to `DEBUG` at several places to reduced output length [#391](https://github.com/a-r-j/graphein/pull/391)
* exposed `fill_value` and `bfactor` option to `protein_to_pyg` function. [#385](https://github.com/a-r-j/graphein/pull/385) and [#388](https://github.com/a-r-j/graphein/pull/388)
Expand Down
2 changes: 2 additions & 0 deletions graphein/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
]
)

logger.disable("graphein")


def verbose(enabled: bool = False):
"""Enable/Disable logging.
Expand Down
20 changes: 19 additions & 1 deletion graphein/grn/parse_regnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import wget
from loguru import logger as log

from graphein.utils.utils import filter_dataframe, ping
from graphein.utils.utils import deprecated, filter_dataframe, ping


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
def _download_RegNetwork(
root_dir: Optional[Path] = None, network_type: str = "human"
) -> str:
Expand Down Expand Up @@ -86,6 +89,9 @@ def _download_RegNetwork(
return file


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
def _download_RegNetwork_regtypes(root_dir: Optional[Path] = None) -> str:
"""
Downloads RegNetwork regulatory interactions types to the root directory.
Expand Down Expand Up @@ -124,6 +130,9 @@ def _download_RegNetwork_regtypes(root_dir: Optional[Path] = None) -> str:
return file


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
@functools.lru_cache()
def load_RegNetwork_interactions(
root_dir: Optional[Path] = None,
Expand All @@ -144,6 +153,9 @@ def load_RegNetwork_interactions(
)


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
@functools.lru_cache()
def load_RegNetwork_regulation_types(
root_dir: Optional[Path] = None,
Expand All @@ -168,6 +180,9 @@ def load_RegNetwork_regulation_types(
)


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
def parse_RegNetwork(
gene_list: List[str], root_dir: Optional[Path] = None
) -> pd.DataFrame:
Expand Down Expand Up @@ -244,6 +259,9 @@ def standardise_RegNetwork(df: pd.DataFrame) -> pd.DataFrame:
return df


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
def RegNetwork_df(
gene_list: List[str],
root_dir: Optional[Path] = None,
Expand Down
2 changes: 1 addition & 1 deletion graphein/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_seq_records(
"Alphabet given. Only checking for terminating *!\n"
)
check_sequences = False
with open(filename, "rU") as handle:
with open(filename, "r") as handle:
records = list(SeqIO.parse(handle, file_format, alphabet=alphabet))
del handle
if check_sequences:
Expand Down
4 changes: 0 additions & 4 deletions graphein/ml/datasets/foldcomp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,6 @@ def download(self):
asyncio.run(_)
os.chdir(curr_dir)
log.info("Download complete.")
# log.info("Moving files to raw directory...")

# for f in self._database_files:
# shutil.move(f, self.root)
else:
log.info(f"FoldComp database already downloaded: {self.root}.")

Expand Down
11 changes: 5 additions & 6 deletions graphein/ml/datasets/pdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
).name

self.list_columns = ["ligands"]
self.labels = labels

# Data
self.download_metadata()
Expand Down Expand Up @@ -165,9 +166,10 @@ def download_metadata(self):
self._download_entry_metadata()
self._download_exp_type()
self._download_pdb_availability()
self._download_pdb_chain_cath_uniprot_map()
self._download_cath_id_cath_code_map()
self._download_pdb_chain_ec_number_map()
if self.labels:
self._download_pdb_chain_cath_uniprot_map()
self._download_cath_id_cath_code_map()
self._download_pdb_chain_ec_number_map()

def get_unavailable_pdb_files(
self, splits: Optional[List[str]] = None
Expand Down Expand Up @@ -643,15 +645,12 @@ def _parse_cath_code(self) -> Dict[str, str]:
with gzip.open(
self.root_dir / self.cath_id_cath_code_filename, "rt"
) as f:
print(f)
for line in f:
print(line)
try:
cath_id, cath_version, cath_code, cath_segment = (
line.strip().split()
)
cath_mapping[cath_id] = cath_code
print(cath_id, cath_code)
except ValueError:
continue
return cath_mapping
Expand Down
14 changes: 11 additions & 3 deletions graphein/protein/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import networkx as nx
import numpy as np
import pandas as pd
from biopandas.mmcif import PandasMmcif
from biopandas.mmtf import PandasMmtf
from biopandas.pdb import PandasPdb
from loguru import logger as log
Expand Down Expand Up @@ -111,21 +112,28 @@ def read_pdb_to_dataframe(
atomic_df = PandasPdb().read_pdb(path)
elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"):
atomic_df = PandasMmtf().read_mmtf(path)
elif (
path.endswith(".cif")
or path.endswith(".cif.gz")
or path.endswith(".mmcif")
or path.endswith(".mmcif.gz")
):
atomic_df = PandasMmcif().read_mmcif(path)
else:
raise ValueError(
f"File {path} must be either .pdb(.gz), .mmtf(.gz) or .ent, not {path.split('.')[-1]}"
f"File {path} must be either .pdb(.gz), .mmtf(.gz), .(mm)cif(.gz) or .ent, not {path.split('.')[-1]}"
)
elif uniprot_id is not None:
atomic_df = PandasPdb().fetch_pdb(
uniprot_id=uniprot_id, source="alphafold2-v3"
)
else:
atomic_df = PandasPdb().fetch_pdb(pdb_code)

atomic_df = atomic_df.get_model(model_index)
if len(atomic_df.df["ATOM"]) == 0:
raise ValueError(f"No model found for index: {model_index}")

if isinstance(atomic_df, PandasMmcif):
atomic_df = atomic_df.convert_to_pandas_pdb()
return pd.concat([atomic_df.df["ATOM"], atomic_df.df["HETATM"]])


Expand Down
10 changes: 7 additions & 3 deletions graphein/protein/tensor/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
conda_channel="pyg",
pip_install=True,
)
log.debug(message)
log.warning(message)

try:
import torch
Expand All @@ -60,7 +60,7 @@
conda_channel="pytorch",
pip_install=True,
)
log.debug(message)
log.warning(message)


def get_protein_length(df: pd.DataFrame, insertions: bool = True) -> int:
Expand Down Expand Up @@ -246,7 +246,9 @@ def protein_to_pyg(

out = Data(
coords=protein_df_to_tensor(
df, atoms_to_keep=atom_types, fill_value=fill_value_coords
df,
atoms_to_keep=atom_types,
fill_value=fill_value_coords,
),
residues=get_sequence(
df,
Expand All @@ -259,6 +261,7 @@ def protein_to_pyg(
residue_type=residue_type_tensor(df),
chains=protein_df_to_chain_tensor(df),
)

if store_het:
out.hetatms = [het_coords]

Expand Down Expand Up @@ -360,6 +363,7 @@ def protein_df_to_tensor(
positions[residue_indices, atom_indices] = torch.tensor(
df[["x_coord", "y_coord", "z_coord"]].values
).float()

return positions


Expand Down
14 changes: 7 additions & 7 deletions graphein/protein/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import lru_cache, partial
from multiprocessing import Pool
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
from urllib.error import HTTPError
from urllib.request import urlopen

Expand Down Expand Up @@ -96,7 +96,7 @@ def read_fasta(file_path: str) -> Dict[str, str]:
def download_pdb_multiprocessing(
pdb_codes: List[str],
out_dir: Union[str, Path], # type: ignore
format: str = "pdb",
format: Literal["pdb", "mmtf", "mmcif", "cif", "bcif"] = "pdb",
overwrite: bool = False,
strict: bool = False,
max_workers: int = 16,
Expand All @@ -108,7 +108,7 @@ def download_pdb_multiprocessing(
:type pdb_codes: List[str]
:param out_dir: Path to directory to download PDB structures to.
:type out_dir: Union[str, Path]
:param format: Filetype to download. ``pdb``, ``mmtf``, ``mmcif`` or ``bcif``.
:param format: Filetype to download. ``pdb``, ``mmtf``, ``mmcif``/``cif`` or ``bcif``.
:type format: str
:param overwrite: Whether to overwrite existing files, defaults to
``False``.
Expand Down Expand Up @@ -146,7 +146,7 @@ def download_pdb_multiprocessing(
def download_pdb(
pdb_code: str,
out_dir: Optional[Union[str, Path]] = None,
format: str = "pdb",
format: Literal["pdb", "mmtf", "mmcif", "cif", "bcif"] = "pdb",
check_obsolete: bool = False,
overwrite: bool = False,
strict: bool = True,
Expand All @@ -162,7 +162,7 @@ def download_pdb(
:param out_dir: Path to directory to download PDB structure to. If ``None``,
will download to a temporary directory.
:type out_dir: Optional[Union[str, Path]]
:param format: Filetype to download. ``pdb``, ``mmtf``, ``mmcif`` or ``bcif``.
:param format: Filetype to download. ``pdb``, ``mmtf``, ``mmcif``/``cif`` or ``bcif``.
:type format: str
:param check_obsolete: Whether to check for obsolete PDB codes,
defaults to ``False``. If an obsolete PDB code is found, the updated PDB
Expand All @@ -183,15 +183,15 @@ def download_pdb(
elif format == "mmtf":
BASE_URL = "https://mmtf.rcsb.org/v1.0/full/"
extension = ".mmtf.gz"
elif format == "mmcif":
elif format == "cif" or format == "mmcif":
BASE_URL = "https://files.rcsb.org/download/"
extension = ".cif.gz"
elif format == "bcif":
BASE_URL = "https://models.rcsb.org/"
extension = ".bcif.gz"
else:
raise ValueError(
f"Invalid format: {format}. Must be 'pdb', 'mmtf', 'mmcif' or 'bcif'."
f"Invalid format: {format}. Must be 'pdb', 'mmtf', '(mm)cif' or 'bcif'."
)

# Make output directory if it doesn't exist or set it to tempdir if None
Expand Down
Loading

0 comments on commit 848a3f8

Please sign in to comment.