From 81f4c2b60ec5d0c493cc0a50381cb39e740f1306 Mon Sep 17 00:00:00 2001 From: "Gabriele A. Ron" Date: Fri, 12 Jan 2024 11:02:54 -0600 Subject: [PATCH] Add multi-az functionality --- src/forge/common.py | 15 +++- src/forge/configure.py | 8 +- src/forge/create.py | 193 ++++++++++++++++++++++++++++++----------- src/forge/destroy.py | 1 + 4 files changed, 159 insertions(+), 58 deletions(-) diff --git a/src/forge/common.py b/src/forge/common.py index 02ea5c3..a463e53 100755 --- a/src/forge/common.py +++ b/src/forge/common.py @@ -117,7 +117,8 @@ def ec2_ip(n, config): 'instance_type': i.get('InstanceType'), 'state': i.get('State').get('Name'), 'launch_time': i.get('LaunchTime'), - 'fleet_id': check_fleet_id(n, config) + 'fleet_id': check_fleet_id(n, config), + 'az': i.get('Placement')['AvailabilityZone'] } details.append(x) logger.debug('ec2_ip details is %s', details) @@ -320,6 +321,14 @@ def normalize_config(config): if config.get('aws_az'): config['region'] = config['aws_az'][:-1] + if config.get('aws_subnet') and not config.get('aws_multi_az'): + config['aws_multi_az'] = {config.get('aws_az'): config.get('aws_subnet')} + elif config.get('aws_subnet') and config.get('aws_multi_az'): + logger.warning('Both aws_multi_az and aws_subnet exist, defaulting to aws_multi_az') + + if config.get('aws_region'): + config['region'] = config['aws_region'] + if not config.get('ram') and not config.get('cpu') and config.get('ratio'): DEFAULT_ARG_VALS['default_ratio'] = config.pop('ratio') @@ -492,8 +501,8 @@ def get_ec2_pricing(ec2_type, market, config): float Hourly price of given EC2 type in given market. """ - region = config.get('region') - az = config.get('aws_az') + region = config['region'] + az = config['aws_az'] if market == 'spot': client = boto3.client('ec2') diff --git a/src/forge/configure.py b/src/forge/configure.py index 76db581..ba1fb30 100755 --- a/src/forge/configure.py +++ b/src/forge/configure.py @@ -50,11 +50,13 @@ def check_env_yaml(env_yaml): """ schema = Schema({ 'forge_env': And(str, len, error='Invalid Environment Name'), - 'aws_az': And(str, len, error='Invalid AWS availability zone'), + Optional('aws_region'): And(str, len, error='Invalid AWS region'), + Optional('aws_az'): And(str, len, error='Invalid AWS availability zone'), + Optional('aws_subnet'): And(str, len, error='Invalid AWS Subnet'), 'ec2_amis': And(dict, len, error='Invalid AMI Dictionary'), - 'aws_subnet': And(str, len, error='Invalid AWS Subnet'), + Optional('aws_multi_az'): And(dict, len, error='Invalid AWS Subnet'), 'ec2_key': And(str, len, error='Invalid AWS key'), - 'aws_security_group': And(str, len, error='Invalid AWS Security Group'), + Optional('aws_security_group'): And(str, len, error='Invalid AWS Security Group'), 'forge_pem_secret': And(str, len, error='Invalid Name of Secret'), Optional('aws_profile'): And(str, len, error='Invalid AWS profile'), Optional('ratio'): And(list, len, error='Invalid default ratio'), diff --git a/src/forge/create.py b/src/forge/create.py index 1ec4ccc..ff53c9a 100755 --- a/src/forge/create.py +++ b/src/forge/create.py @@ -2,11 +2,13 @@ import base64 import logging import sys +import math import os import time from datetime import datetime, timedelta import boto3 +import botocore.exceptions from botocore.exceptions import ClientError from . import DEFAULT_ARG_VALS, REQUIRED_ARGS @@ -474,7 +476,70 @@ def calc_machine_ranges(*, ram=None, cpu=None, ratio=None, workers=None): return job_ram, job_cpu, total_ram, sorted(ram2cpu_ratio) -def create_fleet(n, config, task): +def get_placement_az(config, instance_details, mode=None): + if not mode: + mode = 'balanced' + + region = config.get('region') + subnet = config['aws_multi_az'] + + client = boto3.client('ec2') + az_info = client.describe_availability_zones() + az_mapping = {x['ZoneId']: x['ZoneName'] for x in az_info['AvailabilityZones']} + + try: + response = client.get_spot_placement_scores( + TargetCapacity=instance_details['total_capacity'], + TargetCapacityUnitType=instance_details['capacity_unit'], + SingleAvailabilityZone=True, + RegionNames=[region], + InstanceRequirementsWithMetadata={ + 'ArchitectureTypes': ['x86_64'], # ToDo: Make configurable + 'InstanceRequirements': instance_details['override_instance_stats'] + }, + MaxResults=10, + ) + + placement = {az_mapping[x['AvailabilityZoneId']]: x['Score'] for x in response['SpotPlacementScores']} + logger.debug(placement) + except botocore.exceptions.ClientError as e: + logger.error('Permissions to pull spot placement scores are necessary') + logger.error(e) + placement = {} + + subnet_details = {} + + try: + for placement_az, placement_subnet in subnet.items(): + response = client.describe_subnets( + SubnetIds=[placement_subnet] + ) + + logger.debug(response) + + subnet_details[placement_az] = response['Subnets'][0]['AvailableIpAddressCount'] + + except botocore.exceptions.ClientError as e: + logger.error(e) + + if mode in ['balanced', 'placement']: + subnet_details = {k: int(math.sqrt(v)) for k, v in subnet_details.items()} + + if mode == 'placement': + placement = {k: v**2 for k, v in placement.items()} + + for k, v in subnet_details.items(): + if placement.get(k): + placement[k] += v + else: + placement[k] = v + + az = max(placement, key=placement.get) + + return az + + +def create_fleet(n, config, task, instance_details): """creates the AWS EC2 fleet Parameters @@ -485,6 +550,8 @@ def create_fleet(n, config, task): Forge configuration data task : str Forge service to run + instance_details: dict + EC2 instance details for create_fleet """ profile = config.get('aws_profile') valid = config.get('valid_time', DEFAULT_ARG_VALS['valid_time']) @@ -494,45 +561,15 @@ def create_fleet(n, config, task): now_utc = datetime.utcnow() now_utc = now_utc.replace(microsecond=0) valid_until = now_utc + timedelta(hours=int(valid)) - subnet = config.get('aws_subnet') - - ram = config.get('ram', None) - cpu = config.get('cpu', None) - ratio = config.get('ratio', None) - worker_count = config.get('workers', None) - destroy_flag = config.get('destroy_after_failure') - - if 'single' in n and max(len(ram or []), len(cpu or []), len(ratio or [])) > 1: - raise ValueError("Too many values provided for single job.") - elif 'cluster' in n and min(len(ram or [0, 0]), len(cpu or [0, 0]), len(ratio or [0, 0])) < 2: - raise ValueError("Too few values provided for cluster job.") - - def _check(x, i): - logger.debug('Get index %d of %s', i, x) - return x[i] if x and x[i:i + 1] else None - - if 'cluster-master' in n or 'single' in n: - ram, cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 0), cpu=_check(cpu, 0), - ratio=_check(ratio, 0)) - worker_count = 1 - elif 'cluster-worker' in n: - ram, cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 1), cpu=_check(cpu, 1), - ratio=_check(ratio, 1), workers=worker_count) - else: - logger.error("'%s' does not seem to be a valid cluster or single job.", n) - if destroy_flag: - destroy(config) - sys.exit(1) - - logger.debug('OVERRIDE DETAILS | RAM: %s out of %s | CPU: %s with ratio of %s', ram, total_ram, cpu, ram2cpu_ratio) + subnet = config.get('aws_multi_az') gpu = config.get('gpu_flag', False) market = config.get('market', DEFAULT_ARG_VALS['market']) - subnet = config.get('aws_subnet') market = market[-1] if 'cluster-worker' in n else market[0] set_boto_session(region, profile) + az = config['aws_az'] fmt = FormatEmpty() access_vars = user_accessible_vars(config, market=market, task=task) @@ -555,8 +592,8 @@ def _check(x, i): } }, 'TargetCapacitySpecification': { - 'TotalTargetCapacity': worker_count or total_ram, - 'TargetCapacityUnitType': 'units' if worker_count else 'memory-mib', + 'TotalTargetCapacity': instance_details['total_capacity'], + 'TargetCapacityUnitType': instance_details['capacity_unit'], 'DefaultTargetCapacityType': market }, 'Type': 'maintain', @@ -575,22 +612,17 @@ def _check(x, i): if not tags: kwargs.pop('TagSpecifications') - override_instance_stats = { - 'MemoryMiB': {'Min': ram[0], 'Max': ram[1]}, - 'VCpuCount': {'Min': cpu[0], 'Max': cpu[1]}, - 'SpotMaxPricePercentageOverLowestPrice': 100, - 'MemoryGiBPerVCpu': {'Min': ram2cpu_ratio[0], 'Max': ram2cpu_ratio[1]} - } if gpu: - override_instance_stats['AcceleratorTypes'] = ['gpu'] + instance_details['override_instance_stats']['AcceleratorTypes'] = ['gpu'] if excluded_ec2s: - override_instance_stats['ExcludedInstanceTypes'] = excluded_ec2s + instance_details['override_instance_stats']['ExcludedInstanceTypes'] = excluded_ec2s launch_template_config = { 'LaunchTemplateSpecification': {'LaunchTemplateName': n, 'Version': '1'}, 'Overrides': [{ - 'SubnetId': subnet, - 'InstanceRequirements': override_instance_stats + 'SubnetId': subnet[az], + 'AvailabilityZone': az, + 'InstanceRequirements': instance_details['override_instance_stats'] }] } kwargs['LaunchTemplateConfigs'] = [launch_template_config] @@ -601,7 +633,7 @@ def _check(x, i): create_status(n, request, config) -def search_and_create(config, task): +def search_and_create(config, task, instance_details): """check for running instances and create new ones if necessary Parameters @@ -610,6 +642,8 @@ def search_and_create(config, task): Forge configuration data task : str Forge service to run + instance_details: dict + EC2 instance details for create_fleet """ if not config.get('ram') and not config.get('cpu'): logger.error('Please supply either a ram or cpu value to continue.') @@ -634,18 +668,67 @@ def search_and_create(config, task): logger.info('Fleet is running without EC2, will recreate it.') destroy(config) create_template(n, config, task) - create_fleet(n, config, task) + create_fleet(n, config, task, instance_details) elif len(detail) > 1 and task != 'cluster-worker': logger.info('Multiple %s instances running, destroying and recreating', task) destroy(config) create_template(n, config, task) - create_fleet(n, config, task) + create_fleet(n, config, task, instance_details) detail = ec2_ip(n, config) for e in detail: if e['state'] == 'running': logger.info('%s is running, the IP is %s', task, e['ip']) +def get_instance_details(config, task_list): + """calculate instance details & resources for fleet creation + Parameters + ---------- + config : dict + Forge configuration data + task_list : list + Forge services to get details of + """ + ram = config.get('ram', None) + cpu = config.get('cpu', None) + ratio = config.get('ratio', None) + worker_count = config.get('workers', None) + destroy_flag = config.get('destroy_after_failure') + + instance_details = {} + + def _check(x, i): + logger.debug('Get index %d of %s', i, x) + return x[i] if x and x[i:i + 1] else None + + for task in task_list: + if 'cluster-master' in task or 'single' in task: + task_ram, task_cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 0), cpu=_check(cpu, 0), ratio=_check(ratio, 0)) + worker_count = 1 + elif 'cluster-worker' in task: + task_ram, task_cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 1), cpu=_check(cpu, 1), ratio=_check(ratio, 1), workers=worker_count) + else: + logger.error("'%s' does not seem to be a valid cluster or single job.", task) + if destroy_flag: + destroy(config) + sys.exit(1) + + logger.debug('%s OVERRIDE DETAILS | RAM: %s out of %s | CPU: %s with ratio of %s', task, ram, total_ram, cpu, ram2cpu_ratio) + + instance_details[task] = { + 'total_capacity': worker_count or total_ram, + 'capacity_unit': 'units' if worker_count else 'memory-mib', + 'override_instance_stats': { + 'MemoryMiB': {'Min': task_ram[0], 'Max': task_ram[1]}, + 'VCpuCount': {'Min': task_cpu[0], 'Max': task_cpu[1]}, + 'SpotMaxPricePercentageOverLowestPrice': 100, + 'MemoryGiBPerVCpu': {'Min': ram2cpu_ratio[0], 'Max': ram2cpu_ratio[1]} + } + } + + return instance_details + + def create(config): """creates EC2 instances based on config @@ -657,9 +740,15 @@ def create(config): sys.excepthook = destroy_hook service = config.get('service') + task_list = ['single'] - if service == 'single': - search_and_create(config, 'single') if service == 'cluster': - search_and_create(config, 'cluster-master') - search_and_create(config, 'cluster-worker') + task_list = ['cluster-master', 'cluster-worker'] + + instance_details = get_instance_details(config, task_list) + + if not config.get('aws_az'): + config['aws_az'] = get_placement_az(config, instance_details[task_list[-1]]) + + for task in task_list: + search_and_create(config, task, instance_details[task]) diff --git a/src/forge/destroy.py b/src/forge/destroy.py index a1e53e6..d50a127 100755 --- a/src/forge/destroy.py +++ b/src/forge/destroy.py @@ -60,6 +60,7 @@ def pricing(detail, config, market): if dif > max_dif: max_dif = dif ec2_type = e['instance_type'] + config['aws_az'] = e['az'] total_cost = get_ec2_pricing(ec2_type, market, config) if total_cost > 0: