Skip to content

Commit

Permalink
[batch] Add metadata server to batch jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-goldstein committed Nov 21, 2023
1 parent df46a96 commit 21d24be
Show file tree
Hide file tree
Showing 21 changed files with 318 additions and 84 deletions.
10 changes: 2 additions & 8 deletions batch/Dockerfile.worker
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,14 @@ RUN echo "APT::Acquire::Retries \"5\";" > /etc/apt/apt.conf.d/80-retries && \
RUN curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg && \
curl -s -L https://nvidia.github.io/libnvidia-container/ubuntu22.04/libnvidia-container.list | \
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
tee /etc/apt/sources.list.d/nvidia-container-toolkit.list

RUN apt-get update && \
tee /etc/apt/sources.list.d/nvidia-container-toolkit.list && \
hail-apt-get-install nvidia-container-toolkit

{% elif global.cloud == "azure" %}
RUN apt-get update && \
hail-apt-get-install libreadline8

# https://github.com/Azure/azure-storage-fuse/issues/603
RUN hail-apt-get-install ca-certificates pkg-config libfuse-dev cmake libcurl4-gnutls-dev libgnutls28-dev uuid-dev libgcrypt20-dev && \
RUN hail-apt-get-install libreadline8 ca-certificates pkg-config libfuse-dev cmake libcurl4-gnutls-dev libgnutls28-dev uuid-dev libgcrypt20-dev && \
curl -LO https://packages.microsoft.com/config/ubuntu/22.04/packages-microsoft-prod.deb && \
dpkg -i packages-microsoft-prod.deb && \
apt-get update && \
hail-apt-get-install blobfuse2

{% else %}
Expand Down
3 changes: 3 additions & 0 deletions batch/batch/cloud/azure/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ async def user_container_registry_credentials(
'password': user_credentials.password,
}

def metadata_server(self):
raise NotImplementedError

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> AzureSlimInstanceConfig:
return AzureSlimInstanceConfig.from_dict(config_dict)

Expand Down
6 changes: 3 additions & 3 deletions batch/batch/cloud/gcp/driver/create_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def scheduling() -> dict:
iptables --table nat --append POSTROUTING --source 172.20.0.0/15 --jump MASQUERADE
# [public]
# Block public traffic to the metadata server
iptables --append FORWARD --source 172.21.0.0/16 --destination 169.254.169.254 --jump DROP
# But allow the internal gateway
# Send public jobs' metadata server requests to the batch worker itself
iptables --table nat --append PREROUTING --source 172.21.0.0/16 --destination 169.254.169.254 -p tcp -j REDIRECT --to-ports 5555
# Allow the internal gateway
iptables --append FORWARD --destination $INTERNAL_GATEWAY_IP --jump ACCEPT
# And this worker
iptables --append FORWARD --destination $IP_ADDRESS --jump ACCEPT
Expand Down
6 changes: 6 additions & 0 deletions batch/batch/cloud/gcp/worker/credentials.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import json
from typing import Dict

from hailtop.auth.auth import IdentityProvider
Expand All @@ -10,6 +11,7 @@ class GCPUserCredentials(CloudUserCredentials):
def __init__(self, data: Dict[str, str]):
self._data = data
self._key = base64.b64decode(self._data['key.json']).decode()
self._key_json = json.loads(self._key)

@property
def cloud_env_name(self) -> str:
Expand All @@ -26,3 +28,7 @@ def key(self):
@property
def identity_provider_json(self):
return {'idp': IdentityProvider.GOOGLE.value}

@property
def email(self):
return self._key_json['client_email']
161 changes: 147 additions & 14 deletions batch/batch/cloud/gcp/worker/worker_api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os
import tempfile
from typing import Dict, List
from typing import Dict, List, Optional, Set

import aiohttp
import orjson
from aiohttp import web

from hailtop import httpx
from hailtop.aiocloud import aiogoogle
from hailtop.utils import check_exec_output, retry_transient_errors
from hailtop.utils import check_exec_output

