-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
276981f
commit 39b7ae9
Showing
7 changed files
with
669 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
"""Create Solubility Models in AWS/SageWorks | ||
We're using Pipelines to create different set of ML Artifacts in SageWorks | ||
""" | ||
|
||
import pandas as pd | ||
import numpy as np | ||
import logging | ||
|
||
from sageworks.api.data_source import DataSource | ||
from sageworks.api.feature_set import FeatureSet | ||
from sageworks.api.model import Model, ModelType | ||
from sageworks.api.endpoint import Endpoint | ||
|
||
from sageworks.core.transforms.data_to_features.light.molecular_descriptors import MolecularDescriptors | ||
from sageworks.aws_service_broker.aws_service_broker import AWSServiceBroker | ||
from sageworks.api.pipeline import Pipeline | ||
from sageworks.utils.pandas_utils import stratified_split | ||
|
||
log = logging.getLogger("sageworks") | ||
|
||
# Set our pipeline | ||
pipeline_name = "test_solubility_class_nightly_100_v0" | ||
|
||
|
||
if __name__ == "__main__": | ||
# This forces a refresh on all the data we get from the AWs Broker | ||
AWSServiceBroker().get_all_metadata(force_refresh=True) | ||
|
||
# Grab all the information from the Pipeline (as a dictionary) | ||
pipe = Pipeline(pipeline_name).pipeline | ||
|
||
# Get all the pipeline information | ||
s3_path = pipe["data_source"]["input"] | ||
model_features = pipe["model"]["feature_list"] | ||
data_source_input = pipe["data_source"]["input"] | ||
data_source_name = pipe["data_source"]["name"] | ||
data_source_tags = pipe["data_source"]["tags"] | ||
feature_set_name = pipe["feature_set"]["name"] | ||
feature_set_tags = pipe["feature_set"]["tags"] | ||
holdout = pipe["feature_set"]["holdout"] | ||
model_name = pipe["model"]["name"] | ||
model_type_str = pipe["model"]["model_type"] | ||
model_tags = pipe["model"]["tags"] | ||
model_target = pipe["model"]["target_column"] | ||
endpoint_name = pipe["endpoint"]["name"] | ||
endpoint_tags = pipe["endpoint"]["tags"] | ||
pipeline_id = pipe["pipeline"] | ||
|
||
# Recreate Flag in case you want to recreate the artifacts | ||
recreate = False | ||
|
||
# Create the aqsol_data DataSource | ||
if recreate or not DataSource(data_source_name).exists(): | ||
# Grab the input and add some columns | ||
df = DataSource(data_source_input).pull_dataframe() | ||
|
||
# Remove 'weird' values | ||
log.important("Removing 'weird' values from the solubility data") | ||
log.important(f"Original Shape: {df.shape}") | ||
df = df[df["udm_asy_res_value"] != 4.7] | ||
df = df[df["udm_asy_res_value"] != 0] | ||
log.important(f"New Shape: {df.shape}") | ||
|
||
# Compute the log of the solubility | ||
df["udm_asy_res_value"] = df["udm_asy_res_value"].replace(0, 1e-10) | ||
df["log_s"] = np.log10(df["udm_asy_res_value"] / 1e6) | ||
|
||
# Create a solubility classification column | ||
bins = [-float("inf"), -5, -4, float("inf")] | ||
labels = ["low", "medium", "high"] | ||
df["solubility_class"] = pd.cut(df["log_s"], bins=bins, labels=labels) | ||
|
||
# Now we'll create the DataSource with the new column | ||
DataSource(df, name=data_source_name, tags=data_source_tags) | ||
|
||
# | ||
# Molecular Descriptor Artifacts | ||
# | ||
# Create the rdkit FeatureSet (this is an example of using lower level classes) | ||
if recreate or not FeatureSet(feature_set_name).exists(): | ||
|
||
rdkit_features = MolecularDescriptors(data_source_name, feature_set_name) | ||
rdkit_features.set_output_tags(feature_set_tags) | ||
rdkit_features.transform(id_column="udm_mol_id") | ||
|
||
# Set the holdout ids for the FeatureSet | ||
fs = FeatureSet(feature_set_name) | ||
|
||
# Hold out logic (might be a list of ids or a stratified split) | ||
if isinstance(holdout, list): | ||
fs.set_holdout_ids("udm_mol_id", holdout) | ||
else: | ||
# Stratified Split, so we need to pull the parameters from the string | ||
test_size = float(holdout.split(":")[1]) | ||
column_name = holdout.split(":")[2] | ||
df = fs.pull_dataframe()[["udm_mol_id", column_name]] | ||
|
||
# Perform the stratified split and set the hold out ids | ||
train, test = stratified_split(df, column_name=column_name, test_size=test_size) | ||
fs.set_holdout_ids("udm_mol_id", test["udm_mol_id"].tolist()) | ||
|
||
# Create the Model | ||
model_type = ModelType(model_type_str) | ||
if recreate or not Model(model_name).exists(): | ||
feature_set = FeatureSet(feature_set_name) | ||
feature_set.to_model( | ||
model_type, | ||
target_column=model_target, | ||
name=model_name, | ||
feature_list=model_features, | ||
tags=model_tags, | ||
) | ||
|
||
# Create the Endpoint | ||
if recreate or not Endpoint(endpoint_name).exists(): | ||
m = Model(model_name) | ||
m.to_endpoint(name=endpoint_name, tags=endpoint_tags) | ||
end = Endpoint(endpoint_name) | ||
end.auto_inference(capture=True) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Example Glue Job that goes from CSV to Model/Endpoint | ||
import sys | ||
import numpy as np | ||
import pandas as pd | ||
|
||
# SageWorks Imports | ||
from sageworks.api.data_source import DataSource | ||
from sageworks.api.feature_set import FeatureSet | ||
from sageworks.api.model import Model, ModelType | ||
from sageworks.api.endpoint import Endpoint | ||
from sageworks.core.transforms.data_to_features.light.molecular_descriptors import ( | ||
MolecularDescriptors, | ||
) | ||
from sageworks.core.transforms.pandas_transforms.pandas_to_features import ( | ||
PandasToFeatures, | ||
) | ||
from sageworks.utils.config_manager import ConfigManager | ||
from sageworks.utils.glue_utils import glue_args_to_dict | ||
|
||
# Convert Glue Job Args to a Dictionary | ||
glue_args = glue_args_to_dict(sys.argv) | ||
|
||
# Set the SAGEWORKS_BUCKET for the ConfigManager | ||
cm = ConfigManager() | ||
cm.set_config("SAGEWORKS_BUCKET", glue_args["--sageworks-bucket"]) | ||
cm.set_config("REDIS_HOST", glue_args["--redis-host"]) | ||
|
||
# Create a new Data Source from an S3 Path | ||
# source_path = "s3://idb-forest-sandbox/physchemproperty/LogS/Null/gen_processed/2024_03_07_id_smiles.csv" | ||
source_path = ( | ||
"s3://idb-forest-sandbox/physchemproperty/assay_processed_collection/solubility/all/2024_03_07_id_smiles.csv" | ||
) | ||
my_data = DataSource(source_path, name="solubility_test_data") | ||
|
||
# Pull the dataframe from the Data Source | ||
df = DataSource("solubility_test_data").pull_dataframe() | ||
|
||
# Convert to logS | ||
# Note: This will make 0 -> -16 | ||
df["udm_asy_res_value"] = df["udm_asy_res_value"].replace(0, 1e-10) | ||
df["log_s"] = np.log10(df["udm_asy_res_value"] / 1e6) | ||
df["log_s"] = df["udm_asy_res_value"] | ||
|
||
# Create a solubility classification column | ||
bins = [-float("inf"), -5, -4, float("inf")] | ||
labels = ["low", "medium", "high"] | ||
df["sol_class"] = pd.cut(df["log_s"], bins=bins, labels=labels) | ||
|
||
# Compute molecular descriptors | ||
molecular_features = MolecularDescriptors("solubility_test_data", "solubility_test_features") | ||
|
||
# Okay we're going to use the guts of the class without actually doing the DS to FS transformation | ||
molecular_features.input_df = df[:100] | ||
molecular_features.transform_impl() | ||
output_df = molecular_features.output_df | ||
print(output_df.head()) | ||
|
||
# Create a Feature Set | ||
to_features = PandasToFeatures("solubility_test_features", auto_one_hot=False) | ||
to_features.set_input(output_df, target_column="log_s", id_column="udm_mol_bat_id") | ||
to_features.set_output_tags(["test", "solubility"]) | ||
to_features.transform() | ||
|
||
|
||
""" | ||
DataSource(source_path, name="solubility_test_data") | ||
# Create a Feature Set | ||
molecular_features = MolecularDescriptors("solubility_test_data", "solubility_test_features") | ||
molecular_features.set_output_tags(["test", "solubility", "molecular_descriptors"]) | ||
query = "SELECT udm_mol_bat_id, udm_asy_protocol, udm_prj_code, udm_asy_res_value, smiles FROM solubility_test_data" | ||
molecular_features.transform(target_column="solubility", id_column="udm_mol_bat_id", query=query, auto_one_hot=False) | ||
""" | ||
|
||
""" | ||
# Convert to logS | ||
# Note: This will make 0 -> -16 | ||
test_df["udm_asy_res_value"] = test_df["udm_asy_res_value"].replace(0, 1e-10) | ||
test_df["log_s"] = np.log10(test_df["udm_asy_res_value"] / 1e6) | ||
target_column = "log_s" | ||
meta = [ | ||
"write_time", | ||
"api_invocation_time", | ||
"is_deleted", | ||
"udm_asy_protocol", | ||
"udm_asy_cnd_format", | ||
"std_dev", | ||
"count", | ||
"udm_mol_id", | ||
"udm_asy_date", | ||
"udm_prj_code", | ||
"udm_asy_cnd_target", | ||
"udm_asy_cnd_time_hr", | ||
"smiles", | ||
"udm_mol_bat_slt_smiles", | ||
"udm_mol_bat_slv_smiles", | ||
"operator", | ||
"class", | ||
"event_time", | ||
] | ||
exclude = ["log_s", "udm_asy_res_value", "udm_mol_bat_id"] + meta | ||
feature_columns = [c for c in test_df.columns if c not in exclude] | ||
""" |
Oops, something went wrong.