Skip to content

Commit

Permalink
Merge branch 'reactivate_t5_batching' into 'develop'
Browse files Browse the repository at this point in the history
Reactivate T5 batching

See merge request sacdallago/bio_embeddings!182
  • Loading branch information
konstin committed May 14, 2021
2 parents a794aaa + e91cbf3 commit 9248cb6
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 39 deletions.
10 changes: 0 additions & 10 deletions bio_embeddings/embed/prottrans_t5_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,6 @@ def _embed_batch_impl(

yield embedding

def embed_many(
self, sequences: Iterable[str], batch_size: Optional[int] = None
) -> Generator[ndarray, None, None]:
if batch_size is not None:
raise RuntimeError(
"There is a bug in batching T5, so you currently must set batch_size to `None` for T5"
)

return super().embed_many(sequences, None)

@staticmethod
def reduce_per_protein(embedding):
return embedding.mean(axis=0)
Expand Down
29 changes: 0 additions & 29 deletions tests/test_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,35 +200,6 @@ def test_model_parameters_seqvec(caplog):
weights_file="/none/existent/path", options_file="/none/existent/path"
)


@pytest.mark.skipif(os.environ.get("SKIP_T5"), reason="T5 makes ci run out of disk")
@pytest.mark.skipif(os.environ.get("SKIP_SLOW_TESTS"), reason="This test is very slow")
def test_batching_t5_blocked():
"""Once the T5 bug is fixed, this should become a regression test"""
embedder = ProtTransT5BFDEmbedder()
with pytest.raises(RuntimeError):
embedder.embed_many([], batch_size=1000)


@pytest.mark.skipif(os.environ.get("SKIP_T5"), reason="T5 makes ci run out of disk")
@pytest.mark.skipif(os.environ.get("SKIP_SLOW_TESTS"), reason="This test is very slow")
def test_batching_t5(pytestconfig):
"""Check that T5 batching is still failing"""
embedder = ProtTransT5BFDEmbedder()
fasta_file = pytestconfig.rootpath.joinpath("examples/docker/fasta.fa")
batch = [str(i.seq[:]) for i in read_fasta(str(fasta_file))]
embeddings_single_sequence = list(
super(ProtTransT5Embedder, embedder).embed_many(batch, batch_size=None)
)
embeddings_batched = list(
super(ProtTransT5Embedder, embedder).embed_many(batch, batch_size=10000)
)
for a, b in zip(embeddings_single_sequence, embeddings_batched):
assert not numpy.allclose(a, b) and numpy.allclose(
a, b, rtol=1.0e-4, atol=1.0e-5
)


def test_half_precision_embedder(pytestconfig, caplog, tmp_path: Path):
"""Currently a dummy test"""

Expand Down

0 comments on commit 9248cb6

Please sign in to comment.