from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials
from ....globals import HTTP_CLIENT_MAX_SIZE
from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials, HailMetadataServer
from ..instance_config import GCPSlimInstanceConfig
from .credentials import GCPUserCredentials
from .disk import GCPDisk
Expand All @@ -22,15 +24,26 @@ class GCPWorkerAPI(CloudWorkerAPI[GCPUserCredentials]):
async def from_env() -> 'GCPWorkerAPI':
project = os.environ['PROJECT']
zone = os.environ['ZONE'].rsplit('/', 1)[1]
session = aiogoogle.GoogleSession()
return GCPWorkerAPI(project, zone, session)
worker_credentials = aiogoogle.GoogleInstanceMetadataCredentials()
http_session = httpx.ClientSession()
google_session = aiogoogle.GoogleSession(credentials=worker_credentials, http_session=http_session)
return GCPWorkerAPI(project, zone, worker_credentials, http_session, google_session)

def __init__(self, project: str, zone: str, session: aiogoogle.GoogleSession):
def __init__(
self,
project: str,
zone: str,
worker_credentials: aiogoogle.GoogleInstanceMetadataCredentials,
http_session: httpx.ClientSession,
session: aiogoogle.GoogleSession,
):
self.project = project
self.zone = zone
self._http_session = http_session
self._google_session = session
self._compute_client = aiogoogle.GoogleComputeClient(project, session=session)
self._gcsfuse_credential_files: Dict[str, str] = {}
self._worker_credentials = worker_credentials

@property
def cloud_specific_env_vars_for_user_jobs(self) -> List[str]:
Expand All @@ -54,20 +67,17 @@ def user_credentials(self, credentials: Dict[str, str]) -> GCPUserCredentials:
return GCPUserCredentials(credentials)

async def worker_container_registry_credentials(self, session: httpx.ClientSession) -> ContainerRegistryCredentials:
token_dict = await retry_transient_errors(
session.post_read_json,
'http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token',
headers={'Metadata-Flavor': 'Google'},
timeout=aiohttp.ClientTimeout(total=60), # type: ignore
)
access_token = token_dict['access_token']
access_token = await self._worker_credentials.access_token()
return {'username': 'oauth2accesstoken', 'password': access_token}

async def user_container_registry_credentials(
self, user_credentials: GCPUserCredentials
) -> ContainerRegistryCredentials:
return {'username': '_json_key', 'password': user_credentials.key}

def metadata_server(self) -> 'GoogleHailMetadataServer':
return GoogleHailMetadataServer(self.project, self._http_session)

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig:
return GCPSlimInstanceConfig.from_dict(config_dict)

Expand Down Expand Up @@ -128,3 +138,126 @@ async def close(self):

def __str__(self):
return f'project={self.project} zone={self.zone}'


class GoogleHailMetadataServer(HailMetadataServer[GCPUserCredentials]):
def __init__(self, project: str, http_session: httpx.ClientSession):
self._project = project
self._metadata_server_client = aiogoogle.GoogleMetadataServerClient(http_session)
self._ip_user_credentials: Dict[str, Dict[str, aiogoogle.GoogleServiceAccountCredentials]] = {}

@staticmethod
def _load_sa_credentials(creds: GCPUserCredentials) -> aiogoogle.GoogleServiceAccountCredentials:
return aiogoogle.GoogleServiceAccountCredentials(
orjson.loads(creds.key),
scopes=None,
)

def set_container_credentials(
self,
ip: str,
default_credentials: GCPUserCredentials,
additional_credentials: Optional[Set[GCPUserCredentials]] = None,
):
default_sa_credentials = self._load_sa_credentials(default_credentials)
credentials = {}
for name in ('default', default_credentials.email):
credentials[name] = default_sa_credentials
if additional_credentials:
for creds in additional_credentials:
credentials[creds.email] = self._load_sa_credentials(creds)
self._ip_user_credentials[ip] = credentials

async def clear_container_credentials(self, ip: str):
creds = self._ip_user_credentials.pop(ip)
for credential in creds.values():
await credential.close()

def _container_credentials(self, request: web.Request) -> Dict[str, aiogoogle.GoogleServiceAccountCredentials]:
assert request.remote
return self._ip_user_credentials[request.remote]

def _requested_credentials(self, request: web.Request) -> aiogoogle.GoogleServiceAccountCredentials:
assert self._requested_credentials_are_permitted(request)
email = request.match_info.get('gsa') or 'default'
return self._container_credentials(request)[email]

async def root(self, _):
return web.Response(text='computeMetadata/\n')

async def project_id(self, _):
return web.Response(text=self._project)

