Skip to content

Commit

Permalink
adding Pipelines to the Artifacts View and REPL
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed Apr 16, 2024
1 parent 3347f44 commit 5b8f55a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 3 deletions.
22 changes: 22 additions & 0 deletions src/sageworks/api/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sageworks.utils.config_manager import ConfigManager
from sageworks.utils.datetime_utils import datetime_string
from sageworks.utils.aws_utils import num_columns_ds, num_columns_fs, aws_url
from sageworks.api.pipeline_manager import PipelineManager


class Meta:
Expand All @@ -37,6 +38,9 @@ def __init__(self):
self.aws_broker = AWSServiceBroker()
self.cm = ConfigManager()

# Pipeline Manager
self.pipeline_manager = PipelineManager()

def account(self) -> dict:
"""Print out the AWS Account Info
Expand Down Expand Up @@ -344,6 +348,20 @@ def endpoints_deep(self, refresh: bool = False) -> dict:
"""
return self.aws_broker.get_metadata(ServiceCategory.ENDPOINTS, force_refresh=refresh)

def pipelines(self, refresh: bool = False) -> pd.DataFrame:
"""Get a summary of the SageWorks Pipelines
Args:
refresh (bool, optional): Force a refresh of the metadata. Defaults to False.
Returns:
pd.DataFrame: A summary of the SageWorks Pipelines
"""
data = self.pipeline_manager.list_pipelines()

# Return the pipelines summary as a DataFrame
return pd.DataFrame(data)

def _remove_sageworks_meta(self, data: dict) -> dict:
"""Internal: Recursively remove any keys with 'sageworks_' in them"""

Expand Down Expand Up @@ -400,6 +418,10 @@ def _remove_sageworks_meta(self, data: dict) -> dict:
print("\n\n*** Endpoints ***")
pprint(meta.endpoints())

# Get the Pipelines
print("\n\n*** Pipelines ***")
pprint(meta.pipelines())

# Now do a deep dive on all the Artifacts
print("\n\n#")
print("# Deep Dives ***")
Expand Down
14 changes: 14 additions & 0 deletions src/sageworks/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ def _get_pipeline(self) -> dict:
json_object = json.loads(response["Body"].read())
return json_object

def __repr__(self) -> str:
"""String representation of this pipeline
Returns:
str: String representation of this pipeline
"""
# Class name and details
class_name = self.__class__.__name__
details = json.dumps(self.details(), indent=4)
return f"{class_name}({details})"


if __name__ == "__main__":
"""Exercise the Pipeline Class"""
Expand All @@ -116,3 +127,6 @@ def _get_pipeline(self) -> dict:

# Execute the Pipeline
my_pipeline.execute()

# Print the Representation of the Pipeline
print(my_pipeline)
18 changes: 15 additions & 3 deletions src/sageworks/repl/sageworks_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self):
self.commands["feature_sets"] = self.feature_sets
self.commands["models"] = self.models
self.commands["endpoints"] = self.endpoints
self.commands["pipelines"] = self.pipelines
self.commands["log_debug"] = self.log_debug
self.commands["log_info"] = self.log_info
self.commands["log_important"] = self.log_important
Expand Down Expand Up @@ -209,6 +210,8 @@ def import_sageworks(self):
self.commands["Endpoint"] = importlib.import_module("sageworks.api.endpoint").Endpoint
self.commands["Monitor"] = importlib.import_module("sageworks.api.monitor").Monitor
self.commands["Meta"] = importlib.import_module("sageworks.api.meta").Meta
self.commands["Pipeline"] = importlib.import_module("sageworks.api.pipeline").Pipeline
self.commands["PipelineManager"] = importlib.import_module("sageworks.api.pipeline_manager").PipelineManager
self.commands["PluginManager"] = importlib.import_module("sageworks.utils.plugin_manager").PluginManager

# These are 'nice to have' imports
Expand Down Expand Up @@ -241,6 +244,7 @@ def help_txt():
- feature_sets: List all the FeatureSets in AWS
- models: List all the Models in AWS
- endpoints: List all the Endpoints in AWS
- pipelines: List all the SageWorks Pipelines
- config: Show the current SageWorks Config
- status: Show the current SageWorks Status
- log_(debug/info/important/warning): Set the SageWorks log level
Expand All @@ -266,10 +270,15 @@ def summary(self):
# Pad the name to 15 characters
name = (name + " " * 15)[:15]

# Sanity check the dataframe
if df.empty:
examples = ""

# Get the first three items in the first column
examples = ", ".join(df.iloc[:, 0].tolist())
if len(examples) > 70:
examples = examples[:70] + "..."
else:
examples = ", ".join(df.iloc[:, 0].tolist())
if len(examples) > 70:
examples = examples[:70] + "..."

# Print the summary
cprint(["lightpurple", "\t" + name, "lightgreen", str(df.shape[0]) + " ", "purple_blue", examples])
Expand All @@ -292,6 +301,9 @@ def models(self):
def endpoints(self):
return self.artifacts_text_view.endpoints_summary()

def pipelines(self):
return self.artifacts_text_view.pipelines_summary()

@staticmethod
def log_debug():
logging.getLogger("sageworks").setLevel(logging.DEBUG)
Expand Down
5 changes: 5 additions & 0 deletions src/sageworks/views/artifacts_text_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def view_data(self) -> Dict[str, pd.DataFrame]:
"FEATURE_SETS": self.feature_sets_summary(),
"MODELS": self.models_summary(),
"ENDPOINTS": self.endpoints_summary(),
"PIPELINES": self.pipelines_summary(),
}
return summary_data

Expand Down Expand Up @@ -79,6 +80,10 @@ def endpoints_summary(self) -> pd.DataFrame:
"""Get summary data about the SageWorks Endpoints"""
return self.meta.endpoints()

def pipelines_summary(self) -> pd.DataFrame:
"""Get summary data about the SageWorks Pipelines"""
return self.meta.pipelines()


if __name__ == "__main__":
import time
Expand Down

0 comments on commit 5b8f55a

Please sign in to comment.