Skip to content

Commit

Permalink
DPAT-1970 sagemaker integration poc
Browse files Browse the repository at this point in the history
  • Loading branch information
ymao2 committed Oct 23, 2023
1 parent 4c28ece commit 728c920
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 21 deletions.
80 changes: 80 additions & 0 deletions controlpanel/api/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,3 +1103,83 @@ def delete_messages(self, queue, messages):
log.exception("Couldn't delete messages from queue %s", queue)
else:
return response


class AWSSageMaker(AWSService):

def __init__(self, assume_role_name=None, profile_name=None):
super(AWSSageMaker, self).__init__(
assume_role_name=assume_role_name,
profile_name=profile_name,
region_name=settings.SQS_REGION
)
self.client = self.boto3_session.client("sagemaker")
self.domain_id = settings.SAGEMAKER_DOMAIN_ID
self.session_duration = settings.SAGEMAKER_SESSION_DURATION
self.url_expires_seconds = settings.SAGEMAKER_URL_EXPIRE_SECONDS

def _get_error_code(self, exception):
if hasattr(exception, 'response'):
return exception.response.get('Error', {}).get('Code')
else:
return None

def describe_user_profile(self, user_profile_name):
try:
response = self.client.describe_user_profile(
DomainId=self.domain_id,
UserProfileName=user_profile_name
)
except botocore.exceptions.ClientError as ex:
log.exception("Couldn't retrieve the profile information for %s due to %s",
user_profile_name, ex.__str__())
if self._get_error_code(ex) == "ResourceNotFound":
return None
raise ex
else:
return response

def create_user_profile(self, user):
user_profile = self.describe_user_profile(user.user_profile_name)
if user_profile is None:
user_settings = {
'ExecutionRole': f"{iam_arn('role')}/{user.iam_role_name}"
}
try:
user_profile = self.client.create_user_profile(
DomainId=self.domain_id,
UserProfileName=user.user_profile_name,
UserSettings=user_settings
)
except botocore.exceptions.ClientError as ex:
log.exception("Couldn't create the user_profile fro %s due to %s",
user.user_profile_name, ex.__str__)
return user_profile

def create_presigned_domain_url(self, user):
try:
response = self.client.create_presigned_domain_url(
DomainId=self.domain_id,
UserProfileName=user.user_profile_name,
SessionExpirationDurationInSeconds=self.session_duration,
ExpiresInSeconds=self.url_expires_seconds
)
except botocore.exceptions.ClientError as ex:
log.exception("Couldn't create the user_profile fro %s due to %s",
user.user_profile_name, ex.__str__(),
ex.__str__())
else:
return response.get('AuthorizedUrl')

def delete_user_profile(self, user):
try:
response = self.client.delete_user_profile(
DomainId=self.domain_id,
UserProfileName=user.user_profile_name
)
except botocore.exceptions.ClientError as ex:
log.exception("Couldn't delete the user_profile fro %s due to %s",
user.user_profile_name,
ex.__str__())
else:
return response
13 changes: 13 additions & 0 deletions controlpanel/api/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AWSParameterStore,
AWSPolicy,
AWSRole,
AWSSageMaker,
iam_arn,
s3_arn,
)
Expand Down Expand Up @@ -212,6 +213,7 @@ def __init__(self, user):

def _init_aws_services(self):
self.aws_role_service = self.create_aws_service(AWSRole)
self.aws_sagemaker_service = self.create_aws_service(AWSSageMaker)

@property
def user_helm_charts(self):
Expand Down Expand Up @@ -373,6 +375,17 @@ def has_required_installation_charts(self):
return False
return True

def prepare_sagemaker(self):
self.aws_sagemaker_service.create_user_profile(self.user)
presigned_domain_url = self.aws_sagemaker_service.create_presigned_domain_url(self.user)
return {
"UserProfileName": self.user.user_profile_name,
"PresignedDomainURL": presigned_domain_url,
}

def delete_user_sagemaker_profile(self):
return self.aws_sagemaker_service.delete_user_profile(self.user)