async def numeric_project_id(self, _):
return web.Response(text=await self._metadata_server_client.numeric_project_id())

async def service_accounts(self, request: web.Request):
accounts = '\n'.join(self._container_credentials(request))
return web.Response(text=f'{accounts}\n')

async def user_service_account(self, request: web.Request):
gsa_email = self._requested_credentials(request).email
recursive = request.query.get('recursive')
if recursive == 'true':
return web.json_response(
{
'aliases': ['default'],
'email': gsa_email,
'scopes': ['https://www.googleapis.com/auth/cloud-platform'],
},
)
return web.Response(text='aliases\nemail\nidentity\nscopes\ntoken\n')

async def user_email(self, request: web.Request):
return web.Response(text=self._requested_credentials(request).email)

# TODO Use scopes from the request
async def user_token(self, request: web.Request):
gsa_email = request.match_info['gsa']
creds = self._container_credentials(request)[gsa_email]
access_token = await creds.access_token_obj()
return web.json_response(
{
'access_token': access_token.token,
'expires_in': access_token.expires_in,
'token_type': 'Bearer',
}
)

async def _requested_credentials_are_permitted(self, request: web.Request) -> bool:
credentials = self._container_credentials(request)
return 'gsa' not in request.match_info or request.match_info['gsa'] in credentials.keys()

@web.middleware
async def configure_response(self, request: web.Request, handler):
response = await handler(request)
response.enable_compression()

# `gcloud` does not properly respect `charset`, which aiohttp automatically
# sets so we have to explicitly erase it
# See https://github.com/googleapis/google-auth-library-python/blob/b935298aaf4ea5867b5778bcbfc42408ba4ec02c/google/auth/compute_engine/_metadata.py#L170
if 'application/json' in response.headers['Content-Type']:
response.headers['Content-Type'] = 'application/json'
response.headers['Metadata-Flavor'] = 'Google'
response.headers['Server'] = 'Metadata Server for VM'
response.headers['X-XSS-Protection'] = '0'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
return response

def create_app(self) -> web.Application:
metadata_app = web.Application(
client_max_size=HTTP_CLIENT_MAX_SIZE,
middlewares=[self.reject_unassigned_identities, self.configure_response],
)
metadata_app.add_routes(
[
web.get('/', self.root),
web.get('/computeMetadata/v1/project/project-id', self.project_id),
web.get('/computeMetadata/v1/project/numeric-project-id', self.numeric_project_id),
web.get('/computeMetadata/v1/instance/service-accounts/', self.service_accounts),
web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/', self.user_service_account),
web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/email', self.user_email),
web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/token', self.user_token),
]
)
return metadata_app
34 changes: 33 additions & 1 deletion batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
from ..publicly_available_images import publicly_available_images
from ..resource_usage import ResourceUsageMonitor
from ..semaphore import FIFOWeightedSemaphore
from ..worker.worker_api import CloudDisk, CloudWorkerAPI, ContainerRegistryCredentials
from ..worker.worker_api import CloudDisk, CloudWorkerAPI, ContainerRegistryCredentials, HailMetadataServer
from .credentials import CloudUserCredentials
from .jvm_entryway_protocol import EndOfStream, read_bool, read_int, read_str, write_int, write_str

