From db62343f3665487c8a3d2011fb5590c9f9af9211 Mon Sep 17 00:00:00 2001 From: Poko18 Date: Sat, 16 Mar 2024 22:13:58 +0100 Subject: [PATCH 1/3] Refactor model.py: Add initial guess for all_atom_positions --- alphafold/model/model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 88e90f1f4..4d641168d 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -147,12 +147,16 @@ def predict(self, num_ensemble = self.config.data.eval.num_ensemble L = aatype.shape[1] - # initialize - zeros = lambda shape: np.zeros(shape, dtype=np.float16) prev = {'prev_msa_first_row': zeros([L,256]), - 'prev_pair': zeros([L,L,128]), - 'prev_pos': zeros([L,37,3])} + 'prev_pair': zeros([L,L,128])} + + # initial guess + if "all_atom_positions" in feat: + logging.info("INFO: using provided all_atom_positions as initial guess") + prev["prev_pos"] = feat["all_atom_positions"] + else: + prev["prev_pos"] = np.zeros([L,37,3]) def run(key, feat, prev): def _jnp_to_np(x): @@ -199,4 +203,4 @@ def _jnp_to_np(x): break logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) - return result, r \ No newline at end of file + return result, r From ce488f4d903ad7d117ddf15aa1665a6321098cb3 Mon Sep 17 00:00:00 2001 From: Poko18 Date: Mon, 18 Mar 2024 16:50:25 +0100 Subject: [PATCH 2/3] Add saving Protein to mmCIF file and reading Protein from mmCIF (from AF2 repo - 6c4d833fbd1c6b8e7c9a21dae5d4ada2ce777e10) --- alphafold/common/mmcif_metadata.py | 213 +++++++++++++++++ alphafold/common/protein.py | 327 +++++++++++++++++++++++++- alphafold/common/residue_constants.py | 32 ++- run_alphafold.py | 96 +++++++- 4 files changed, 652 insertions(+), 16 deletions(-) create mode 100644 alphafold/common/mmcif_metadata.py diff --git a/alphafold/common/mmcif_metadata.py b/alphafold/common/mmcif_metadata.py new file mode 100644 index 000000000..b5b725d5f --- /dev/null +++ b/alphafold/common/mmcif_metadata.py @@ -0,0 +1,213 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""mmCIF metadata.""" + +from typing import Mapping, Sequence +from alphafold import version +import numpy as np + + +_DISCLAIMER = """ALPHAFOLD DATA, COPYRIGHT (2021) DEEPMIND TECHNOLOGIES LIMITED. +THE INFORMATION PROVIDED IS THEORETICAL MODELLING ONLY AND CAUTION SHOULD BE +EXERCISED IN ITS USE. IT IS PROVIDED "AS-IS" WITHOUT ANY WARRANTY OF ANY KIND, +WHETHER EXPRESSED OR IMPLIED. NO WARRANTY IS GIVEN THAT USE OF THE INFORMATION +SHALL NOT INFRINGE THE RIGHTS OF ANY THIRD PARTY. DISCLAIMER: THE INFORMATION IS +NOT INTENDED TO BE A SUBSTITUTE FOR PROFESSIONAL MEDICAL ADVICE, DIAGNOSIS, OR +TREATMENT, AND DOES NOT CONSTITUTE MEDICAL OR OTHER PROFESSIONAL ADVICE. IT IS +AVAILABLE FOR ACADEMIC AND COMMERCIAL PURPOSES, UNDER CC-BY 4.0 LICENCE.""" + +# Authors of the Nature methods paper we reference in the mmCIF. +_MMCIF_PAPER_AUTHORS = ( + 'Jumper, John', + 'Evans, Richard', + 'Pritzel, Alexander', + 'Green, Tim', + 'Figurnov, Michael', + 'Ronneberger, Olaf', + 'Tunyasuvunakool, Kathryn', + 'Bates, Russ', + 'Zidek, Augustin', + 'Potapenko, Anna', + 'Bridgland, Alex', + 'Meyer, Clemens', + 'Kohl, Simon A. A.', + 'Ballard, Andrew J.', + 'Cowie, Andrew', + 'Romera-Paredes, Bernardino', + 'Nikolov, Stanislav', + 'Jain, Rishub', + 'Adler, Jonas', + 'Back, Trevor', + 'Petersen, Stig', + 'Reiman, David', + 'Clancy, Ellen', + 'Zielinski, Michal', + 'Steinegger, Martin', + 'Pacholska, Michalina', + 'Berghammer, Tamas', + 'Silver, David', + 'Vinyals, Oriol', + 'Senior, Andrew W.', + 'Kavukcuoglu, Koray', + 'Kohli, Pushmeet', + 'Hassabis, Demis', +) + +# Authors of the mmCIF - we set them to be equal to the authors of the paper. +_MMCIF_AUTHORS = _MMCIF_PAPER_AUTHORS + + +def add_metadata_to_mmcif( + old_cif: Mapping[str, Sequence[str]], model_type: str +) -> Mapping[str, Sequence[str]]: + """Adds AlphaFold metadata in the given mmCIF.""" + cif = {} + + # ModelCIF conformation dictionary. + cif['_audit_conform.dict_name'] = ['mmcif_ma.dic'] + cif['_audit_conform.dict_version'] = ['1.3.9'] + cif['_audit_conform.dict_location'] = [ + 'https://raw.githubusercontent.com/ihmwg/ModelCIF/master/dist/' + 'mmcif_ma.dic' + ] + + # License and disclaimer. + cif['_pdbx_data_usage.id'] = ['1', '2'] + cif['_pdbx_data_usage.type'] = ['license', 'disclaimer'] + cif['_pdbx_data_usage.details'] = [ + 'Data in this file is available under a CC-BY-4.0 license.', + _DISCLAIMER, + ] + cif['_pdbx_data_usage.url'] = [ + 'https://creativecommons.org/licenses/by/4.0/', + '?', + ] + cif['_pdbx_data_usage.name'] = ['CC-BY-4.0', '?'] + + # Structure author details. + cif['_audit_author.name'] = [] + cif['_audit_author.pdbx_ordinal'] = [] + for author_index, author_name in enumerate(_MMCIF_AUTHORS, start=1): + cif['_audit_author.name'].append(author_name) + cif['_audit_author.pdbx_ordinal'].append(str(author_index)) + + # Paper author details. + cif['_citation_author.citation_id'] = [] + cif['_citation_author.name'] = [] + cif['_citation_author.ordinal'] = [] + for author_index, author_name in enumerate(_MMCIF_PAPER_AUTHORS, start=1): + cif['_citation_author.citation_id'].append('primary') + cif['_citation_author.name'].append(author_name) + cif['_citation_author.ordinal'].append(str(author_index)) + + # Paper citation details. + cif['_citation.id'] = ['primary'] + cif['_citation.title'] = [ + 'Highly accurate protein structure prediction with AlphaFold' + ] + cif['_citation.journal_full'] = ['Nature'] + cif['_citation.journal_volume'] = ['596'] + cif['_citation.page_first'] = ['583'] + cif['_citation.page_last'] = ['589'] + cif['_citation.year'] = ['2021'] + cif['_citation.journal_id_ASTM'] = ['NATUAS'] + cif['_citation.country'] = ['UK'] + cif['_citation.journal_id_ISSN'] = ['0028-0836'] + cif['_citation.journal_id_CSD'] = ['0006'] + cif['_citation.book_publisher'] = ['?'] + cif['_citation.pdbx_database_id_PubMed'] = ['34265844'] + cif['_citation.pdbx_database_id_DOI'] = ['10.1038/s41586-021-03819-2'] + + # Type of data in the dataset including data used in the model generation. + cif['_ma_data.id'] = ['1'] + cif['_ma_data.name'] = ['Model'] + cif['_ma_data.content_type'] = ['model coordinates'] + + # Description of number of instances for each entity. + cif['_ma_target_entity_instance.asym_id'] = old_cif['_struct_asym.id'] + cif['_ma_target_entity_instance.entity_id'] = old_cif[ + '_struct_asym.entity_id' + ] + cif['_ma_target_entity_instance.details'] = ['.'] * len( + cif['_ma_target_entity_instance.entity_id'] + ) + + # Details about the target entities. + cif['_ma_target_entity.entity_id'] = cif[ + '_ma_target_entity_instance.entity_id' + ] + cif['_ma_target_entity.data_id'] = ['1'] * len( + cif['_ma_target_entity.entity_id'] + ) + cif['_ma_target_entity.origin'] = ['.'] * len( + cif['_ma_target_entity.entity_id'] + ) + + # Details of the models being deposited. + cif['_ma_model_list.ordinal_id'] = ['1'] + cif['_ma_model_list.model_id'] = ['1'] + cif['_ma_model_list.model_group_id'] = ['1'] + cif['_ma_model_list.model_name'] = ['Top ranked model'] + + cif['_ma_model_list.model_group_name'] = [ + f'AlphaFold {model_type} v{version.__version__} model' + ] + cif['_ma_model_list.data_id'] = ['1'] + cif['_ma_model_list.model_type'] = ['Ab initio model'] + + # Software used. + cif['_software.pdbx_ordinal'] = ['1'] + cif['_software.name'] = ['AlphaFold'] + cif['_software.version'] = [f'v{version.__version__}'] + cif['_software.type'] = ['package'] + cif['_software.description'] = ['Structure prediction'] + cif['_software.classification'] = ['other'] + cif['_software.date'] = ['?'] + + # Collection of software into groups. + cif['_ma_software_group.ordinal_id'] = ['1'] + cif['_ma_software_group.group_id'] = ['1'] + cif['_ma_software_group.software_id'] = ['1'] + + # Method description to conform with ModelCIF. + cif['_ma_protocol_step.ordinal_id'] = ['1', '2', '3'] + cif['_ma_protocol_step.protocol_id'] = ['1', '1', '1'] + cif['_ma_protocol_step.step_id'] = ['1', '2', '3'] + cif['_ma_protocol_step.method_type'] = [ + 'coevolution MSA', + 'template search', + 'modeling', + ] + + # Details of the metrics use to assess model confidence. + cif['_ma_qa_metric.id'] = ['1', '2'] + cif['_ma_qa_metric.name'] = ['pLDDT', 'pLDDT'] + # Accepted values are distance, energy, normalised score, other, zscore. + cif['_ma_qa_metric.type'] = ['pLDDT', 'pLDDT'] + cif['_ma_qa_metric.mode'] = ['global', 'local'] + cif['_ma_qa_metric.software_group_id'] = ['1', '1'] + + # Global model confidence metric value. + cif['_ma_qa_metric_global.ordinal_id'] = ['1'] + cif['_ma_qa_metric_global.model_id'] = ['1'] + cif['_ma_qa_metric_global.metric_id'] = ['1'] + global_plddt = np.mean( + [float(v) for v in old_cif['_atom_site.B_iso_or_equiv']] + ) + cif['_ma_qa_metric_global.metric_value'] = [f'{global_plddt:.2f}'] + + cif['_atom_type.symbol'] = sorted(set(old_cif['_atom_site.type_symbol'])) + + return cif diff --git a/alphafold/common/protein.py b/alphafold/common/protein.py index 6ea153584..7b7003ec3 100644 --- a/alphafold/common/protein.py +++ b/alphafold/common/protein.py @@ -13,11 +13,18 @@ # limitations under the License. """Protein data type.""" + +import collections import dataclasses +import functools import io -from typing import Any, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Tuple +from alphafold.common import mmcif_metadata from alphafold.common import residue_constants +from Bio.PDB import MMCIFParser from Bio.PDB import PDBParser +from Bio.PDB.mmcifio import MMCIFIO +from Bio.PDB.Structure import Structure import numpy as np from string import ascii_uppercase,ascii_lowercase @@ -30,6 +37,32 @@ PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. +# Data to fill the _chem_comp table when writing mmCIFs. +_CHEM_COMP: Mapping[str, Tuple[Tuple[str, str], ...]] = { + 'L-peptide linking': ( + ('ALA', 'ALANINE'), + ('ARG', 'ARGININE'), + ('ASN', 'ASPARAGINE'), + ('ASP', 'ASPARTIC ACID'), + ('CYS', 'CYSTEINE'), + ('GLN', 'GLUTAMINE'), + ('GLU', 'GLUTAMIC ACID'), + ('HIS', 'HISTIDINE'), + ('ILE', 'ISOLEUCINE'), + ('LEU', 'LEUCINE'), + ('LYS', 'LYSINE'), + ('MET', 'METHIONINE'), + ('PHE', 'PHENYLALANINE'), + ('PRO', 'PROLINE'), + ('SER', 'SERINE'), + ('THR', 'THREONINE'), + ('TRP', 'TRYPTOPHAN'), + ('TYR', 'TYROSINE'), + ('VAL', 'VALINE'), + ), + 'peptide linking': (('GLY', 'GLYCINE'),), +} + @dataclasses.dataclass(frozen=True) class Protein: @@ -66,27 +99,32 @@ def __post_init__(self): 'because these cannot be written to PDB format.') -def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: - """Takes a PDB string and constructs a Protein object. +def _from_bio_structure( + structure: Structure, chain_id: Optional[str] = None +) -> Protein: + """Takes a Biopython structure and creates a `Protein` instance. WARNING: All non-standard residue types will be converted into UNK. All non-standard atoms will be ignored. Args: - pdb_str: The contents of the pdb file - chain_id: If chain_id is specified (e.g. A), then only that chain - is parsed. Otherwise all chains are parsed. + structure: Structure from the Biopython library. + chain_id: If chain_id is specified (e.g. A), then only that chain is parsed. + Otherwise all chains are parsed. Returns: - A new `Protein` parsed from the pdb contents. + A new `Protein` created from the structure contents. + + Raises: + ValueError: If the number of models included in the structure is not 1. + ValueError: If insertion code is detected at a residue. """ - pdb_fh = io.StringIO(pdb_str) - parser = PDBParser(QUIET=True) - structure = parser.get_structure('none', pdb_fh) models = list(structure.get_models()) if len(models) != 1: raise ValueError( - f'Only single model PDBs are supported. Found {len(models)} models.') + 'Only single model PDBs/mmCIFs are supported. Found' + f' {len(models)} models.' + ) model = models[0] atom_positions = [] @@ -102,8 +140,9 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: for res in chain: if res.id[2] != ' ': raise ValueError( - f'PDB contains an insertion code at chain {chain.id} and residue ' - f'index {res.id[1]}. These are not supported.') + f'PDB/mmCIF contains an insertion code at chain {chain.id} and' + f' residue index {res.id[1]}. These are not supported.' + ) res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') restype_idx = residue_constants.restype_order.get( res_shortname, residue_constants.restype_num) @@ -140,6 +179,48 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: b_factors=np.array(b_factors)) +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a `Protein` object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If chain_id is specified (e.g. A), then only that chain is parsed. + Otherwise all chains are parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + with io.StringIO(pdb_str) as pdb_fh: + parser = PDBParser(QUIET=True) + structure = parser.get_structure(id='none', file=pdb_fh) + return _from_bio_structure(structure, chain_id) + + +def from_mmcif_string( + mmcif_str: str, chain_id: Optional[str] = None +) -> Protein: + """Takes a mmCIF string and constructs a `Protein` object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + mmcif_str: The contents of the mmCIF file + chain_id: If chain_id is specified (e.g. A), then only that chain is parsed. + Otherwise all chains are parsed. + + Returns: + A new `Protein` parsed from the mmCIF contents. + """ + with io.StringIO(mmcif_str) as mmcif_fh: + parser = MMCIFParser(QUIET=True) + structure = parser.get_structure(structure_id='none', filename=mmcif_fh) + return _from_bio_structure(structure, chain_id) + + def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: chain_end = 'TER' return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' @@ -279,3 +360,223 @@ def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1, chain_index=chain_index, b_factors=b_factors) + + +def to_mmcif( + prot: Protein, + file_id: str, + model_type: str, +) -> str: + """Converts a `Protein` instance to an mmCIF string. + + WARNING 1: The _entity_poly_seq is filled with unknown (UNK) residues for any + missing residue indices in the range from min(1, min(residue_index)) to + max(residue_index). E.g. for a protein object with positions for residues + 2 (MET), 3 (LYS), 6 (GLY), this method would set the _entity_poly_seq to: + 1 UNK + 2 MET + 3 LYS + 4 UNK + 5 UNK + 6 GLY + This is done to preserve the residue numbering. + + WARNING 2: Converting ground truth mmCIF file to Protein and then back to + mmCIF using this method will convert all non-standard residue types to UNK. + If you need this behaviour, you need to store more mmCIF metadata in the + Protein object (e.g. all fields except for the _atom_site loop). + + WARNING 3: Converting ground truth mmCIF file to Protein and then back to + mmCIF using this method will not retain the original chain indices. + + WARNING 4: In case of multiple identical chains, they are assigned different + `_atom_site.label_entity_id` values. + + Args: + prot: A protein to convert to mmCIF string. + file_id: The file ID (usually the PDB ID) to be used in the mmCIF. + model_type: 'Multimer' or 'Monomer'. + + Returns: + A valid mmCIF string. + + Raises: + ValueError: If aminoacid types array contains entries with too many protein + types. + """ + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + chain_index = prot.chain_index.astype(np.int32) + b_factors = prot.b_factors + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + # We count unknown residues as protein residues. + for entity_id in np.unique(chain_index): # np.unique gives sorted output. + chain_ids[entity_id] = _int_id_to_str_id(entity_id + 1) + + mmcif_dict = collections.defaultdict(list) + + mmcif_dict['data_'] = file_id.upper() + mmcif_dict['_entry.id'] = file_id.upper() + + label_asym_id_to_entity_id = {} + # Entity and chain information. + for entity_id, chain_id in chain_ids.items(): + # Add all chain information to the _struct_asym table. + label_asym_id_to_entity_id[str(chain_id)] = str(entity_id) + mmcif_dict['_struct_asym.id'].append(chain_id) + mmcif_dict['_struct_asym.entity_id'].append(str(entity_id)) + # Add information about the entity to the _entity_poly table. + mmcif_dict['_entity_poly.entity_id'].append(str(entity_id)) + mmcif_dict['_entity_poly.type'].append(residue_constants.PROTEIN_CHAIN) + mmcif_dict['_entity_poly.pdbx_strand_id'].append(chain_id) + # Generate the _entity table. + mmcif_dict['_entity.id'].append(str(entity_id)) + mmcif_dict['_entity.type'].append(residue_constants.POLYMER_CHAIN) + + # Add the residues to the _entity_poly_seq table. + for entity_id, (res_ids, aas) in _get_entity_poly_seq( + aatype, residue_index, chain_index + ).items(): + for res_id, aa in zip(res_ids, aas): + mmcif_dict['_entity_poly_seq.entity_id'].append(str(entity_id)) + mmcif_dict['_entity_poly_seq.num'].append(str(res_id)) + mmcif_dict['_entity_poly_seq.mon_id'].append( + residue_constants.resnames[aa] + ) + + # Populate the chem comp table. + for chem_type, chem_comp in _CHEM_COMP.items(): + for chem_id, chem_name in chem_comp: + mmcif_dict['_chem_comp.id'].append(chem_id) + mmcif_dict['_chem_comp.type'].append(chem_type) + mmcif_dict['_chem_comp.name'].append(chem_name) + + # Add all atom sites. + atom_index = 1 + for i in range(aatype.shape[0]): + res_name_3 = residue_constants.resnames[aatype[i]] + if aatype[i] <= len(residue_constants.restypes): + atom_names = residue_constants.atom_types + else: + raise ValueError( + 'Amino acid types array contains entries with too many protein types.' + ) + for atom_name, pos, mask, b_factor in zip( + atom_names, atom_positions[i], atom_mask[i], b_factors[i] + ): + if mask < 0.5: + continue + type_symbol = residue_constants.atom_id_to_type(atom_name) + + mmcif_dict['_atom_site.group_PDB'].append('ATOM') + mmcif_dict['_atom_site.id'].append(str(atom_index)) + mmcif_dict['_atom_site.type_symbol'].append(type_symbol) + mmcif_dict['_atom_site.label_atom_id'].append(atom_name) + mmcif_dict['_atom_site.label_alt_id'].append('.') + mmcif_dict['_atom_site.label_comp_id'].append(res_name_3) + mmcif_dict['_atom_site.label_asym_id'].append(chain_ids[chain_index[i]]) + mmcif_dict['_atom_site.label_entity_id'].append( + label_asym_id_to_entity_id[chain_ids[chain_index[i]]] + ) + mmcif_dict['_atom_site.label_seq_id'].append(str(residue_index[i])) + mmcif_dict['_atom_site.pdbx_PDB_ins_code'].append('.') + mmcif_dict['_atom_site.Cartn_x'].append(f'{pos[0]:.3f}') + mmcif_dict['_atom_site.Cartn_y'].append(f'{pos[1]:.3f}') + mmcif_dict['_atom_site.Cartn_z'].append(f'{pos[2]:.3f}') + mmcif_dict['_atom_site.occupancy'].append('1.00') + mmcif_dict['_atom_site.B_iso_or_equiv'].append(f'{b_factor:.2f}') + mmcif_dict['_atom_site.auth_seq_id'].append(str(residue_index[i])) + mmcif_dict['_atom_site.auth_asym_id'].append(chain_ids[chain_index[i]]) + mmcif_dict['_atom_site.pdbx_PDB_model_num'].append('1') + + atom_index += 1 + + metadata_dict = mmcif_metadata.add_metadata_to_mmcif(mmcif_dict, model_type) + mmcif_dict.update(metadata_dict) + + return _create_mmcif_string(mmcif_dict) + + +@functools.lru_cache(maxsize=256) +def _int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def _get_entity_poly_seq( + aatypes: np.ndarray, residue_indices: np.ndarray, chain_indices: np.ndarray +) -> Dict[int, Tuple[List[int], List[int]]]: + """Constructs gapless residue index and aatype lists for each chain. + + Args: + aatypes: A numpy array with aatypes. + residue_indices: A numpy array with residue indices. + chain_indices: A numpy array with chain indices. + + Returns: + A dictionary mapping chain indices to a tuple with list of residue indices + and a list of aatypes. Missing residues are filled with UNK residue type. + """ + if ( + aatypes.shape[0] != residue_indices.shape[0] + or aatypes.shape[0] != chain_indices.shape[0] + ): + raise ValueError( + 'aatypes, residue_indices, chain_indices must have the same length.' + ) + + # Group the present residues by chain index. + present = collections.defaultdict(list) + for chain_id, res_id, aa in zip(chain_indices, residue_indices, aatypes): + present[chain_id].append((res_id, aa)) + + # Add any missing residues (from 1 to the first residue and for any gaps). + entity_poly_seq = {} + for chain_id, present_residues in present.items(): + present_residue_indices = set([x[0] for x in present_residues]) + min_res_id = min(present_residue_indices) # Could be negative. + max_res_id = max(present_residue_indices) + + new_residue_indices = [] + new_aatypes = [] + present_index = 0 + for i in range(min(1, min_res_id), max_res_id + 1): + new_residue_indices.append(i) + if i in present_residue_indices: + new_aatypes.append(present_residues[present_index][1]) + present_index += 1 + else: + new_aatypes.append(20) # Unknown amino acid type. + entity_poly_seq[chain_id] = (new_residue_indices, new_aatypes) + return entity_poly_seq + + +def _create_mmcif_string(mmcif_dict: Dict[str, Any]) -> str: + """Converts mmCIF dictionary into mmCIF string.""" + mmcifio = MMCIFIO() + mmcifio.set_dict(mmcif_dict) + + with io.StringIO() as file_handle: + mmcifio.save(file_handle) + return file_handle.getvalue() diff --git a/alphafold/common/residue_constants.py b/alphafold/common/residue_constants.py index 46ff9a6bc..604a33bd2 100644 --- a/alphafold/common/residue_constants.py +++ b/alphafold/common/residue_constants.py @@ -17,7 +17,7 @@ import collections import functools import os -from typing import List, Mapping, Tuple +from typing import Final, List, Mapping, Tuple import numpy as np import tree @@ -497,6 +497,7 @@ def make_bond_key(atom1_name, atom2_name): atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} atom_type_num = len(atom_types) # := 37. + # A compact atom encoding with 14 columns # pylint: disable=line-too-long # pylint: disable=bad-whitespace @@ -608,6 +609,35 @@ def sequence_to_onehot( 'V': 'VAL', } +PROTEIN_CHAIN: Final[str] = 'polypeptide(L)' +POLYMER_CHAIN: Final[str] = 'polymer' + + +def atom_id_to_type(atom_id: str) -> str: + """Convert atom ID to atom type, works only for standard protein residues. + + Args: + atom_id: Atom ID to be converted. + + Returns: + String corresponding to atom type. + + Raises: + ValueError: If atom ID not recognized. + """ + + if atom_id.startswith('C'): + return 'C' + elif atom_id.startswith('N'): + return 'N' + elif atom_id.startswith('O'): + return 'O' + elif atom_id.startswith('H'): + return 'H' + elif atom_id.startswith('S'): + return 'S' + raise ValueError('Atom ID not recognized.') + # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple # 1-to-1 mapping of 3 letter names to one letter names. The latter contains diff --git a/run_alphafold.py b/run_alphafold.py index 0d89bfb47..d77b2ea97 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -171,6 +171,63 @@ def _jnp_to_np(output: Dict[str, Any]) -> Dict[str, Any]: return output +def _save_confidence_json_file( + plddt: np.ndarray, output_dir: str, model_name: str +) -> None: + confidence_json = confidence.confidence_json(plddt) + + # Save the confidence json. + confidence_json_output_path = os.path.join( + output_dir, f'confidence_{model_name}.json' + ) + with open(confidence_json_output_path, 'w') as f: + f.write(confidence_json) + + +def _save_mmcif_file( + prot: protein.Protein, + output_dir: str, + model_name: str, + file_id: str, + model_type: str, +) -> None: + """Crate mmCIF string and save to a file. + + Args: + prot: Protein object. + output_dir: Directory to which files are saved. + model_name: Name of a model. + file_id: The file ID (usually the PDB ID) to be used in the mmCIF. + model_type: Monomer or multimer. + """ + + mmcif_string = protein.to_mmcif(prot, file_id, model_type) + + # Save the MMCIF. + mmcif_output_path = os.path.join(output_dir, f'{model_name}.cif') + with open(mmcif_output_path, 'w') as f: + f.write(mmcif_string) + + +def _save_pae_json_file( + pae: np.ndarray, max_pae: float, output_dir: str, model_name: str +) -> None: + """Check prediction result for PAE data and save to a JSON file if present. + + Args: + pae: The n_res x n_res PAE array. + max_pae: The maximum possible PAE value. + output_dir: Directory to which files are saved. + model_name: Name of a model. + """ + pae_json = confidence.pae_json(pae, max_pae) + + # Save the PAE json. + pae_json_output_path = os.path.join(output_dir, f'pae_{model_name}.json') + with open(pae_json_output_path, 'w') as f: + f.write(pae_json) + + def predict_structure( fasta_path: str, fasta_name: str, @@ -180,7 +237,9 @@ def predict_structure( amber_relaxer: relax.AmberRelaxation, benchmark: bool, random_seed: int, - models_to_relax: ModelsToRelax): + models_to_relax: ModelsToRelax, + model_type: str, +): """Predicts structure using AlphaFold for the given sequence.""" logging.info('Predicting %s', fasta_name) timings = {} @@ -266,6 +325,14 @@ def predict_structure( with open(unrelaxed_pdb_path, 'w') as f: f.write(unrelaxed_pdbs[model_name]) + _save_mmcif_file( + prot=unrelaxed_protein, + output_dir=output_dir, + model_name=f'unrelaxed_{model_name}', + file_id=str(model_index), + model_type=model_type, + ) + # Rank by model confidence. ranked_order = [ model_name for model_name, confidence in @@ -297,6 +364,15 @@ def predict_structure( with open(relaxed_output_path, 'w') as f: f.write(relaxed_pdb_str) + relaxed_protein = protein.from_pdb_string(relaxed_pdb_str) + _save_mmcif_file( + prot=relaxed_protein, + output_dir=output_dir, + model_name=f'relaxed_{model_name}', + file_id='0', + model_type=model_type, + ) + # Write out relaxed PDBs in rank order. for idx, model_name in enumerate(ranked_order): ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb') @@ -306,6 +382,19 @@ def predict_structure( else: f.write(unrelaxed_pdbs[model_name]) + if model_name in relaxed_pdbs: + protein_instance = protein.from_pdb_string(relaxed_pdbs[model_name]) + else: + protein_instance = protein.from_pdb_string(unrelaxed_pdbs[model_name]) + + _save_mmcif_file( + prot=protein_instance, + output_dir=output_dir, + model_name=f'ranked_{idx}', + file_id=str(idx), + model_type=model_type, + ) + ranking_output_path = os.path.join(output_dir, 'ranking_debug.json') with open(ranking_output_path, 'w') as f: label = 'iptm+ptm' if 'iptm' in prediction_result else 'plddts' @@ -342,6 +431,7 @@ def main(argv): should_be_set=not use_small_bfd) run_multimer_system = 'multimer' in FLAGS.model_preset + model_type = 'Multimer' if run_multimer_system else 'Monomer' _check_flag('pdb70_database_path', 'model_preset', should_be_set=not run_multimer_system) _check_flag('pdb_seqres_database_path', 'model_preset', @@ -449,7 +539,9 @@ def main(argv): amber_relaxer=amber_relaxer, benchmark=FLAGS.benchmark, random_seed=random_seed, - models_to_relax=FLAGS.models_to_relax) + models_to_relax=FLAGS.models_to_relax, + model_type=model_type, + ) if __name__ == '__main__': From c0cb62859b99ae6368195f3839cb6e2c6bcdb594 Mon Sep 17 00:00:00 2001 From: Poko18 Date: Tue, 19 Mar 2024 10:50:30 +0100 Subject: [PATCH 3/3] Update AlphaFold version in metadata --- alphafold/common/mmcif_metadata.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/alphafold/common/mmcif_metadata.py b/alphafold/common/mmcif_metadata.py index b5b725d5f..d3bbbcccd 100644 --- a/alphafold/common/mmcif_metadata.py +++ b/alphafold/common/mmcif_metadata.py @@ -15,7 +15,6 @@ """mmCIF metadata.""" from typing import Mapping, Sequence -from alphafold import version import numpy as np @@ -162,7 +161,7 @@ def add_metadata_to_mmcif( cif['_ma_model_list.model_name'] = ['Top ranked model'] cif['_ma_model_list.model_group_name'] = [ - f'AlphaFold {model_type} v{version.__version__} model' + f'AlphaFold {model_type} v2.3.2 model' ] cif['_ma_model_list.data_id'] = ['1'] cif['_ma_model_list.model_type'] = ['Ab initio model'] @@ -170,7 +169,7 @@ def add_metadata_to_mmcif( # Software used. cif['_software.pdbx_ordinal'] = ['1'] cif['_software.name'] = ['AlphaFold'] - cif['_software.version'] = [f'v{version.__version__}'] + cif['_software.version'] = [f'v2.3.2'] cif['_software.type'] = ['package'] cif['_software.description'] = ['Structure prediction'] cif['_software.classification'] = ['other']