Skip to content

Commit

Permalink
working on subclusters
Browse files Browse the repository at this point in the history
  • Loading branch information
frapercan committed Mar 21, 2024
1 parent 1611800 commit 5eaee67
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 52 deletions.
8 changes: 4 additions & 4 deletions protein_metamorphisms_is/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ limit: 100
# PDB Extraction
resolution_threshold: 2.5
server: "https://files.wwpdb.org/"
pdb_path: "/home/bioxaxi/data/pdb"
pdb_chains_path: "/home/bioxaxi/data/chains"
pdb_path: "/home/frapercan/data/pdb"
pdb_chains_path: "/home/frapercan/data/chains"
file_format: "mmCif"
allow_multiple_chain_models: False # NMR

## Operations
constants: "config/constants.yaml"

# Sequence Clustering
fasta_path: "/home/bioxaxi/data/complete.fasta"
cdhit_out_path: "/home/bioxaxi/data/out"
fasta_path: "/home/frapercan/data/complete.fasta"
cdhit_out_path: "/home/frapercan/data/out"
sequence_identity_threshold: 0.95
alignment_coverage: 0.95
memory_usage: 25000
Expand Down
10 changes: 7 additions & 3 deletions protein_metamorphisms_is/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from protein_metamorphisms_is.information_system.uniprot import UniProtExtractor
from protein_metamorphisms_is.operations.cdhit import CDHit
from protein_metamorphisms_is.operations.embeddings import EmbeddingManager
from protein_metamorphisms_is.operations.optics import OpticsClustering
from protein_metamorphisms_is.operations.structural_alignment import StructuralAlignmentManager


def main(config_path="config/config.yaml"):
conf = read_yaml_config(config_path)
# UniProtExtractor(conf).start()
# PDBExtractor(conf).start()
# CDHit(conf).start()
UniProtExtractor(conf).start()
PDBExtractor(conf).start()
EmbeddingManager(conf).start()
CDHit(conf).start()
OpticsClustering(conf).start()


# StructuralAlignmentManager(conf).start()


Expand Down
28 changes: 15 additions & 13 deletions protein_metamorphisms_is/operations/cdhit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from protein_metamorphisms_is.operations.base.operator import OperatorBase
from protein_metamorphisms_is.sql.model import PDBChains, Cluster
from protein_metamorphisms_is.sql.model import PDBChains, Cluster, ClusterEntry
from pycdhit import cd_hit, read_clstr


Expand Down Expand Up @@ -94,13 +94,6 @@ def create_fasta(self, chains):
fasta_file.write(f"{header}\n{sequence}\n")

def cluster(self):
"""
Execute the CD-HIT algorithm for sequence clustering.
Runs the CD-HIT algorithm on the prepared FASTA file, then reads the output cluster file to store the clustering
results in the database. Configuration parameters such as sequence identity threshold, alignment coverage, accurate mode and
memory usage are used to control the CD-HIT execution.
"""
fasta_file_path = self.conf.get('fasta_path', './complete.fasta')
cdhit_out_path = self.conf.get('cdhit_out_path', './out.clstr')

Expand All @@ -125,15 +118,24 @@ def cluster(self):

self.logger.info(f"Reading CD-HIT output from {cdhit_out_path}.clstr")
df_clstr = read_clstr(f"{cdhit_out_path}.clstr")

