Skip to content

Commit

Permalink
Remove bio_embeddings dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
mheinzinger authored Aug 5, 2022
1 parent 228df80 commit b10ce20
Showing 1 changed file with 165 additions and 33 deletions.
198 changes: 165 additions & 33 deletions eat.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ def compute_err(self, n_bootstrap=1000):
n_skipped, n_bootstrap))
return {"ACCs": (acc, accs_btrap), "bACCs": (bAcc, bAccs_btrap), "F1": (f1, f1s_btrap)}


class Embedder():
# The following exemplifies how bio_embeddings compatibility can easily be included
# However, this is a bit slower which is why the Embedder() class below should be used for speedy inference
class BioEmbEmbedder():
def __init__(self):
from bio_embeddings.embed import ProtTransT5XLU50Embedder
self.embedder = ProtTransT5XLU50Embedder(half_model=True)

self.bs = 1000 # batch-size defines max-number of AAs per batch; lower this value if you run OOM

def write_embeddings(self, emb_p, embds):
with h5py.File(str(emb_p), "w") as hf:
for sequence_id, embedding in embds.items():
Expand All @@ -111,16 +113,95 @@ def write_embeddings(self, emb_p, embds):
def get_embeddings(self, id2seq):
fasta_ids, seqs = zip(*[(fasta_id, seq)
for fasta_id, seq in id2seq.items()])
print("Start generating embeddings. This process might take a few minutes.")
print("Start generating embeddings for {} proteins.".format(len(fasta_ids)) +
"This process might take a few minutes.")
start = time.time()
per_residue_embeddings = list(self.embedder.embed_many(list(seqs)))
per_residue_embeddings = list(self.embedder.embed_many(list(seqs)))
id2embd = { fasta_id: per_residue_embeddings[idx].mean(axis=0)
for idx, fasta_id in enumerate(list(fasta_ids))
}
print("Creating embeddings took: {:.4f}[s]".format(time.time()-start))
print("Creating per-protein embeddings took: {:.4f}[s]".format(time.time()-start))
return id2embd


class Embedder():
def __init__(self):
self.embedder, self.tokenizer = self.get_prott5()

def get_prott5(self):
start=time.time()
# Load your checkpoint here
# Currently, only the encoder-part of ProtT5 is loaded in half-precision
from transformers import T5EncoderModel, T5Tokenizer
print("Start loading ProtT5...")
transformer_name = "Rostlab/prot_t5_xl_half_uniref50-enc"
model = T5EncoderModel.from_pretrained(transformer_name, torch_dtype=torch.float16)
model = model.to(device)
model = model.eval()
tokenizer = T5Tokenizer.from_pretrained(transformer_name, do_lower_case=False)
print("Finished loading {} in {:.1f}[s]".format(transformer_name,time.time()-start))
return model, tokenizer

def write_embedding_list(self,emb_p, ids,embeddings):
embeddings=embeddings.detach().cpu().numpy().squeeze()
with h5py.File(str(emb_p),"w") as hf:
for idx, seq_id in enumerate(ids):
hf.create_dataset(seq_id,data=embeddings[idx])
return None

def write_embeddings(self, emb_p, embds):
with h5py.File(str(emb_p), "w") as hf:
for sequence_id, embedding in embds.items():
# noinspection PyUnboundLocalVariable
hf.create_dataset(sequence_id, data=embedding)
return None

def get_embeddings_batch(self, id2seq, max_residues=4000, max_seq_len=1000, max_batch=100):
print("Start generating embeddings for {} proteins.".format(len(id2seq)) +
"This process might take a few minutes." +
"Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.")
start = time.time()
ids = list()
embeddings = list()
batch = list()

id2seq = sorted( id2seq.items(), key=lambda kv: len( id2seq[kv[0]] ), reverse=True )
for seq_idx, (protein_id, original_seq) in enumerate(id2seq):
seq = original_seq.replace('U','X').replace('Z','X').replace('O','X')
seq_len = len(seq)
seq = ' '.join(list(seq))
batch.append((protein_id,seq,seq_len))


n_res_batch = sum([ s_len for _, _, s_len in batch ]) + seq_len
if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(id2seq) or seq_len>max_seq_len:
protein_ids, seqs, seq_lens = zip(*batch)
batch = list()

token_encoding = self.tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(token_encoding['input_ids']).to(device)
attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)

try:
with torch.no_grad():
# get embeddings extracted from last hidden state
batch_emb = self.embedder(input_ids, attention_mask=attention_mask).last_hidden_state # [B, L, 1024]
except RuntimeError as e :
print(e)
print("RuntimeError during embedding for {} (L={})".format(protein_id, seq_len))
continue

