Skip to content

Commit

Permalink
refactor multi test set datasets; add seq id test splits to GO
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamasb committed Feb 5, 2024
1 parent 81e7100 commit 0a05401
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 63 deletions.
26 changes: 19 additions & 7 deletions proteinworkshop/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base classes for protein structure datamodules and datasets."""
import copy
import os
import pathlib
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -81,12 +82,23 @@ def download(self):

def setup(self, stage: Optional[str] = None):
self.download()
logger.info("Preprocessing training data")
self.train_ds = self.train_dataset()
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()
logger.info("Preprocessing test data")
self.test_ds = self.test_dataset()

if stage == "fit" or stage is None:
logger.info("Preprocessing training data")
self.train_ds = self.train_dataset()
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()
elif stage == "test":
logger.info("Preprocessing test data")
if hasattr(self, "test_dataset_names"):
for split in self.test_dataset_names:
setattr(self, f"{split}_ds", self.test_dataset(split))
else:
self.test_ds = self.test_dataset()
elif stage == "lazy_init":
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()

# self.class_weights = self.get_class_weights()

@property
Expand Down Expand Up @@ -518,7 +530,7 @@ def get(self, idx: int) -> Data:
:return: PyTorch Geometric Data object.
"""
if self.in_memory:
return self._batch_format(self.data[idx])
return self._batch_format(copy.deepcopy(self.data[idx]))

if self.out_names is not None:
fname = f"{self.out_names[idx]}.pt"
Expand Down
53 changes: 30 additions & 23 deletions proteinworkshop/datasets/fold_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pathlib
import tarfile
from typing import Callable, Dict, Iterable, Optional
from typing import Callable, Dict, Iterable, List, Literal, Optional

import omegaconf
import pandas as pd
Expand Down Expand Up @@ -72,6 +72,11 @@ def __init__(
else:
self.transform = None

@property
def test_dataset_names(self) -> List[str]:
"""Provides a list of test set split names."""
return ["fold", "family", "superfamily"]

def download(self):
self.download_data_files()
self.download_structures()
Expand Down Expand Up @@ -152,16 +157,12 @@ def parse_class_map(self) -> Dict[str, str]:
)
return dict(class_map.values)

def setup(self, stage: Optional[str] = None):
self.download_data_files()
self.download_structures()
self.train_ds = self.train_dataset()
self.val_ds = self.val_dataset()
self.test_ds = self.test_dataset()

def _get_dataset(self, split: str) -> ProteinDataset:
if hasattr(self, f"{split}_ds"):
return getattr(self, f"{split}_ds")

