Skip to content

Commit

Permalink
embeddings storage
Browse files Browse the repository at this point in the history
  • Loading branch information
frapercan committed Mar 11, 2024
1 parent 4f8948b commit 8853b25
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 4 deletions.
7 changes: 6 additions & 1 deletion protein_metamorphisms_is/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ DB_NAME: BioData

## Information System
# Uniprot Extraction
search_criteria: '(structure_3d:true)'
search_criteria: '(organism_id:9615) AND (structure_3d:true)'
limit: 100

# PDB Extraction
Expand Down Expand Up @@ -45,3 +45,8 @@ structural_alignment:
batch_size: 1000
task_timeout: 20

# Embedding
embedding:
types:
# - 1
- 2
9 changes: 9 additions & 0 deletions protein_metamorphisms_is/config/constants.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ structural_complexity_levels:
- name: "Secondary Structures"
description: Secondary structures refer to the local spatial arrangement of the protein's backbone, excluding the side chains of the amino acids. The two most common types of secondary structures are alpha-helices and beta-sheets. Alpha-helices are right-handed coils stabilized by hydrogen bonds between the backbone atoms, while beta-sheets consist of two or more strands aligned next to each other, forming a sheet-like structure also stabilized by hydrogen bonding. These structures are fundamental components of a protein's overall three-dimensional conformation and play critical roles in defining its function. Secondary structures are formed as a result of hydrogen bonds between the amide hydrogen and carbonyl oxygen atoms in the peptide backbone, and their formation is driven by the protein's primary sequence.

embedding_types:
- name: "ESM"
description: "Evolutionary Scale Modeling (ESM) embeddings are designed to capture the evolutionary information of protein sequences, utilizing deep learning to generate representations that enhance sequence analysis and prediction tasks."
task_name: "esm"
model_name: facebook/esm2_t6_8M_UR50D
- name: "Prost-T5"
description: "Prot-T5 embeddings leverage the capabilities of the T5 (Text-to-Text Transfer Transformer) model adapted for protein sequences, offering advanced sequence representation by considering both local and global sequence features."
task_name: prost_t5
model_name: Rostlab/ProstT5
2 changes: 2 additions & 0 deletions protein_metamorphisms_is/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from protein_metamorphisms_is.information_system.pdb import PDBExtractor
from protein_metamorphisms_is.information_system.uniprot import UniProtExtractor
from protein_metamorphisms_is.operations.cdhit import CDHit
from protein_metamorphisms_is.operations.embeddings import EmbeddingManager
from protein_metamorphisms_is.operations.structural_alignment import StructuralAlignmentManager


Expand All @@ -10,6 +11,7 @@ def main(config_path="config/config.yaml"):
UniProtExtractor(conf).start()
PDBExtractor(conf).start()
CDHit(conf).start()
EmbeddingManager(conf).start()
StructuralAlignmentManager(conf).start()


Expand Down
23 changes: 22 additions & 1 deletion protein_metamorphisms_is/operations/base/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from protein_metamorphisms_is.helpers.logger.logger import setup_logger
from protein_metamorphisms_is.sql.base.database_manager import DatabaseManager
from protein_metamorphisms_is.sql.constants import handle_structural_complexity_levels, \
handle_structural_alignment_types
handle_structural_alignment_types, handle_embedding_types
from protein_metamorphisms_is.sql.model import PDBChains


class OperatorBase(ABC):
Expand All @@ -19,6 +20,8 @@ def __init__(self, conf):
constants = yaml.safe_load(open(conf['constants']))
handle_structural_complexity_levels(self.session, constants)
handle_structural_alignment_types(self.session, constants)
handle_embedding_types(self.session, constants)


