Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add inference framework #14

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions ginkgo_ai_client/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,11 @@ class PromoterActivityQuery(QueryBase):
query_name: Optional[str] = None
The name of the query. It will appear in the API response and can be used to
handle exceptions.
model: str = "borzoi-human-fold0"
The model to use for the inference (only one default model is supported for now).

inference_framework: Literal["promoter-0"] = "promoter-0"
The inference framework to use for the inference. Currently only supports
borzoi_model: Literal["human-fold0"] = "human-fold0"
The model to use for the inference. Currently only supports the trained
model of "human-fold0".
Returns
-------
PromoterActivityResponse
Expand All @@ -291,19 +293,21 @@ class PromoterActivityQuery(QueryBase):
orf_sequence: str
tissue_of_interest: Dict[str, List[str]]
source: str
model: str = "borzoi-human-fold0"
inference_framework: Literal["promoter-0"] = "promoter-0"
borzoi_model: Literal["human-fold0"] = "human-fold0"
query_name: Optional[str] = None

def to_request_params(self) -> Dict:
# TODO: update the web API so the conversion isn't necessary
model_name = f"borzoi-{self.borzoi_model}"
szhao045 marked this conversation as resolved.
Show resolved Hide resolved
data = {
"prom": self.promoter_sequence,
"orf": self.orf_sequence,
"tissue_of_interest": self.tissue_of_interest,
"source": self.source,
}
return {
"model": self.model,
"model": model_name,
"text": json.dumps(data),
"transforms": [{"type": "PROMOTER_ACTIVITY"}],
}
Expand Down
19 changes: 18 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
DiffusionMaskedQuery,
BoltzStructurePredictionQuery,
)

from pydantic_core import ValidationError

@pytest.mark.parametrize(
"model, sequence, expected_sequence",
Expand Down Expand Up @@ -67,6 +67,23 @@ def test_promoter_activity():
assert "heart" in response.activity_by_tissue
assert "liver" in response.activity_by_tissue

def test_promoter_activity_invalid_framework():
szhao045 marked this conversation as resolved.
Show resolved Hide resolved
client = GinkgoAIClient()

with pytest.raises(ValidationError) as exc_info:
query = PromoterActivityQuery(
promoter_sequence="tgccagccatctgttgtttgcc",
orf_sequence="GTCCCACTGATGAACTGTGCT",
source="expression",
tissue_of_interest={
"heart": ["CNhs10608+", "CNhs10612+"],
"liver": ["CNhs10608+", "CNhs10612+"],
},
inference_framework="promoter-1" # invalid framework
)

assert "Input should be 'promote-0'" in str(exc_info.value)
assert "type=literal_error" in str(exc_info.value)

@pytest.mark.parametrize(
"model, sequence",
Expand Down
Loading