diff --git a/ginkgo_ai_client/__init__.py b/ginkgo_ai_client/__init__.py index e7880dd..6b54845 100644 --- a/ginkgo_ai_client/__init__.py +++ b/ginkgo_ai_client/__init__.py @@ -8,6 +8,7 @@ PromoterActivityQuery, DiffusionMaskedQuery, DiffusionMaskedResponse, + BoltzStructurePredictionQuery, ) __all__ = [ diff --git a/ginkgo_ai_client/queries.py b/ginkgo_ai_client/queries.py index be4d5cc..2f94a98 100644 --- a/ginkgo_ai_client/queries.py +++ b/ginkgo_ai_client/queries.py @@ -1,12 +1,22 @@ """Classes to define queries to the Ginkgo AI API.""" -from typing import Dict, Optional, Any, List +from typing import Dict, Optional, Any, List, Literal, Union from abc import ABC, abstractmethod +from pathlib import Path import json +import yaml +import tempfile import pydantic +import requests -from ginkgo_ai_client.utils import fasta_sequence_iterator, IteratorWithLength +from ginkgo_ai_client.utils import ( + fasta_sequence_iterator, + IteratorWithLength, + cif_to_pdb, +) + +## ---- Base classes -------------------------------------------------------------- class QueryBase(pydantic.BaseModel, ABC): @@ -42,6 +52,8 @@ def write_to_jsonl(self, path: str): f.write(self.model_dump_json() + "\n") +## ---- MASKEDLM AND EMBEDDINGS ------------------------------------------------------ + _maskedlm_models_properties = { "ginkgo-aa0-650M": "protein", "esm2-650M": "protein", @@ -71,9 +83,10 @@ def _validate_model_and_sequence(model, sequence: str, allow_masks=False): f"Model {model} requires the sequence to only contain ATGC characters" ) elif sequence_type == "nucleotide": - if not set(sequence.lower()).issubset({"a", "t", "g", "c", "r", "y", "s", "w", "k", "m", "b", "d", "h", "v", "n"}): + if not set(sequence.lower()).issubset(set("atgcrsywkmdbhvn")): raise ValueError( - f"Model {model} requires the sequence to only contain valid IUPAC nucleotide characters" + f"Model {model} requires the sequence to only contain valid ATGC or " + f"IUPAC nucleotide characters" ) elif sequence_type == "protein": if not set(sequence).issubset(set("ACDEFGHIKLMNPQRSTVWY")): @@ -228,7 +241,7 @@ def check_model_and_sequence_compatibility(cls, query): cls.__doc__ += auto_doc_str[:1] -### PROMOTER ACTIVITY QUERIES ### +## ---- PROMOTER ACTIVITY QUERIES --------------------------------------------------- class PromoterActivityResponse(ResponseBase): @@ -384,6 +397,9 @@ def list_with_promoter_from_fasta( return list(iterator) +## ---- DIFFUSION QUERIES --------------------------------------------------------- + + class DiffusionMaskedResponse(ResponseBase): """A response to a DiffusionMaskedQuery, with attributes `sequence` (the predicted sequence) and `query_name` (the original query's name). @@ -470,10 +486,131 @@ def validate_query(cls, query): raise ValueError("temperature must be between 0 and 1") # Validate decoding_order_strategy if query.decoding_order_strategy not in ["max_prob", "entropy"]: - raise ValueError( - "decoding_order_strategy must be 'max_prob' or 'entropy'" - ) + raise ValueError("decoding_order_strategy must be 'max_prob' or 'entropy'") # Validate unmaskings_per_step if not 1 <= query.unmaskings_per_step <= 1000: raise ValueError("unmaskings_per_step must be between 1 and 1000") return query + + +## ---- STRUCTURE PREDICTION QUERIES ------------------------------------------------ + + +class _Protein(pydantic.BaseModel): + id: Union[List[str], str] + sequence: str + + @pydantic.validator("sequence") + def validate_sequence(cls, sequence): + sequence = sequence.upper() + invalid_chars = [c for c in sequence if c not in "LAGVSERTIDPKQNFYMHWCXBUZO"] + if len(invalid_chars) > 0: + invalid_chars_str = ", ".join(sorted(set(invalid_chars))) + raise ValueError( + f"Sequence contains invalid characters: {invalid_chars_str}" + ) + return sequence + + +class _CCD(pydantic.BaseModel): + id: Union[List[str], str] + ccd: str + + +class _Smiles(pydantic.BaseModel): + id: Union[List[str], str] + smiles: str + + +class BoltzStructurePredictionResponse(ResponseBase): + """A response to a BoltzStructurePredictionQuery + + Attributes + ---------- + cif_file_url: str + The URL of the cif file. + confidence_data: Dict[str, Any] + The confidence data. + query_name: Optional[str] = None + The name of the query. It will appear in the API response and can be used to + handle exceptions. + + Examples + -------- + + .. code:: python + + response = BoltzStructurePredictionResponse( + cif_file_url="https://example.com/structure.cif", + confidence_data={"confidence": 0.95}, + query_name="my_query", + ) + response.download_structure("structure.cif") # or... + response.download_structure("structure.pdb") + """ + + cif_file_url: str + confidence_data: Dict[str, Any] + query_name: Optional[str] = None + + def download_structure(self, path: str): + """Download the structure from the URL and save it to a file.""" + path = Path(path) + if str(path).endswith(".pdb"): + with tempfile.TemporaryDirectory() as temp_dir: + cif_path = Path(temp_dir) / "temp.cif" + self.download_structure(cif_path) + cif_to_pdb(cif_path, path) + else: + response = requests.get(self.cif_file_url) + with open(path, "w") as f: + f.write(response.text) + + +class BoltzStructurePredictionQuery(QueryBase): + """A query to predict the structure of a protein using the Boltz model. + + Parameters + ---------- + sequences: List[Dict[Literal["protein", "ligand"], Union[_Protein, _CCD, _Smiles]]] + The sequences to predict the structure for + model: Literal["boltz"] = "boltz" + The model to use for the inference (only Boltz(1) is supported for now). + query_name: Optional[str] = None + The name of the query. It will appear in the API response and can be used to + handle exceptions. + """ + + sequences: List[Dict[Literal["protein", "ligand"], Union[_Protein, _CCD, _Smiles]]] + model: Literal["boltz"] = "boltz" + query_name: Optional[str] = None + + def to_request_params(self) -> Dict: + return { + "model": "boltz", + "transforms": [{"type": "INFER_STRUCTURE"}], + "text": self.model_dump(exclude=["model", "query_name"], mode="json"), + } + + def parse_response(self, results: Dict) -> BoltzStructurePredictionResponse: + return BoltzStructurePredictionResponse( + cif_file_url=results["cif_file_url"], + confidence_data=results["confidence_data"], + query_name=self.query_name, + ) + + @classmethod + def from_yaml_file(cls, path, query_name: Optional[str] = "auto"): + path = Path(path) + if query_name == "auto": + query_name = path.name + with open(path, "r") as f: + data = yaml.load(f, yaml.SafeLoader) + return cls(sequences=data["sequences"], query_name=query_name) + + @classmethod + def from_protein_sequence(cls, sequence: str, query_name: Optional[str] = None): + return cls( + sequences=[{"protein": {"id": "A", "sequence": sequence}}], + query_name=query_name, + ) diff --git a/ginkgo_ai_client/utils.py b/ginkgo_ai_client/utils.py index e2915e2..2794bb3 100644 --- a/ginkgo_ai_client/utils.py +++ b/ginkgo_ai_client/utils.py @@ -1,11 +1,14 @@ """Utility functions for building queries, batches, etc.""" -from Bio import SeqIO from pathlib import Path import gzip from typing import Iterator, Union +from Bio import SeqIO +from Bio.PDB import MMCIFParser, PDBIO + + class IteratorWithLength(Iterator): """An iterator that also has a length attribute, which will produce useful progress bars with % progress and estimated time of arrival. @@ -50,3 +53,12 @@ def fasta_sequence_iterator(fasta_path: str): # compute the number of sequences in the fasta file by counting ">" length = _fast_fasta_sequence_count(fasta_path) return IteratorWithLength(SeqIO.parse(fasta_path, "fasta"), length) + + +def cif_to_pdb(cif_path: Union[str, Path], pdb_path: Union[str, Path]): + """Convert a cif file to a pdb file.""" + parser = MMCIFParser(QUIET=True) + structure = parser.get_structure(cif_path.stem, cif_path) + io = PDBIO() + io.set_structure(structure) + io.save(str(pdb_path))