@abstractmethod
def start(self):
Expand All @@ -29,3 +32,21 @@ def start(self):
the specific data operation logic for each bioinformatics data source.
"""
pass


def load_chains(self):
"""
Retrieve protein chain data from the database.
Fetches all PDBChains records from the database. The method can be configured to include or exclude multiple chain
models based on the 'allow_multiple_chain_models' (NMR samples) configuration.
Returns:
list: A list of PDBChains objects representing protein chains.
"""
self.logger.info("Loading protein chains from the database")
if not self.conf.get("allow_multiple_chain_models"):
chains = self.session.query(PDBChains).filter(PDBChains.model == 0).all()
else:
chains = self.session.query(PDBChains).all()
return chains
Empty file.
31 changes: 31 additions & 0 deletions protein_metamorphisms_is/operations/embedding_tasks/esm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from transformers import AutoTokenizer, EsmModel
import torch


def embedding_task(session,chains,module,model_name):
# Verificar si CUDA está disponible
if not torch.cuda.is_available():
raise Exception("CUDA is not available. This script requires a GPU with CUDA.")

# Configurar el dispositivo
device = torch.device("cuda")

# Cargar el tokenizador y el modelo
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name).to(device)

# Preparar la secuencia

with torch.no_grad(): # Desactivar el cálculo de gradientes
for chain in chains:
tokens = tokenizer(chain.sequence, return_tensors="pt", truncation=True, padding=True)
# Mover los tokens al dispositivo correcto
tokens = {k: v.to(device) for k, v in tokens.items()}

# Obtener los embeddings del modelo
outputs = model(**tokens)
embeddings = outputs.last_hidden_state
# embeddings es un tensor de shape (batch_size, sequence_length, hidden_size)

print(embeddings.shape)

25 changes: 25 additions & 0 deletions protein_metamorphisms_is/operations/embedding_tasks/prost_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from transformers import T5Tokenizer, T5EncoderModel
import re
import torch


def embedding_task(session,chains,module,model_name):
if not torch.cuda.is_available():
raise Exception("CUDA is not available. This script requires a GPU with CUDA.")

device = torch.device("cuda")
model_name = model_name
tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False)
model = T5EncoderModel.from_pretrained(model_name).to(device)
model.eval()

with torch.no_grad():
for chain in chains:
sequence_processed = " ".join(list(re.sub(r"[UZOB]", "X", chain.sequence)))
sequence_processed = "<AA2fold> " + sequence_processed if sequence_processed.isupper() else "<fold2AA> " + sequence_processed
inputs = tokenizer(sequence_processed, return_tensors="pt", padding=True, truncation=True,
max_length=512, add_special_tokens=True).to(device)

# Generación de embeddings
outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask)
embeddings = outputs.last_hidden_state.mean(dim=1)
79 changes: 79 additions & 0 deletions protein_metamorphisms_is/operations/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import importlib
import multiprocessing
from datetime import datetime, timedelta
from multiprocessing import Pool

from sqlalchemy.orm import aliased


from protein_metamorphisms_is.operations.base.operator import OperatorBase
from protein_metamorphisms_is.sql.model import PDBChains, Cluster, PDBReference, StructuralAlignmentQueue, \
StructuralAlignmentType, StructuralAlignmentResults, EmbeddingType


class EmbeddingManager(OperatorBase):
"""
lorem ipsum
Attributes:
conf (dict): Configuration of the instance, including database connections and operational settings.
"""

def __init__(self, conf):
"""
Initializes an instance of `StructuralAlignmentManager` with configuration settings.
Args:
conf (dict): Configuration parameters, including database connections and operational settings.
"""
super().__init__(conf)
self.logger.info("Secuence Embedding Manager instance created.")

def fetch_models_info(self):
"""
Fetches and prepares alignment task modules based on the configuration.
This method dynamically imports alignment task modules specified in the configuration and stores
references to these modules in a dictionary for later use in the alignment process.
"""
embedding_types = self.session.query(EmbeddingType).all()
self.types = {}
base_module_path = 'protein_metamorphisms_is.operations.embedding_tasks'

for type_obj in embedding_types:
if type_obj.id in self.conf['embedding']['types']:
# Construye el nombre completo del módulo
module_name = f"{base_module_path}.{type_obj.task_name}"
# Importa dinámicamente el módulo usando importlib
module = importlib.import_module(module_name)
# Almacena la referencia al módulo en el diccionario self.types
self.types[type_obj.id] = {'module': module, 'model_name' : type_obj.model_name}

print(self.types)

def start(self):
"""
Begin the structural alignment process.
This method manages the workflow of the alignment process, including loading clusters, executing alignments,
and handling any exceptions encountered during the process. Progress and errors are logged appropriately.
"""
try:
self.logger.info("Starting structural alignment process.")
chains = self.load_chains()
self.fetch_models_info()

for type in self.types.values():
print(type)
module, model = type['module'], type['model_name']
module.embedding_task(self.session,chains,module,model)



except Exception as e:
self.logger.error(f"Error during structural alignment process: {e}")
raise




Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_update_queue(self):

clusters_not_queued = self.session.query(Cluster).filter(
Cluster.id.notin_(queued_cluster_ids),
not Cluster.is_representative
Cluster.is_representative == False
).all()

self.logger.info(f"Found {len(clusters_not_queued)} clusters not in queue, adding to queue.")
Expand Down
14 changes: 13 additions & 1 deletion protein_metamorphisms_is/sql/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from protein_metamorphisms_is.sql.model import StructuralComplexityLevel, StructuralAlignmentType
from protein_metamorphisms_is.sql.model import StructuralComplexityLevel, StructuralAlignmentType, EmbeddingType


def handle_structural_complexity_levels(session, constants):
Expand Down Expand Up @@ -29,3 +29,15 @@ def handle_structural_alignment_types(session, constants):

# Comprometer los cambios en la base de datos
session.commit()


def handle_embedding_types(session, constants):
embedding_types = constants['embedding_types']

for type_data in embedding_types:
exists = session.query(EmbeddingType).filter_by(name=type_data['name']).first()
if not exists:
embedding_type = EmbeddingType(**type_data)
session.add(embedding_type)

session.commit()
20 changes: 20 additions & 0 deletions protein_metamorphisms_is/sql/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,26 @@ class StructuralAlignmentType(Base):
task_name = Column(String)


class EmbeddingType(Base):
"""
Represents a type of protein analysis embedding.
This class is designed to manage different embedding techniques used in protein sequence analysis, offering a structured way to categorize and store information about various embedding methods such as ESM and Prot-T5.
Attributes:
id (Integer): Unique identifier for each embedding type.
name (String): Unique name of the embedding type.
description (String): Detailed description of the embedding technique.
task_name (String): Name of the specific task associated with this embedding type, if applicable.
"""
__tablename__ = 'embedding_types'
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False, unique=True)
description = Column(String)
task_name = Column(String)
model_name = Column(String)


class StructuralAlignmentQueue(Base):
"""
Manages a queue of pending structural alignment tasks, overseeing their execution and monitoring.
Expand Down

0 comments on commit 8853b25

Please sign in to comment.