class App(EntityResource):
"""
Expand Down
4 changes: 4 additions & 0 deletions controlpanel/api/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def get_id_token(self):
def iam_role_name(self):
return cluster.User(self).iam_role_name

@property
def user_profile_name(self):
return f"{self.slug}"

@property
def k8s_namespace(self):
return cluster.User(self).k8s_namespace
Expand Down
2 changes: 1 addition & 1 deletion controlpanel/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class Meta:

class ToolDeploymentSerializer(serializers.Serializer):
old_chart_name = serializers.CharField(max_length=64, required=False)
version = serializers.CharField(max_length=64, required=True)
version = serializers.CharField(max_length=64, required=False)

def validate_version(self, value):
try:
Expand Down
27 changes: 19 additions & 8 deletions controlpanel/api/views/tool_deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ class ToolDeploymentAPIView(GenericAPIView):
serializer_class = serializers.ToolDeploymentSerializer
permission_classes = (IsAuthenticated,)

def _deploy(self, chart_name, data):
def _sagemaker_deploy(self, tool_name, data):
start_background_task(
"tool.deploy",
{
"tool_name": tool_name,
"user_id": self.request.user.id,
"id_token": self.request.user.get_id_token()
},
)

def _deploy_tool(self, tool_name, data):
"""
This is the most backwards thing you'll see for a while. The helm
task to deploy the tool apparently must happen when the view class
Expand All @@ -34,30 +44,31 @@ def _deploy(self, chart_name, data):
# attribute and then split on "__" to extract them. Why? Because we
# need both pieces of information to kick off the background helm
# deploy.
tool_name, tool_version, tool_id = chart_info.split("__")
_, tool_version, tool_id = chart_info.split("__")

# Kick off the helm chart as a background task.
start_background_task(
"tool.deploy",
{
"tool_name": chart_name,
"tool_name": tool_name,
"version": tool_version,
"tool_id": tool_id,
"user_id": self.request.user.id,
"id_token": self.request.user.get_id_token(),
"old_chart_name": old_chart_name,
},
)
return None

def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)

chart_name = self.kwargs["tool_name"]
tool_name = self.kwargs["tool_name"]
tool_action = self.kwargs["action"]
tool_action_function = getattr(self, f"_{tool_action}", None)
tool_action_function = getattr(self, f"_{tool_name}_{tool_action}", None)
if tool_action_function and callable(tool_action_function):
tool_action_function(chart_name, request.data)
return Response(status=status.HTTP_200_OK)
tool_action_function(tool_name, request.data)
else:
return Response(status=status.HTTP_400_BAD_REQUEST)
self._deploy_tool(tool_name, request.data)
return Response(status=status.HTTP_200_OK)
35 changes: 33 additions & 2 deletions controlpanel/frontend/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

