Skip to content

Commit

Permalink
running inference on the endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed Apr 24, 2024
1 parent 8fbc441 commit 9737f2d
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tests/create_aqsol_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Endpoints:
- aqsol-regression-end
"""

import logging
import pandas as pd
import awswrangler as wr

Expand All @@ -20,6 +20,8 @@

from sageworks.core.transforms.data_to_features.light.molecular_descriptors import MolecularDescriptors
from sageworks.aws_service_broker.aws_service_broker import AWSServiceBroker
log = logging.getLogger("sageworks")


if __name__ == "__main__":
# This forces a refresh on all the data we get from the AWs Broker
Expand Down Expand Up @@ -94,7 +96,7 @@
rdkit_features = MolecularDescriptors("aqsol_data", "aqsol_mol_descriptors")
rdkit_features.set_output_tags(["aqsol", "public"])
query = "SELECT id, solubility, solubility_class, smiles FROM aqsol_data"
rdkit_features.transform(target_column="solubility", id_column="id", query=query, auto_one_hot=False)
rdkit_features.transform(target_column="solubility", id_column="id", query=query)

# Create the Molecular Descriptor based Regression Model
if recreate or not Model("aqsol-mol-regression").exists():
Expand Down Expand Up @@ -129,9 +131,15 @@
# Create the Molecular Descriptor Regression Endpoint
if recreate or not Endpoint("aqsol-mol-regression-end").exists():
m = Model("aqsol-mol-regression")
m.to_endpoint(name="aqsol-mol-regression-end", tags=["aqsol", "mol", "regression"])
end = m.to_endpoint(name="aqsol-mol-regression-end", tags=["aqsol", "mol", "regression"])

# Run inference on the endpoint
end.auto_inference(capture=True)

# Create the Molecular Descriptor Classification Endpoint
if recreate or not Endpoint("aqsol-mol-class-end").exists():
m = Model("aqsol-mol-class")
m.to_endpoint(name="aqsol-mol-class-end", tags=["aqsol", "mol", "classification"])
end = m.to_endpoint(name="aqsol-mol-class-end", tags=["aqsol", "mol", "classification"])

# Run inference on the endpoint
end.auto_inference(capture=True)

0 comments on commit 9737f2d

Please sign in to comment.