Skip to content

Commit

Permalink
[batch] Remove worker shim classes for cloud credentials (#14125)
Browse files Browse the repository at this point in the history
This change removes shim credential classes from the batch worker code
and attempts to consolidate credential handling in the worker API
classes.
  • Loading branch information
daniel-goldstein authored Jan 10, 2024
1 parent cb87096 commit 1a5f485
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 161 deletions.
44 changes: 0 additions & 44 deletions batch/batch/cloud/azure/worker/credentials.py

This file was deleted.

44 changes: 28 additions & 16 deletions batch/batch/cloud/azure/worker/worker_api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import abc
import base64
import os
import tempfile
from typing import Dict, List, Optional, Tuple

import aiohttp
import orjson

from hailtop import httpx
from hailtop.aiocloud import aioazure
from hailtop.auth.auth import IdentityProvider
from hailtop.utils import check_exec_output, retry_transient_errors, time_msecs

from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials
from ..instance_config import AzureSlimInstanceConfig
from .credentials import AzureUserCredentials
from .disk import AzureDisk


class AzureWorkerAPI(CloudWorkerAPI[AzureUserCredentials]):
class AzureWorkerAPI(CloudWorkerAPI):
nameserver_ip = '168.63.129.16'

@staticmethod
Expand All @@ -37,48 +39,58 @@ def __init__(self, subscription_id: str, resource_group: str, acr_url: str, hail

@property
def cloud_specific_env_vars_for_user_jobs(self) -> List[str]:
return [f'HAIL_AZURE_OAUTH_SCOPE={self.hail_oauth_scope}']
idp_json = orjson.dumps({'idp': IdentityProvider.MICROSOFT.value}).decode('utf-8')
return [
f'HAIL_AZURE_OAUTH_SCOPE={self.hail_oauth_scope}',
'AZURE_APPLICATION_CREDENTIALS=/azure-credentials/key.json',
f'HAIL_IDENTITY_PROVIDER_JSON={idp_json}',
]

def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> AzureDisk:
return AzureDisk(disk_name, instance_name, size_in_gb, mount_path)

def user_credentials(self, credentials: Dict[str, str]) -> AzureUserCredentials:
return AzureUserCredentials(credentials)

async def worker_container_registry_credentials(self, session: httpx.ClientSession) -> ContainerRegistryCredentials:
# https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication?tabs=azure-cli#az-acr-login-with---expose-token
return {
'username': '00000000-0000-0000-0000-000000000000',
'password': await self.acr_refresh_token.token(session),
}

async def user_container_registry_credentials(
self, user_credentials: AzureUserCredentials
) -> ContainerRegistryCredentials:
return {
'username': user_credentials.username,
'password': user_credentials.password,
}
async def user_container_registry_credentials(self, credentials: Dict[str, str]) -> ContainerRegistryCredentials:
credentials = orjson.loads(base64.b64decode(credentials['key.json']).decode())
return {'username': credentials['appId'], 'password': credentials['password']}

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> AzureSlimInstanceConfig:
return AzureSlimInstanceConfig.from_dict(config_dict)

def _blobfuse_credentials(self, credentials: Dict[str, str], account: str, container: str) -> str:
credentials = orjson.loads(base64.b64decode(credentials['key.json']).decode())
# https://github.com/Azure/azure-storage-fuse
return f'''
accountName {account}
authType SPN
servicePrincipalClientId {credentials["appId"]}
servicePrincipalClientSecret {credentials["password"]}
servicePrincipalTenantId {credentials["tenant"]}
containerName {container}
'''

def _write_blobfuse_credentials(
self,
credentials: AzureUserCredentials,
credentials: Dict[str, str],
account: str,
container: str,
mount_base_path_data: str,
) -> str:
if mount_base_path_data not in self._blobfuse_credential_files:
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False) as credsfile:
credsfile.write(credentials.blobfuse_credentials(account, container))
credsfile.write(self._blobfuse_credentials(credentials, account, container))
self._blobfuse_credential_files[mount_base_path_data] = credsfile.name
return self._blobfuse_credential_files[mount_base_path_data]

async def _mount_cloudfuse(
self,
credentials: AzureUserCredentials,
credentials: Dict[str, str],
mount_base_path_data: str,
mount_base_path_tmp: str,
config: dict,
Expand Down
28 changes: 0 additions & 28 deletions batch/batch/cloud/gcp/worker/credentials.py

This file was deleted.

30 changes: 17 additions & 13 deletions batch/batch/cloud/gcp/worker/worker_api.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import base64
import os
import tempfile
from typing import Dict, List

import aiohttp
import orjson

from hailtop import httpx
from hailtop.aiocloud import aiogoogle
from hailtop.auth.auth import IdentityProvider
from hailtop.utils import check_exec_output, retry_transient_errors

from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials
from ..instance_config import GCPSlimInstanceConfig
from .credentials import GCPUserCredentials
from .disk import GCPDisk


class GCPWorkerAPI(CloudWorkerAPI[GCPUserCredentials]):
class GCPWorkerAPI(CloudWorkerAPI):
nameserver_ip = '169.254.169.254'

# async because GoogleSession must be created inside a running event loop
Expand All @@ -33,7 +35,11 @@ def __init__(self, project: str, zone: str, compute_client: aiogoogle.GoogleComp

@property
def cloud_specific_env_vars_for_user_jobs(self) -> List[str]:
return []
idp_json = orjson.dumps({'idp': IdentityProvider.GOOGLE.value}).decode('utf-8')
return [
'GOOGLE_APPLICATION_CREDENTIALS=/gsa-key/key.json',
f'HAIL_IDENTITY_PROVIDER_JSON={idp_json}',
]

def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> GCPDisk:
return GCPDisk(
Expand All @@ -46,9 +52,6 @@ def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount
compute_client=self._compute_client,
)

def user_credentials(self, credentials: Dict[str, str]) -> GCPUserCredentials:
return GCPUserCredentials(credentials)

async def worker_container_registry_credentials(self, session: httpx.ClientSession) -> ContainerRegistryCredentials:
token_dict = await retry_transient_errors(
session.post_read_json,
Expand All @@ -59,24 +62,25 @@ async def worker_container_registry_credentials(self, session: httpx.ClientSessi
access_token = token_dict['access_token']
return {'username': 'oauth2accesstoken', 'password': access_token}

async def user_container_registry_credentials(
self, user_credentials: GCPUserCredentials
) -> ContainerRegistryCredentials:
return {'username': '_json_key', 'password': user_credentials.key}
async def user_container_registry_credentials(self, credentials: Dict[str, str]) -> ContainerRegistryCredentials:
key = orjson.loads(base64.b64decode(credentials['key.json']).decode())
async with aiogoogle.GoogleServiceAccountCredentials(key) as sa_credentials:
access_token = await sa_credentials.access_token()
return {'username': 'oauth2accesstoken', 'password': access_token}

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig:
return GCPSlimInstanceConfig.from_dict(config_dict)

def _write_gcsfuse_credentials(self, credentials: GCPUserCredentials, mount_base_path_data: str) -> str:
def _write_gcsfuse_credentials(self, credentials: Dict[str, str], mount_base_path_data: str) -> str:
if mount_base_path_data not in self._gcsfuse_credential_files:
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False) as credsfile:
credsfile.write(credentials.key)
credsfile.write(base64.b64decode(credentials['key.json']).decode())
self._gcsfuse_credential_files[mount_base_path_data] = credsfile.name
return self._gcsfuse_credential_files[mount_base_path_data]

async def _mount_cloudfuse(
self,
credentials: GCPUserCredentials,
credentials: Dict[str, str],
mount_base_path_data: str,
mount_base_path_tmp: str,
config: dict,
Expand Down
18 changes: 0 additions & 18 deletions batch/batch/worker/credentials.py

This file was deleted.

Loading

0 comments on commit 1a5f485

Please sign in to comment.