Skip to content

Commit

Permalink
added boltz query classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentin Zulkower committed Dec 16, 2024
1 parent b3f6519 commit fc6b476
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 9 deletions.
1 change: 1 addition & 0 deletions ginkgo_ai_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PromoterActivityQuery,
DiffusionMaskedQuery,
DiffusionMaskedResponse,
BoltzStructurePredictionQuery,
)

__all__ = [
Expand Down
153 changes: 145 additions & 8 deletions ginkgo_ai_client/queries.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
"""Classes to define queries to the Ginkgo AI API."""

from typing import Dict, Optional, Any, List
from typing import Dict, Optional, Any, List, Literal, Union
from abc import ABC, abstractmethod
from pathlib import Path
import json
import yaml
import tempfile

import pydantic
import requests

from ginkgo_ai_client.utils import fasta_sequence_iterator, IteratorWithLength
from ginkgo_ai_client.utils import (
fasta_sequence_iterator,
IteratorWithLength,
cif_to_pdb,
)

## ---- Base classes --------------------------------------------------------------


class QueryBase(pydantic.BaseModel, ABC):
Expand Down Expand Up @@ -42,6 +52,8 @@ def write_to_jsonl(self, path: str):
f.write(self.model_dump_json() + "\n")


## ---- MASKEDLM AND EMBEDDINGS ------------------------------------------------------

_maskedlm_models_properties = {
"ginkgo-aa0-650M": "protein",
"esm2-650M": "protein",
Expand Down Expand Up @@ -71,9 +83,10 @@ def _validate_model_and_sequence(model, sequence: str, allow_masks=False):
f"Model {model} requires the sequence to only contain ATGC characters"
)
elif sequence_type == "nucleotide":
if not set(sequence.lower()).issubset({"a", "t", "g", "c", "r", "y", "s", "w", "k", "m", "b", "d", "h", "v", "n"}):
if not set(sequence.lower()).issubset(set("atgcrsywkmdbhvn")):
raise ValueError(
f"Model {model} requires the sequence to only contain valid IUPAC nucleotide characters"
f"Model {model} requires the sequence to only contain valid ATGC or "
f"IUPAC nucleotide characters"
)
elif sequence_type == "protein":
if not set(sequence).issubset(set("ACDEFGHIKLMNPQRSTVWY")):
Expand Down Expand Up @@ -228,7 +241,7 @@ def check_model_and_sequence_compatibility(cls, query):
cls.__doc__ += auto_doc_str[:1]


### PROMOTER ACTIVITY QUERIES ###
## ---- PROMOTER ACTIVITY QUERIES ---------------------------------------------------


class PromoterActivityResponse(ResponseBase):
Expand Down Expand Up @@ -384,6 +397,9 @@ def list_with_promoter_from_fasta(
return list(iterator)


## ---- DIFFUSION QUERIES ---------------------------------------------------------


class DiffusionMaskedResponse(ResponseBase):
"""A response to a DiffusionMaskedQuery, with attributes `sequence` (the predicted
sequence) and `query_name` (the original query's name).
Expand Down Expand Up @@ -470,10 +486,131 @@ def validate_query(cls, query):
raise ValueError("temperature must be between 0 and 1")
# Validate decoding_order_strategy
if query.decoding_order_strategy not in ["max_prob", "entropy"]:
raise ValueError(
"decoding_order_strategy must be 'max_prob' or 'entropy'"
)
raise ValueError("decoding_order_strategy must be 'max_prob' or 'entropy'")
# Validate unmaskings_per_step
if not 1 <= query.unmaskings_per_step <= 1000:
raise ValueError("unmaskings_per_step must be between 1 and 1000")
return query


## ---- STRUCTURE PREDICTION QUERIES ------------------------------------------------


class _Protein(pydantic.BaseModel):
id: Union[List[str], str]
sequence: str

@pydantic.validator("sequence")
def validate_sequence(cls, sequence):
sequence = sequence.upper()
invalid_chars = [c for c in sequence if c not in "LAGVSERTIDPKQNFYMHWCXBUZO"]
if len(invalid_chars) > 0:
invalid_chars_str = ", ".join(sorted(set(invalid_chars)))
raise ValueError(
f"Sequence contains invalid characters: {invalid_chars_str}"
)
return sequence


class _CCD(pydantic.BaseModel):
id: Union[List[str], str]
ccd: str


class _Smiles(pydantic.BaseModel):
id: Union[List[str], str]
smiles: str


class BoltzStructurePredictionResponse(ResponseBase):
"""A response to a BoltzStructurePredictionQuery
Attributes
----------
cif_file_url: str
The URL of the cif file.
confidence_data: Dict[str, Any]
The confidence data.
query_name: Optional[str] = None
The name of the query. It will appear in the API response and can be used to
handle exceptions.
Examples
--------
.. code:: python
response = BoltzStructurePredictionResponse(
cif_file_url="https://example.com/structure.cif",
confidence_data={"confidence": 0.95},
query_name="my_query",
)
response.download_structure("structure.cif") # or...
response.download_structure("structure.pdb")
"""

cif_file_url: str
confidence_data: Dict[str, Any]
query_name: Optional[str] = None

def download_structure(self, path: str):
"""Download the structure from the URL and save it to a file."""
path = Path(path)
if str(path).endswith(".pdb"):
with tempfile.TemporaryDirectory() as temp_dir:
cif_path = Path(temp_dir) / "temp.cif"
self.download_structure(cif_path)
cif_to_pdb(cif_path, path)
else:
response = requests.get(self.cif_file_url)
with open(path, "w") as f:
f.write(response.text)


class BoltzStructurePredictionQuery(QueryBase):
"""A query to predict the structure of a protein using the Boltz model.
Parameters
----------
sequences: List[Dict[Literal["protein", "ligand"], Union[_Protein, _CCD, _Smiles]]]
The sequences to predict the structure for
model: Literal["boltz"] = "boltz"
The model to use for the inference (only Boltz(1) is supported for now).
query_name: Optional[str] = None
The name of the query. It will appear in the API response and can be used to
handle exceptions.
"""

sequences: List[Dict[Literal["protein", "ligand"], Union[_Protein, _CCD, _Smiles]]]
model: Literal["boltz"] = "boltz"
query_name: Optional[str] = None

def to_request_params(self) -> Dict:
return {
"model": "boltz",
"transforms": [{"type": "INFER_STRUCTURE"}],
"text": self.model_dump(exclude=["model", "query_name"], mode="json"),
}

def parse_response(self, results: Dict) -> BoltzStructurePredictionResponse:
return BoltzStructurePredictionResponse(
cif_file_url=results["cif_file_url"],
confidence_data=results["confidence_data"],
query_name=self.query_name,
)

@classmethod
def from_yaml_file(cls, path, query_name: Optional[str] = "auto"):
path = Path(path)
if query_name == "auto":
query_name = path.name
with open(path, "r") as f:
data = yaml.load(f, yaml.SafeLoader)
return cls(sequences=data["sequences"], query_name=query_name)

@classmethod
def from_protein_sequence(cls, sequence: str, query_name: Optional[str] = None):
return cls(
sequences=[{"protein": {"id": "A", "sequence": sequence}}],
query_name=query_name,
)
14 changes: 13 additions & 1 deletion ginkgo_ai_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Utility functions for building queries, batches, etc."""

from Bio import SeqIO
from pathlib import Path
import gzip
from typing import Iterator, Union


from Bio import SeqIO
from Bio.PDB import MMCIFParser, PDBIO


class IteratorWithLength(Iterator):
"""An iterator that also has a length attribute, which will produce useful
progress bars with % progress and estimated time of arrival.
Expand Down Expand Up @@ -50,3 +53,12 @@ def fasta_sequence_iterator(fasta_path: str):
# compute the number of sequences in the fasta file by counting ">"
length = _fast_fasta_sequence_count(fasta_path)
return IteratorWithLength(SeqIO.parse(fasta_path, "fasta"), length)


def cif_to_pdb(cif_path: Union[str, Path], pdb_path: Union[str, Path]):
"""Convert a cif file to a pdb file."""
parser = MMCIFParser(QUIET=True)
structure = parser.get_structure(cif_path.stem, cif_path)
io = PDBIO()
io.set_structure(structure)
io.save(str(pdb_path))

0 comments on commit fc6b476

Please sign in to comment.