# First-party/Local
from controlpanel.api import cluster
from controlpanel.api.cluster import ( # TOOL_IDLED,; TOOL_READY,
from controlpanel.api.cluster import (
HOME_RESET_FAILED,
HOME_RESETTING,
TOOL_DEPLOY_FAILED,
TOOL_DEPLOYING,
TOOL_RESTARTING,
TOOL_READY
)
from controlpanel.api.models import (
App,
Expand Down Expand Up @@ -151,7 +152,14 @@ def app_ip_ranges_delete(self, message):
if ip_range.apps.count() == 0:
ip_range.delete()

def tool_deploy(self, message):
def _deploy_sagemaker(self, message):
user = User.objects.get(auth0_id=message["user_id"])
update_sagemaker_status(user, TOOL_DEPLOYING)
result = cluster.User(user).prepare_sagemaker()
# result = {"PresignedDomainURL": "http://testing"}
update_sagemaker_status(user, TOOL_READY, data=result)

def _general_tool_deploy(self, message):
"""
Deploy the named tool for the specified user
Expects a message with `tool_name`, `version` and `user_id` values
Expand All @@ -178,6 +186,13 @@ def tool_deploy(self, message):
else:
log.debug(f"Deployed {tool.name} for {user}")

def tool_deploy(self, message):
tool_action_function = getattr(self, f"_deploy_{message.get('tool_name')}", None)
if tool_action_function and callable(tool_action_function):
tool_action_function(message)
else:
self._general_tool_deploy(message)

def _send_to_sentry(self, error):
if os.environ.get("SENTRY_DSN"):
# Third-party
Expand Down Expand Up @@ -253,6 +268,22 @@ def update_tool_status(tool_deployment, id_token, status):
)


def update_sagemaker_status(user, status, data=None):
payload = {
"toolName": "sagemaker",
"status": status,
}
if data:
payload.update(data)
send_sse(
user.auth0_id,
{
"event": "toolStatus",
"data": json.dumps(payload),
},
)


def update_home_status(home_directory, status):
"""
Update the user with the status of their home directory reset task.
Expand Down
2 changes: 1 addition & 1 deletion controlpanel/frontend/jinja2/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
<script src="{{ static('jquery-ui/jquery-ui.min.js') }}"></script>
<link href="{{ static('jquery-ui/themes/base/jquery-ui.min.css') }}" rel="stylesheet" />

<script src="{{ static('app.js') }}?version=v0.29.35"></script>
<script src="{{ static('app.js') }}?version=v0.29.35.8"></script>
<script>window.moj.init();</script>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id={{ google_analytics_id }}"></script>
Expand Down
48 changes: 43 additions & 5 deletions controlpanel/frontend/jinja2/tool-list.html
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,49 @@ <h2 class="govuk-heading-m">Airflow</h2>
You can <a href="{{ aws_service_url }}" target="_blank" rel="noopener"> access AWS services such as S3 and Athena via the AWS Console (opens in new tab).</a>
</p>

{% if ip_range_feature_enabled %}
<p class="govuk-body">
ip ranges has been enabled
</p>
{% endif %}
{% if settings.features.sagemaker %}
<h2 class="govuk-heading-m">Sagemaker</h2>
<div class="govuk-grid-row tool sse-listener" data-tool-name="sagemaker">
<div class="govuk-grid-column-two-thirds">
<p class="govuk-body">
Amazon SageMaker is a cloud based machine-learning platform that enables developers to create, train, and deploy machine-learning models on the cloud. It also enables developers to deploy ML models on embedded systems and edge-devices.
</p>
</div>
<div class="govuk-grid-column-one-third">
<p class="govuk-!-margin-bottom-1">Status:
<span class="govuk-!-font-weight-bold tool-status-label">
{% if deployment %}
{{ deployment.status | default("") }}
{% else %}
Not deployed
{% endif %}
</span>
</p>
<form style="display: inline;" id="form-sagemaker">
{{ csrf_input }}
<button class="govuk-button govuk-button--secondary govuk-!-margin-right-1 govuk-!-margin-top-0 js-confirm tool-action"
data-action-name="deploy"
data-form-target="form-sagemaker"
data-form-url="{{ url('tool-deployments', kwargs={'tool_name': 'sagemaker', 'action': 'deploy'}) }}"
id="deploy-sagemaker"
data-confirm-message="Do you wish to access the Sagemaker?">
Deploy
</button>
</form>

<button class="govuk-button govuk-button--secondary govuk-!-margin-right-1 govuk-!-margin-top-0 tool-action"
data-action-name="open"
onclick="window.open('', '_blank');"
rel="noopener"
target="_blank"
disabled>
Open
</button>
</form>

</div>
</div>

{% endif %}

{% endblock %}
15 changes: 13 additions & 2 deletions controlpanel/frontend/static/javascripts/modules/tool-status.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ moj.Modules.toolStatus = {
// maybe have a Cancel button? Report issue?
break;
case 'READY':
toolstatus.showActions(listener, ['open', 'restart']);
toolstatus.updateAppVersion(listener, data);
if (data.toolName == 'sagemaker'){
toolstatus.showActions(listener, ['open']);
toolstatus.updatePreSignedURL(listener, data);
} else {
toolstatus.showActions(listener, ['open', 'restart']);
toolstatus.updateAppVersion(listener, data);
}
toolstatus.updateMessage("The tool has been deployed")
break;
case 'IDLED':
Expand Down Expand Up @@ -157,4 +162,10 @@ moj.Modules.toolStatus = {
// the "Deploy" button needs to be disabled
deployButton.disabled = notInstalledSelected || installedSelected;
},

// update sagemaker's the pre_signed_url to the open link
updatePreSignedURL(listener, data) {
const button = listener.querySelector(`${this.buttonClass}[data-action-name='open']`);
button.setAttribute("onclick", "window.open('" + data.PresignedDomainURL + "', '_blank');");
},
};
4 changes: 4 additions & 0 deletions controlpanel/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,7 @@
CELERY_IMPORTS = [
"controlpanel.api.tasks.handlers"
]

SAGEMAKER_DOMAIN_ID = os.environ.get("SAGEMAKER_DOMAIN_ID", "testing")
SAGEMAKER_SESSION_DURATION = os.environ.get("SAGEMAKER_SESSION_DURATION", 43200)
SAGEMAKER_URL_EXPIRE_SECONDS = os.environ.get("SAGEMAKER_URL_EXPIRE_SECONDS", 300)
6 changes: 4 additions & 2 deletions settings.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
enabled_features:
redirect_legacy_api_urls:
_DEFAULT: true
s3_folders:
_DEFAULT: false
_HOST_dev: false
_HOST_prod: false
_HOST_alpha: false
sagemaker:
_HOST_dev: true
_HOST_prod: false
_HOST_alpha: false


AWS_SERVICE_URL:
Expand Down

0 comments on commit 728c920

Please sign in to comment.