for batch_idx, identifier in enumerate(protein_ids):
s_len = seq_lens[batch_idx]
emb = batch_emb[batch_idx,:s_len].mean(dim=0,keepdims=True)
ids.append(protein_ids[batch_idx])
embeddings.append(emb.detach())

print("Creating per-protein embeddings took: {:.1f}[s]".format(time.time()-start))
embeddings = torch.vstack(embeddings)
return ids, embeddings


# EAT: Embedding-based Annotation Transfer
class EAT():
def __init__(self, lookup_p, query_p, output_d, use_tucker, num_NN,
Expand Down Expand Up @@ -158,7 +239,7 @@ def tucker_embeddings(self, dataset):
state = torch.load(weights_p)['state_dict']
model = Tucker().to(device)
model.load_state_dict(state)
model.eval()
model=model.eval()

start = time.time()
dataset = model.single_pass(dataset)
Expand All @@ -167,25 +248,25 @@ def tucker_embeddings(self, dataset):

def read_inputs(self, input_p):
# define path for storing embeddings
emb_p = self.output_d / input_p.name.replace(".fasta", ".h5")
if not (input_p.is_file() or emb_p.is_file()):

if not input_p.is_file():
print("Neither input fasta, nor embedding H5 could be found for: {}".format(input_p))
print("Files are expected to either end with .fasta or .h5")
raise FileNotFoundError

if emb_p.is_file(): # if the embedding file already exists
return self.read_embeddings(emb_p)
if input_p.name.endswith(".h5"): # if the embedding file already exists
return self.read_embeddings(input_p)

elif input_p.name.endswith(".fasta"): # compute new embeddings if only FASTA available
if self.Embedder is None: # avoid re-loading the pLM
self.Embedder = Embedder()
id2seq = self.read_fasta(input_p)
id2emb = self.Embedder.get_embeddings(id2seq)
self.Embedder.write_embeddings(emb_p, id2emb)
keys, embeddings = zip(*id2emb.items())
# matrix of values (protein-embeddings); n_proteins x embedding_dim
embeddings = np.vstack(embeddings)
return list(keys), torch.tensor(embeddings).to(device).float()


ids, embeddings = self.Embedder.get_embeddings_batch(id2seq)

emb_p = self.output_d / input_p.name.replace(".fasta", ".h5")
self.Embedder.write_embedding_list(emb_p, ids,embeddings)
return ids, embeddings
else:
print("The file you passed neither ended with .fasta nor .h5. " +
"Only those file formats are currently supported.")
Expand All @@ -202,16 +283,20 @@ def read_fasta(self, fasta_path):
sequences = dict()
with open(fasta_path, 'r') as fasta_f:
for line in fasta_f:
line=line.strip()
# get uniprot ID from header and create new entry
if line.startswith('>'):
uniprot_id = line.replace('>', '').strip()
if '|' in line and (line.startswith(">tr") or line.startswith(">sp")):
seq_id = line.split("|")[1]
else:
seq_id = line.replace(">","")
# replace tokens that are mis-interpreted when loading h5
uniprot_id = uniprot_id.replace("/", "_").replace(".", "_")
sequences[uniprot_id] = ''
seq_id = seq_id.replace("/", "_").replace(".", "_")
sequences[seq_id] = ''
else:
# repl. all whie-space chars and join seqs spanning multiple lines
# drop gaps and cast to upper-case
sequences[uniprot_id] += ''.join(
sequences[seq_id] += ''.join(
line.split()).upper().replace("-", "")
return sequences

Expand All @@ -220,6 +305,8 @@ def read_embeddings(self, emb_p):
h5_f = h5py.File(emb_p, 'r')
dataset = {pdb_id: np.array(embd) for pdb_id, embd in h5_f.items()}
keys, embeddings = zip(*dataset.items())
if keys[0].startswith("cath"):
keys = [key.split("|")[2].split("_")[0] for key in keys ]
# matrix of values (protein-embeddings); n_proteins x embedding_dim
embeddings = np.vstack(embeddings)
print("Loading embeddings from {} took: {:.4f}[s]".format(
Expand Down Expand Up @@ -248,15 +335,43 @@ def write_predictions(self, predictions):
]))
return None

def pdist(self, lookup, queries, norm=2):
return torch.cdist(lookup.unsqueeze(dim=0).double(), queries.unsqueeze(dim=0).double(), p=norm).squeeze(dim=0)

def get_NNs(self, random=False):
def pdist(self, lookup, queries, norm=2, use_double=False):
lookup=lookup.unsqueeze(dim=0)
queries=queries.unsqueeze(dim=0)
# double precision improves performance slightly but can be removed for speedy predictions (no significant difference in performance)
if use_double:
lookup=lookup.double()
queries=queries.double()

try: # try to batch-compute pairwise-distance on GPU
pdist = torch.cdist(lookup, queries, p=norm).squeeze(dim=0)
except RuntimeError as e:
print("Encountered RuntimeError: {}".format(e))
print("Trying single query inference on GPU.")
try: # if OOM for batch-GPU, re-try single query pdist computation on GPU
pdist = torch.stack(
[torch.cdist(lookup, queries[0:1, q_idx], p=norm).squeeze(dim=0)
for q_idx in range(queries.shape[1])
]
).squeeze(dim=-1).T

except RuntimeError as e: # if OOM for single GPU, re-try single query on CPU
print("Encountered RuntimeError: {}".format(e))
print("Trying to move single query computation to CPU.")
lookup=lookup.to("cpu")
queries=queries.to("cpu")
pdist = torch.stack(
[torch.cdist(lookup, queries[0:1, q_idx], p=norm).squeeze(dim=0)
for q_idx in range(queries.shape[1])
]
).squeeze(dim=-1).T

print(pdist.shape)
return pdist

def get_NNs(self, threshold, random=False):
start = time.time()
p_dist = self.pdist(self.lookup_embs, self.query_embs)
self_hits = torch.isclose(p_dist, torch.zeros_like(p_dist), atol=1e-5)
# replace self-hits with infinte dimension to avoid self-hit lookup
p_dist[self_hits] = float('inf')

if random: # this is only needed for benchmarking against random background
print("Making RANDOM predictions!")
Expand All @@ -265,7 +380,9 @@ def get_NNs(self, random=False):
else: # infer nearest neighbor indices
nn_dists, nn_idxs = torch.topk(
p_dist, self.num_NN, largest=False, dim=0)


print("Computing NN took: {:.4f}[s]".format(time.time()-start))
nn_dists, nn_idxs = nn_dists.to("cpu"), nn_idxs.to("cpu")
predictions = list()
n_test = len(self.query_ids)
for test_idx in range(n_test): # for all test proteins
Expand All @@ -275,14 +392,17 @@ def get_NNs(self, random=False):
for nn_iter, (nn_i, nn_d) in enumerate(zip(nn_idx, nn_dist)):
# index of nearest neighbour (nn) in train set
nn_i, nn_d = int(nn_i), float(nn_d)
# if a threshold is passed, skip all proteins above this threshold
if threshold is not None and nn_d > threshold:
continue
# get id of nn (infer annotation)
lookup_id = self.lookup_ids[nn_i]
lookup_label = self.lookupLabels[lookup_id]
query_label = self.queryLabels[query_id]
predictions.append(
(query_id, query_label, lookup_id, lookup_label, nn_d, nn_iter))
end = time.time()
print("Computing NN took: {:.4f}".format(end-start))
print("Computing NN took: {:.4f}[s]".format(end-start))
return predictions


Expand Down Expand Up @@ -343,6 +463,11 @@ def create_arg_parser():
default=1,
help="The number of nearest neighbors to retrieve via EAT." +
"Default: 1 (retrieve only THE nearest neighbor).")

parser.add_argument('--threshold', type=int,
default=None,
help="The Euclidean distance threshold below which nearest neighbors are retrieved via EAT." +
"Default: None (retrieve THE nearest neighbor, irrespective of distance).")
return parser


Expand All @@ -360,17 +485,24 @@ def main():
args.queryLabels)

num_NN = int(args.num_NN)
threshold = float(args.threshold) if args.threshold is not None else None
assert num_NN > 0, print(
"Only positive number of nearest neighbors can be retrieved.")

use_tucker = int(args.use_tucker)
use_tucker = False if use_tucker == 0 else True


start=time.time()
eater = EAT(lookup_p, query_p, output_d,
use_tucker, num_NN, lookupLabels_p, queryLabels_p)
predictions = eater.get_NNs()

predictions = eater.get_NNs(threshold=threshold)
eater.write_predictions(predictions)
end=time.time()

print("Total time: {:.3f}[s] ({:.3f}[s]/protein)".format(
end-start, (end-start)/len(eater.query_ids)))

if queryLabels_p is not None:
print("Found labels to queries. Computing EAT performance ...")
evaluator = Evaluator(predictions)
Expand Down

0 comments on commit b10ce20

Please sign in to comment.