From 44f86162cdb814d24dc78a9757ba0ca56c8b7e67 Mon Sep 17 00:00:00 2001 From: Valentin Zulkower Date: Mon, 16 Dec 2024 00:18:15 -0500 Subject: [PATCH] added tests --- test/data/boltz_input_ligand.yaml | 11 ++++++++++ test/data/boltz_input_multimer.yaml | 8 +++++++ test/data/boltz_input_single_protein.yaml | 5 +++++ test/test_models.py | 16 +++++++++++++- test/test_query_creation.py | 26 ++++++++++++++++++++++- 5 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 test/data/boltz_input_ligand.yaml create mode 100644 test/data/boltz_input_multimer.yaml create mode 100644 test/data/boltz_input_single_protein.yaml diff --git a/test/data/boltz_input_ligand.yaml b/test/data/boltz_input_ligand.yaml new file mode 100644 index 0000000..868c5e1 --- /dev/null +++ b/test/data/boltz_input_ligand.yaml @@ -0,0 +1,11 @@ +version: 1 # Optional, defaults to 1 +sequences: + - protein: + id: [A, B] + sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQ + - ligand: + id: [C, D] + ccd: SAH + - ligand: + id: [E, F] + smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O diff --git a/test/data/boltz_input_multimer.yaml b/test/data/boltz_input_multimer.yaml new file mode 100644 index 0000000..23048c7 --- /dev/null +++ b/test/data/boltz_input_multimer.yaml @@ -0,0 +1,8 @@ +version: 1 # Optional, defaults to 1 +sequences: + - protein: + id: A + sequence: MAHHHHHHVAVDAVSFTLLQDQLQSVLDTL + - protein: + id: B + sequence: MRYAFAAEATTCNAFWRNVDMTVTALYEVPLGVCTQDPDRW diff --git a/test/data/boltz_input_single_protein.yaml b/test/data/boltz_input_single_protein.yaml new file mode 100644 index 0000000..ade9a77 --- /dev/null +++ b/test/data/boltz_input_single_protein.yaml @@ -0,0 +1,5 @@ +version: 1 # Optional, defaults to 1 +sequences: + - protein: + id: A + sequence: MAHHHHHHVAVDAVSFTLLQDQLQSVLDTL diff --git a/test/test_models.py b/test/test_models.py index d939afc..0b07567 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,3 +1,5 @@ +import tempfile +from pathlib import Path import pytest from ginkgo_ai_client import ( @@ -6,6 +8,7 @@ MeanEmbeddingQuery, PromoterActivityQuery, DiffusionMaskedQuery, + BoltzStructurePredictionQuery, ) @@ -63,6 +66,7 @@ def test_promoter_activity(): assert "heart" in response.activity_by_tissue assert "liver" in response.activity_by_tissue + @pytest.mark.parametrize( "model, sequence", [ @@ -73,7 +77,7 @@ def test_promoter_activity(): def test_diffusion_masked_inference(model, sequence): client = GinkgoAIClient() query = DiffusionMaskedQuery( - sequence=sequence, #upper and lower cases + sequence=sequence, # upper and lower cases model=model, temperature=0.5, decoding_order_strategy="entropy", @@ -82,3 +86,13 @@ def test_diffusion_masked_inference(model, sequence): response = client.send_request(query) assert isinstance(response.sequence, str) assert "" not in response.sequence + + +def test_boltz_structure_prediction(): + client = GinkgoAIClient() + data_file = Path(__file__).parent / "data" / "boltz_input_single_chain.yaml" + query = BoltzStructurePredictionQuery.from_yaml_file(data_file) + response = client.send_request(query) + with tempfile.TemporaryDirectory() as temp_dir: + response.download_structure(Path(temp_dir) / "structure.cif") + response.download_structure(Path(temp_dir) / "structure.pdb") diff --git a/test/test_query_creation.py b/test/test_query_creation.py index dd30833..f300ba6 100644 --- a/test/test_query_creation.py +++ b/test/test_query_creation.py @@ -1,7 +1,11 @@ import pytest import re from pathlib import Path -from ginkgo_ai_client.queries import MeanEmbeddingQuery, PromoterActivityQuery +from ginkgo_ai_client.queries import ( + MeanEmbeddingQuery, + PromoterActivityQuery, + BoltzStructurePredictionQuery, +) def test_that_forgetting_to_name_arguments_raises_the_better_error_message(): @@ -37,3 +41,23 @@ def test_promoter_activity_iteration(): }, ) assert len(queries) == 50 + + +@pytest.mark.parametrize( + "filename, expected_sequences", + [ + ("boltz_input_ligand.yaml", 3), + ("boltz_input_multimer.yaml", 2), + ], +) +def test_boltz_structure_prediction_query_from_yaml_file(filename, expected_sequences): + query = BoltzStructurePredictionQuery.from_yaml_file( + Path(__file__).parent / "data" / filename + ) + assert len(query.sequences) == expected_sequences + + +def test_boltz_structure_prediction_query_from_protein_sequence(): + query = BoltzStructurePredictionQuery.from_protein_sequence(sequence="MLLKP") + sequences = query.model_dump(exclude_none=True)["sequences"] + assert sequences == [{"protein": {"id": "A", "sequence": "MLLKP"}}]