df = self.parse_dataset(split)
return ProteinDataset(
ds = ProteinDataset(
root=str(self.data_dir),
pdb_dir=str(self.structure_dir),
pdb_codes=list(df.id),
Expand All @@ -171,15 +172,19 @@ def _get_dataset(self, split: str) -> ProteinDataset:
transform=self.transform,
in_memory=self.in_memory,
)
setattr(self, f"{split}_ds", ds)
return ds

def train_dataset(self) -> ProteinDataset:
return self._get_dataset("training")

def val_dataset(self) -> ProteinDataset:
return self._get_dataset("validation")

def test_dataset(self) -> ProteinDataset:
return self._get_dataset(f"test_{self.split}")
def test_dataset(
self, split: Literal["fold", "family", "superfamily"]
) -> ProteinDataset:
return self._get_dataset(f"test_{split}")

def train_dataloader(self) -> ProteinDataLoader:
self.train_ds = self.train_dataset()
Expand All @@ -201,8 +206,10 @@ def val_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def test_dataloader(self) -> ProteinDataLoader:
self.test_ds = self.test_dataset()
def test_dataloader(
self, split: Literal["fold", "family", "superfamily"]
) -> ProteinDataLoader:
self.test_ds = self.test_dataset(split)
return ProteinDataLoader(
self.test_ds,
batch_size=self.batch_size,
Expand All @@ -211,16 +218,16 @@ def test_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def get_test_loader(self, split: str) -> ProteinDataLoader:
log.info(f"Getting test loader: {split}")
test_ds = self._get_dataset(f"test_{split}")
return ProteinDataLoader(
test_ds,
batch_size=self.batch_size,
shuffle=False,
pin_memory=self.pin_memory,
num_workers=self.num_workers,
)
# def get_test_loader(self, split: str) -> ProteinDataLoader:
# log.info(f"Getting test loader: {split}")
# test_ds = self._get_dataset(f"test_{split}")
# return ProteinDataLoader(
# test_ds,
# batch_size=self.batch_size,
# shuffle=False,
# pin_memory=self.pin_memory,
# num_workers=self.num_workers,
# )

def parse_dataset(self, split: str) -> pd.DataFrame:
"""
Expand Down
113 changes: 89 additions & 24 deletions proteinworkshop/datasets/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Callable, Dict, Iterable, Literal, Optional
from typing import Callable, Dict, Iterable, List, Literal, Optional

import omegaconf
import pandas as pd
Expand Down Expand Up @@ -70,6 +70,14 @@ def __init__(

self.shuffle_labels = shuffle_labels

self.test_seq_similarity_cutoffs: List[float] = [
0.3,
0.4,
0.5,
0.7,
0.95,
]

if transforms is not None:
self.transform = self.compose_transforms(
omegaconf.OmegaConf.to_container(transforms, resolve=True)
Expand All @@ -79,14 +87,19 @@ def __init__(

self.train_fname = self.data_dir / "nrPDB-GO_train.txt"
self.val_fname = self.data_dir / "nrPDB-GO_valid.txt"
self.test_fname = self.data_dir / "nrPDB-GO_test.txt"
self.test_fname = self.data_dir / "nrPDB-GO_test.csv"
self.label_fname = self.data_dir / "nrPDB-GO_annot.tsv"
self.url = "https://zenodo.org/record/6622158/files/GeneOntology.zip"

log.info(
f"Setting up Gene Ontology dataset. Fraction {self.dataset_fraction}"
)

@property
def test_dataset_names(self) -> List[str]:
"""Provides a list of test set split names."""
return ["test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"]

@lru_cache
def parse_labels(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -129,12 +142,38 @@ def parse_labels(self) -> Dict[str, torch.Tensor]:
log.info(f"Encoded {len(labels)} labels for task {self.split}.")
return labels

# def get_test_loader(
# self,
# split: Literal["test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"],
# ) -> ProteinDataLoader:
# log.info(f"Getting test loader: {split}")
# test_ds = self._get_dataset(split)
# return ProteinDataLoader(
# test_ds,
# batch_size=self.batch_size,
# shuffle=False,
# pin_memory=self.pin_memory,
# num_workers=self.num_workers,
# )

def _get_dataset(
self, split: Literal["training", "validation", "testing"]
self,
split: Literal[
"training",
"validation",
"test_0.3",
"test_0.4",
"test_0.5",
"test_0.7",
"test_0.95",
],
) -> ProteinDataset:
if hasattr(self, f"{split}_ds"):
return getattr(self, f"{split}_ds")

df = self.parse_dataset(split)
log.info("Initialising Graphein dataset...")
return ProteinDataset(
ds = ProteinDataset(
root=str(self.data_dir),
pdb_dir=str(self.pdb_dir),
pdb_codes=list(df.pdb),
Expand All @@ -147,15 +186,22 @@ def _get_dataset(
format=self.format,
in_memory=self.in_memory,
)
setattr(self, f"{split}_ds", ds)
return ds

def train_dataset(self) -> ProteinDataset:
return self._get_dataset("training")

def val_dataset(self) -> ProteinDataset:
return self._get_dataset("validation")

def test_dataset(self) -> ProteinDataset:
return self._get_dataset("testing")
def test_dataset(
self,
split: Literal[
"test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"
],
) -> ProteinDataset:
return self._get_dataset(split)

def train_dataloader(self) -> ProteinDataLoader:
return ProteinDataLoader(
Expand All @@ -175,9 +221,14 @@ def val_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def test_dataloader(self) -> ProteinDataLoader:
def test_dataloader(
self,
split: Literal[
"test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"
],
) -> ProteinDataLoader:
return ProteinDataLoader(
self.test_dataset(),
self.test_dataset(split),
batch_size=self.batch_size,
shuffle=False,
pin_memory=self.pin_memory,
Expand Down Expand Up @@ -205,7 +256,16 @@ def exclude_pdbs(self):
pass

def parse_dataset(
self, split: Literal["training", "validation", "testing"]
self,
split: Literal[
"training",
"validation",
"test_0.3",
"test_0.4",
"test_0.5",
"test_0.7",
"test_0.95",
],
) -> pd.DataFrame:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches, switch
"""
Expand All @@ -221,8 +281,11 @@ def parse_dataset(
data = data.sample(frac=self.dataset_fraction)
elif split == "validation":
data = pd.read_csv(self.val_fname, sep="\t", header=None)
elif split == "testing":
data = pd.read_csv(self.test_fname, sep="\t", header=None)
elif split.startswith("test_"):
cutoff = int(float(split.split("_")[1]) * 100)
data = pd.read_csv(self.test_fname, sep=",")
data = data.loc[data[f"<{cutoff}%"] == 1]
data = pd.DataFrame(data["PDB-chain"].values)
else:
raise ValueError(f"Unknown split: {split}")

Expand Down Expand Up @@ -304,16 +367,18 @@ def __call__(self, data: Protein) -> Protein:
cfg.datamodule.transforms = []
log.info("Loaded config")

ds = hydra.utils.instantiate(cfg)
print(ds)
# labels = ds["datamodule"].parse_labels()
ds.datamodule.setup()
dl = ds["datamodule"].train_dataloader()
for batch in dl:
print(batch)
dl = ds["datamodule"].val_dataloader()
for batch in dl:
print(batch)
dl = ds["datamodule"].test_dataloader()
for batch in dl:
print(batch)
ds = hydra.utils.instantiate(cfg)["datamodule"]
ds.parse_dataset("test_0.3")
ds.parse_dataset("test_0.95")
# print(ds)
## labels = ds["datamodule"].parse_labels()
# ds.datamodule.setup()
# dl = ds["datamodule"].train_dataloader()
# for batch in dl:
# print(batch)
# dl = ds["datamodule"].val_dataloader()
# for batch in dl:
# print(batch)
# dl = ds["datamodule"].test_dataloader()
# for batch in dl:
# print(batch)
15 changes: 6 additions & 9 deletions proteinworkshop/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def train_model(

log.info("Initializing lazy layers...")
with torch.no_grad():
datamodule.setup() # type: ignore
datamodule.setup(stage="lazy_init") # type: ignore
batch = next(iter(datamodule.val_dataloader()))
log.info(f"Unfeaturized batch: {batch}")
batch = model.featurise(batch)
Expand Down Expand Up @@ -185,16 +185,13 @@ def train_model(

if cfg.get("test"):
log.info("Starting testing!")
# Run test on all splits if using fold_classification dataset
if (
cfg.dataset.datamodule._target_
== "proteinworkshop.datasets.fold_classification.FoldClassificationDataModule"
):
splits = ["fold", "family", "superfamily"]
if hasattr(datamodule, "test_dataset_names"):
splits = datamodule.test_dataset_names
wandb_logger = copy.deepcopy(trainer.logger)
for split in splits:
dataloader = datamodule.get_test_loader(split)
for i, split in enumerate(splits):
dataloader = datamodule.test_dataloader(split)
trainer.logger = False
log.info(f"Testing on {split} ({i+1} / {len(splits)})...")
results = trainer.test(
model=model, dataloaders=dataloader, ckpt_path="best"
)[0]
Expand Down

0 comments on commit 0a05401

Please sign in to comment.