Skip to content

Commit

Permalink
Merge pull request #14 from ginkgobioworks/siqi/unify_promoter_names
Browse files Browse the repository at this point in the history
refactor: add inference framework
  • Loading branch information
Zulko authored Dec 16, 2024
2 parents 5a32368 + d92ded4 commit 6ab6482
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
13 changes: 8 additions & 5 deletions ginkgo_ai_client/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,11 @@ class PromoterActivityQuery(QueryBase):
query_name: Optional[str] = None
The name of the query. It will appear in the API response and can be used to
handle exceptions.
model: str = "borzoi-human-fold0"
The model to use for the inference (only one default model is supported for now).
inference_framework: Literal["promoter-0"] = "promoter-0"
The inference framework to use for the inference. Currently only supports
borzoi_model: Literal["human-fold0"] = "human-fold0"
The model to use for the inference. Currently only supports the trained
model of "human-fold0".
Returns
-------
PromoterActivityResponse
Expand All @@ -291,7 +293,8 @@ class PromoterActivityQuery(QueryBase):
orf_sequence: str
tissue_of_interest: Dict[str, List[str]]
source: str
model: str = "borzoi-human-fold0"
inference_framework: Literal["promoter-0"] = "promoter-0"
borzoi_model: Literal["human-fold0"] = "human-fold0"
query_name: Optional[str] = None

def to_request_params(self) -> Dict:
Expand All @@ -303,7 +306,7 @@ def to_request_params(self) -> Dict:
"source": self.source,
}
return {
"model": self.model,
"model": f"borzoi-{self.borzoi_model}",
"text": json.dumps(data),
"transforms": [{"type": "PROMOTER_ACTIVITY"}],
}
Expand Down
19 changes: 18 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
DiffusionMaskedQuery,
BoltzStructurePredictionQuery,
)

from pydantic_core import ValidationError

@pytest.mark.parametrize(
"model, sequence, expected_sequence",
Expand Down Expand Up @@ -67,6 +67,23 @@ def test_promoter_activity():
assert "heart" in response.activity_by_tissue
assert "liver" in response.activity_by_tissue

def test_promoter_activity_fails_with_invalid_framework_name():
client = GinkgoAIClient()

with pytest.raises(ValidationError) as exc_info:
query = PromoterActivityQuery(
promoter_sequence="tgccagccatctgttgtttgcc",
orf_sequence="GTCCCACTGATGAACTGTGCT",
source="expression",
tissue_of_interest={
"heart": ["CNhs10608+", "CNhs10612+"],
"liver": ["CNhs10608+", "CNhs10612+"],
},
inference_framework="promoter-2" # invalid framework
)

assert "Input should be 'promoter-0' " in str(exc_info.value)
assert "type=literal_error" in str(exc_info.value)

@pytest.mark.parametrize(
"model, sequence",
Expand Down

0 comments on commit 6ab6482

Please sign in to comment.