# Asociar cada cadena con su cluster correspondiente y marcar si es representativa
clusters_dict = {}
for _, row in df_clstr.iterrows():
chain_id = row["identifier"]
cluster = Cluster(
cluster_id=row['cluster'],
pdb_chain_id=chain_id,
cluster_id = row['cluster']
if cluster_id not in clusters_dict:
cluster = Cluster()
self.session.add(cluster)
self.session.flush() # Esto es para obtener el id generado para el cluster
clusters_dict[cluster_id] = cluster.id

cluster_entry = ClusterEntry(
cluster_id=clusters_dict[cluster_id],
pdb_chain_id=row["identifier"],
is_representative=row['is_representative'],
sequence_length=row['size'],
identity=row['identity']
)
self.session.add(cluster)
self.session.add(cluster_entry)
self.session.commit()
self.logger.info("CD-HIT clustering data stored in the database")
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ def embedding_task(session,chains,model_name, embedding_type_id):
sequence_processed = "<AA2fold> " + sequence_processed if sequence_processed.isupper() else "<fold2AA> " + sequence_processed
inputs = tokenizer(sequence_processed, return_tensors="pt", padding=True, truncation=True, max_length=512, add_special_tokens=True).to(device)

# Generación de embeddings
outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy().tolist()[0] # Convertir a lista para almacenamiento
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy().tolist()[0]

# Crear y guardar el embedding en la base de datos
embedding_entry = ChainEmbedding(
pdb_chain_id=chain.id, # Asume que `chain` tiene un atributo `id` referenciando a `PDBChains.id`
embedding_type_id=embedding_type_id, # Asegúrate de tener este id desde tu lógica de negocio
pdb_chain_id=chain.id,
embedding_type_id=embedding_type_id,
embedding=embeddings
)
session.add(embedding_entry)
Expand Down
2 changes: 0 additions & 2 deletions protein_metamorphisms_is/operations/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,9 @@ def start(self):
try:
self.logger.info("Starting structural alignment process.")
chains = self.load_chains()
print(len(chains))
self.fetch_models_info()

for type in self.types.values():
print(type)
module, model_name, embedding_type_id = type['module'], type['model_name'], type['id']
module.embedding_task(self.session, chains, model_name, embedding_type_id)

Expand Down
117 changes: 117 additions & 0 deletions protein_metamorphisms_is/operations/optics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from sklearn.cluster import OPTICS
import numpy as np

from protein_metamorphisms_is.operations.base.operator import OperatorBase
from protein_metamorphisms_is.sql.model import Subcluster, SubclusterEntry, ClusterEntry, PDBChains, Cluster, \
ChainEmbedding


class OpticsClustering(OperatorBase):
def __init__(self, conf):
super().__init__(conf)
self.logger.info("OpticsClustering instance created")

def start(self):
try:
self.logger.info("Starting OPTICS clustering process")
cluster_ids = self.get_cluster_ids()
for cluster_id in cluster_ids:
embeddings, pdb_chain_ids = self.load_embeddings(cluster_id)
if embeddings.size == 0: # Verificar si no hay embeddings
continue
cluster_labels = self.cluster_embeddings(embeddings)
self.store_subclusters(cluster_id, cluster_labels, pdb_chain_ids)
self.logger.info("Clustering process completed successfully")
except Exception as e:
self.logger.error(f"Error during clustering process: {e}")
raise

def get_cluster_ids(self):
return [cluster.id for cluster in self.session.query(Cluster.id).all()]

def load_embeddings(self, cluster_id):
# Obtiene los embeddings y pdb_chain_ids para un cluster_id específico
entries = self.session.query(
ClusterEntry.pdb_chain_id,
ChainEmbedding.embedding
).join(
PDBChains, ClusterEntry.pdb_chain_id == PDBChains.id
).join(
ChainEmbedding, PDBChains.id == ChainEmbedding.pdb_chain_id
).filter(
ClusterEntry.cluster_id == cluster_id
).all()

if not entries:
return np.array([]), []

# Asumiendo que 'embedding' es una lista o un numpy array ya compatible
embeddings = np.array([entry.embedding for entry in entries])
pdb_chain_ids = [entry.pdb_chain_id for entry in entries]

return embeddings, pdb_chain_ids

def cluster_embeddings(self, embeddings):
# Ajusta min_samples basado en el número de muestras disponibles
min_samples = min(5, len(embeddings) - 1) # Asegura que min_samples nunca sea mayor que el número de muestras
if min_samples < 2: # OPTICS requiere al menos dos muestras para funcionar
return np.array(
[-1] * len(embeddings)) # Considera todos los puntos como ruido si no hay suficientes para clustering

optics = OPTICS(min_samples=min_samples, xi=0.05, min_cluster_size=0.05)
optics.fit(embeddings)
return optics.labels_

def store_subclusters(self, cluster_id, cluster_labels, pdb_chain_ids):
# Diccionario para almacenar los subclusters y sus entradas
subclusters_dict = {}

# Iterar sobre cada etiqueta y pdb_chain_id juntos
for label, pdb_chain_id in zip(cluster_labels, pdb_chain_ids):
if label == -1: # OPTICS puede marcar algunos puntos como ruido, los ignoramos
continue

if label not in subclusters_dict:
subclusters_dict[label] = {
"entries": [],
"max_length": 0,
"representative_id": None
}

# Consulta la secuencia para el pdb_chain_id actual y calcula su longitud
sequence = self.session.query(PDBChains.sequence).filter_by(id=pdb_chain_id).scalar()
sequence_length = len(sequence)

# Añadir la entrada al subcluster
subclusters_dict[label]["entries"].append((pdb_chain_id, sequence_length))

# Verificar si esta entrada tiene la secuencia de mayor longitud
if sequence_length > subclusters_dict[label]["max_length"]:
subclusters_dict[label]["max_length"] = sequence_length
subclusters_dict[label]["representative_id"] = pdb_chain_id

# Ahora, almacenar los subclusters y sus entradas en la base de datos
for label, subcluster_info in subclusters_dict.items():
subcluster = Subcluster(
cluster_id=cluster_id,
# Añadir otros campos necesarios aquí
)
self.session.add(subcluster)
self.session.flush() # Obtener el id del subcluster insertado

# Marcar la entrada con la secuencia de mayor longitud como representativa
for pdb_chain_id, sequence_length in subcluster_info["entries"]:
is_representative = (pdb_chain_id == subcluster_info["representative_id"])
subcluster_entry = SubclusterEntry(
subcluster_id=subcluster.id,
pdb_chain_id=pdb_chain_id,
is_representative=is_representative,
sequence_length=sequence_length,
# Añadir otros campos necesarios aquí
)
self.session.add(subcluster_entry)

self.session.commit()



75 changes: 50 additions & 25 deletions protein_metamorphisms_is/sql/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

from sqlalchemy import (Column, Integer, String, Date, ForeignKey, DateTime,
func, Float, Boolean)
from sqlalchemy.dialects.postgresql import ARRAY
Expand Down Expand Up @@ -192,6 +194,8 @@ class PDBChains(Base):

pdb_reference = relationship("PDBReference", back_populates="pdb_chains")
embeddings = relationship("ChainEmbedding", back_populates="pdb_chain")
# Añade esta línea para definir la relación con SubclusterEntry
subcluster_entries = relationship("SubclusterEntry", back_populates="pdb_chain")


class ChainEmbedding(Base):
Expand All @@ -207,41 +211,62 @@ class ChainEmbedding(Base):
embedding_type = relationship("EmbeddingType", back_populates="chain_embeddings")



class Cluster(Base):
"""
Represents a cluster of protein chains, where each cluster is formed by chains with significant similarity,
determined using the cd-hit tool.
__tablename__ = 'clusters'
id = Column(Integer, primary_key=True)
description = Column(String) # Ejemplo de campo adicional
created_at = Column(DateTime, default=datetime.now)

This class is instrumental in grouping protein chains that are highly similar to each other, aiding in the
identification of common structures and functions.
# Relación con ClusterEntries
entries = relationship("ClusterEntry", back_populates="cluster")

# Relación con Subcluster
subclusters = relationship("Subcluster", back_populates="cluster")

Attributes:
id (int): Unique identifier for each cluster.
pdb_chain_id (int): Foreign key referencing the 'PDBChains' entity. It is used to identify the specific protein
chain in the PDB database associated with this cluster.
cluster_id (int): Identifier of the cluster, typically a unique string representing this specific group of
protein chains.
is_representative (Boolean): Indicates whether the cluster is representative of a larger set of similar chains.
'True' for yes, 'False' for no.
sequence_length (int): Average length of the sequences of the chains in the cluster.
identity (Float): Value representing the average sequence identity within the cluster, usually a percentage
indicating how similar the chains are within the group.
The relationship with 'PDBChains' allows each cluster to be connected to its specific chain in the PDB database,
providing a direct link to detailed structural information.
"""
__tablename__ = 'clusters'

class ClusterEntry(Base):
__tablename__ = 'cluster_entries'
id = Column(Integer, primary_key=True)
cluster_id = Column(Integer, ForeignKey('clusters.id'))
pdb_chain_id = Column(Integer, ForeignKey('pdb_chains.id'))
cluster_id = Column(Integer)
is_representative = Column(Boolean)
sequence_length = Column(Integer)
identity = Column(Float)
complexity_level_id = Column(Integer, ForeignKey('structural_complexity_levels.id'))
created_at = Column(DateTime, default=datetime.now)

# Relaciones con Cluster y PDBChains
cluster = relationship("Cluster", back_populates="entries")
pdb_chain = relationship("PDBChains") # Asegúrate de definir esta relación en PDBChains si aún no existe


class Subcluster(Base):
__tablename__ = 'subclusters'
id = Column(Integer, primary_key=True)
cluster_id = Column(Integer, ForeignKey('clusters.id'))
description = Column(String) # Una descripción opcional del subcluster
created_at = Column(DateTime, default=datetime.now)

# Relación de vuelta a Cluster
cluster = relationship("Cluster", back_populates="subclusters")

# Relación con SubclusterEntry
entries = relationship("SubclusterEntry", back_populates="subcluster")


class SubclusterEntry(Base):
__tablename__ = 'subcluster_entries'
id = Column(Integer, primary_key=True)
subcluster_id = Column(Integer, ForeignKey('subclusters.id'))
pdb_chain_id = Column(Integer, ForeignKey('pdb_chains.id'))
is_representative = Column(Boolean)
sequence_length = Column(Integer)
identity = Column(Float)
created_at = Column(DateTime, default=datetime.now)

# Relaciones con Subcluster y PDBChains
subcluster = relationship("Subcluster", back_populates="entries")
pdb_chain = relationship("PDBChains", back_populates="subcluster_entries")

complexity_level = relationship("StructuralComplexityLevel", backref="clusters")


class StructuralComplexityLevel(Base):
Expand Down

0 comments on commit 5eaee67

Please sign in to comment.