diff --git a/bio_embeddings/embed/prottrans_t5_embedder.py b/bio_embeddings/embed/prottrans_t5_embedder.py index 2370d8fa..ae8b5e0a 100644 --- a/bio_embeddings/embed/prottrans_t5_embedder.py +++ b/bio_embeddings/embed/prottrans_t5_embedder.py @@ -131,16 +131,6 @@ def _embed_batch_impl( yield embedding - def embed_many( - self, sequences: Iterable[str], batch_size: Optional[int] = None - ) -> Generator[ndarray, None, None]: - if batch_size is not None: - raise RuntimeError( - "There is a bug in batching T5, so you currently must set batch_size to `None` for T5" - ) - - return super().embed_many(sequences, None) - @staticmethod def reduce_per_protein(embedding): return embedding.mean(axis=0) diff --git a/tests/test_embedder.py b/tests/test_embedder.py index e02d38b5..aade2e9f 100644 --- a/tests/test_embedder.py +++ b/tests/test_embedder.py @@ -200,35 +200,6 @@ def test_model_parameters_seqvec(caplog): weights_file="/none/existent/path", options_file="/none/existent/path" ) - -@pytest.mark.skipif(os.environ.get("SKIP_T5"), reason="T5 makes ci run out of disk") -@pytest.mark.skipif(os.environ.get("SKIP_SLOW_TESTS"), reason="This test is very slow") -def test_batching_t5_blocked(): - """Once the T5 bug is fixed, this should become a regression test""" - embedder = ProtTransT5BFDEmbedder() - with pytest.raises(RuntimeError): - embedder.embed_many([], batch_size=1000) - - -@pytest.mark.skipif(os.environ.get("SKIP_T5"), reason="T5 makes ci run out of disk") -@pytest.mark.skipif(os.environ.get("SKIP_SLOW_TESTS"), reason="This test is very slow") -def test_batching_t5(pytestconfig): - """Check that T5 batching is still failing""" - embedder = ProtTransT5BFDEmbedder() - fasta_file = pytestconfig.rootpath.joinpath("examples/docker/fasta.fa") - batch = [str(i.seq[:]) for i in read_fasta(str(fasta_file))] - embeddings_single_sequence = list( - super(ProtTransT5Embedder, embedder).embed_many(batch, batch_size=None) - ) - embeddings_batched = list( - super(ProtTransT5Embedder, embedder).embed_many(batch, batch_size=10000) - ) - for a, b in zip(embeddings_single_sequence, embeddings_batched): - assert not numpy.allclose(a, b) and numpy.allclose( - a, b, rtol=1.0e-4, atol=1.0e-5 - ) - - def test_half_precision_embedder(pytestconfig, caplog, tmp_path: Path): """Currently a dummy test"""