Expand Down Expand Up @@ -209,6 +209,7 @@ def compose(auth: Union[MutableMapping, str, bytes], registry_addr: Optional[str

port_allocator: Optional['PortAllocator'] = None
network_allocator: Optional['NetworkAllocator'] = None
metadata_server: Optional[HailMetadataServer] = None

worker: Optional['Worker'] = None

Expand Down Expand Up @@ -265,6 +266,8 @@ async def init(self):
for service in HAIL_SERVICES:
hosts.write(f'{INTERNAL_GATEWAY_IP} {service}.hail\n')
hosts.write(f'{INTERNAL_GATEWAY_IP} internal.hail\n')
if CLOUD == 'gcp':
hosts.write('169.254.169.254 metadata metadata.google.internal')

# Jobs on the private network should have access to the metadata server
# and our vdc. The public network should not so we use google's public
Expand Down Expand Up @@ -772,6 +775,7 @@ def __init__(
command: List[str],
cpu_in_mcpu: int,
memory_in_bytes: int,
user_credentials: Optional[CloudUserCredentials],
network: Optional[Union[bool, str]] = None,
port: Optional[int] = None,
timeout: Optional[int] = None,
Expand All @@ -789,6 +793,7 @@ def __init__(
self.command = command
self.cpu_in_mcpu = cpu_in_mcpu
self.memory_in_bytes = memory_in_bytes
self.user_credentials = user_credentials
self.network = network
self.port = port
self.timeout = timeout
Expand Down Expand Up @@ -989,6 +994,9 @@ async def _cleanup(self):

if self.netns:
assert network_allocator
assert metadata_server
if self.user_credentials:
await metadata_server.clear_container_credentials(self.netns.job_ip)
network_allocator.free(self.netns)
log.info(f'Freed the network namespace for {self}')
self.netns = None
Expand Down Expand Up @@ -1037,6 +1045,9 @@ async def _setup_network_namespace(self):
else:
assert self.network is None or self.network == 'public'
self.netns = await network_allocator.allocate_public()
if self.user_credentials:
assert metadata_server
metadata_server.set_container_credentials(self.netns.job_ip, self.user_credentials)
except asyncio.TimeoutError:
log.exception(network_allocator.task_manager.tasks)
raise
Expand Down Expand Up @@ -1468,6 +1479,7 @@ def copy_container(
cpu_in_mcpu=cpu_in_mcpu,
memory_in_bytes=memory_in_bytes,
volume_mounts=volume_mounts,
user_credentials=job.credentials,
env=[f'{job.credentials.cloud_env_name}={job.credentials.mount_path}'],
stdin=json.dumps(files),
)
Expand Down Expand Up @@ -1796,6 +1808,7 @@ def __init__(
command=job_spec['process']['command'],
cpu_in_mcpu=self.cpu_in_mcpu,
memory_in_bytes=self.memory_in_bytes,
user_credentials=self.credentials,
network=job_spec.get('network'),
port=job_spec.get('port'),
timeout=job_spec.get('timeout'),
Expand Down Expand Up @@ -2236,6 +2249,10 @@ async def run(self):

self.state = 'initializing'

assert self.jvm.container.container.netns
assert metadata_server
metadata_server.set_container_credentials(self.jvm.container.container.netns.job_ip, self.credentials)

await check_shell_output(f'xfs_quota -x -c "project -s -p {self.scratch} {self.project_id}" /host/')
await check_shell_output(
f'xfs_quota -x -c "limit -p bsoft={self.data_disk_storage_in_gib} bhard={self.data_disk_storage_in_gib} {self.project_id}" /host/'
Expand Down Expand Up @@ -2379,6 +2396,10 @@ async def cleanup(self):
f'while unmounting fuse blob storage {bucket} from {mount_path} for {self.jvm_name} for job {self.id}'
) from e

assert self.jvm.container.container.netns
assert metadata_server
await metadata_server.clear_container_credentials(self.jvm.container.container.netns.job_ip)

if self.jvm is not None:
self.worker.return_jvm(self.jvm)
self.jvm = None
Expand Down Expand Up @@ -2562,6 +2583,7 @@ async def create_and_start(
command=command,
cpu_in_mcpu=n_cores * 1000,
memory_in_bytes=total_memory_bytes,
user_credentials=None,
env=[f'HAIL_WORKER_OFF_HEAP_MEMORY_PER_CORE_MB={off_heap_memory_per_core_mib}', f'HAIL_CLOUD={CLOUD}'],
volume_mounts=volume_mounts,
log_path=f'/batch/jvm-container-logs/jvm-{index}.log',
Expand Down Expand Up @@ -3206,6 +3228,16 @@ async def healthcheck(self, request): # pylint: disable=unused-argument
return json_response(body)

async def run(self):
global metadata_server
assert CLOUD_WORKER_API
assert network_allocator
metadata_server = CLOUD_WORKER_API.metadata_server()
metadata_app_runner = web.AppRunner(metadata_server.create_app(), access_log_class=BatchWorkerAccessLogger)
await metadata_app_runner.setup()
# TODO Listen on link local IP
metadata_site = web.TCPSite(metadata_app_runner, '0.0.0.0', 5555)
await metadata_site.start()

app = web.Application(client_max_size=HTTP_CLIENT_MAX_SIZE)
app.add_routes(
[
Expand Down
Loading

0 comments on commit 21d24be

Please sign in to comment.