Skip to content

Commit

Permalink
Add multi-az functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
Macr0Nerd committed Jan 12, 2024
1 parent 023f542 commit 81f4c2b
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 58 deletions.
15 changes: 12 additions & 3 deletions src/forge/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def ec2_ip(n, config):
'instance_type': i.get('InstanceType'),
'state': i.get('State').get('Name'),
'launch_time': i.get('LaunchTime'),
'fleet_id': check_fleet_id(n, config)
'fleet_id': check_fleet_id(n, config),
'az': i.get('Placement')['AvailabilityZone']
}
details.append(x)
logger.debug('ec2_ip details is %s', details)
Expand Down Expand Up @@ -320,6 +321,14 @@ def normalize_config(config):
if config.get('aws_az'):
config['region'] = config['aws_az'][:-1]

if config.get('aws_subnet') and not config.get('aws_multi_az'):
config['aws_multi_az'] = {config.get('aws_az'): config.get('aws_subnet')}
elif config.get('aws_subnet') and config.get('aws_multi_az'):
logger.warning('Both aws_multi_az and aws_subnet exist, defaulting to aws_multi_az')

if config.get('aws_region'):
config['region'] = config['aws_region']

if not config.get('ram') and not config.get('cpu') and config.get('ratio'):
DEFAULT_ARG_VALS['default_ratio'] = config.pop('ratio')

Expand Down Expand Up @@ -492,8 +501,8 @@ def get_ec2_pricing(ec2_type, market, config):
float
Hourly price of given EC2 type in given market.
"""
region = config.get('region')
az = config.get('aws_az')
region = config['region']
az = config['aws_az']

if market == 'spot':
client = boto3.client('ec2')
Expand Down
8 changes: 5 additions & 3 deletions src/forge/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ def check_env_yaml(env_yaml):
"""
schema = Schema({
'forge_env': And(str, len, error='Invalid Environment Name'),
'aws_az': And(str, len, error='Invalid AWS availability zone'),
Optional('aws_region'): And(str, len, error='Invalid AWS region'),
Optional('aws_az'): And(str, len, error='Invalid AWS availability zone'),
Optional('aws_subnet'): And(str, len, error='Invalid AWS Subnet'),
'ec2_amis': And(dict, len, error='Invalid AMI Dictionary'),
'aws_subnet': And(str, len, error='Invalid AWS Subnet'),
Optional('aws_multi_az'): And(dict, len, error='Invalid AWS Subnet'),
'ec2_key': And(str, len, error='Invalid AWS key'),
'aws_security_group': And(str, len, error='Invalid AWS Security Group'),
Optional('aws_security_group'): And(str, len, error='Invalid AWS Security Group'),
'forge_pem_secret': And(str, len, error='Invalid Name of Secret'),
Optional('aws_profile'): And(str, len, error='Invalid AWS profile'),
Optional('ratio'): And(list, len, error='Invalid default ratio'),
Expand Down
193 changes: 141 additions & 52 deletions src/forge/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import base64
import logging
import sys
import math
import os
import time
from datetime import datetime, timedelta

import boto3
import botocore.exceptions
from botocore.exceptions import ClientError

from . import DEFAULT_ARG_VALS, REQUIRED_ARGS
Expand Down Expand Up @@ -474,7 +476,70 @@ def calc_machine_ranges(*, ram=None, cpu=None, ratio=None, workers=None):
return job_ram, job_cpu, total_ram, sorted(ram2cpu_ratio)


def create_fleet(n, config, task):
def get_placement_az(config, instance_details, mode=None):
if not mode:
mode = 'balanced'

region = config.get('region')
subnet = config['aws_multi_az']

client = boto3.client('ec2')
az_info = client.describe_availability_zones()
az_mapping = {x['ZoneId']: x['ZoneName'] for x in az_info['AvailabilityZones']}

try:
response = client.get_spot_placement_scores(
TargetCapacity=instance_details['total_capacity'],
TargetCapacityUnitType=instance_details['capacity_unit'],
SingleAvailabilityZone=True,
RegionNames=[region],
InstanceRequirementsWithMetadata={
'ArchitectureTypes': ['x86_64'], # ToDo: Make configurable
'InstanceRequirements': instance_details['override_instance_stats']
},
MaxResults=10,
)

placement = {az_mapping[x['AvailabilityZoneId']]: x['Score'] for x in response['SpotPlacementScores']}
logger.debug(placement)
except botocore.exceptions.ClientError as e:
logger.error('Permissions to pull spot placement scores are necessary')
logger.error(e)
placement = {}

subnet_details = {}

try:
for placement_az, placement_subnet in subnet.items():
response = client.describe_subnets(
SubnetIds=[placement_subnet]
)

logger.debug(response)

subnet_details[placement_az] = response['Subnets'][0]['AvailableIpAddressCount']

except botocore.exceptions.ClientError as e:
logger.error(e)

if mode in ['balanced', 'placement']:
subnet_details = {k: int(math.sqrt(v)) for k, v in subnet_details.items()}

if mode == 'placement':
placement = {k: v**2 for k, v in placement.items()}

for k, v in subnet_details.items():
if placement.get(k):
placement[k] += v
else:
placement[k] = v

az = max(placement, key=placement.get)

return az


def create_fleet(n, config, task, instance_details):
"""creates the AWS EC2 fleet
Parameters
Expand All @@ -485,6 +550,8 @@ def create_fleet(n, config, task):
Forge configuration data
task : str
Forge service to run
instance_details: dict
EC2 instance details for create_fleet
"""
profile = config.get('aws_profile')
valid = config.get('valid_time', DEFAULT_ARG_VALS['valid_time'])
Expand All @@ -494,45 +561,15 @@ def create_fleet(n, config, task):
now_utc = datetime.utcnow()
now_utc = now_utc.replace(microsecond=0)
valid_until = now_utc + timedelta(hours=int(valid))
subnet = config.get('aws_subnet')

ram = config.get('ram', None)
cpu = config.get('cpu', None)
ratio = config.get('ratio', None)
worker_count = config.get('workers', None)
destroy_flag = config.get('destroy_after_failure')

if 'single' in n and max(len(ram or []), len(cpu or []), len(ratio or [])) > 1:
raise ValueError("Too many values provided for single job.")
elif 'cluster' in n and min(len(ram or [0, 0]), len(cpu or [0, 0]), len(ratio or [0, 0])) < 2:
raise ValueError("Too few values provided for cluster job.")

def _check(x, i):
logger.debug('Get index %d of %s', i, x)
return x[i] if x and x[i:i + 1] else None

if 'cluster-master' in n or 'single' in n:
ram, cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 0), cpu=_check(cpu, 0),
ratio=_check(ratio, 0))
worker_count = 1
elif 'cluster-worker' in n:
ram, cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 1), cpu=_check(cpu, 1),
ratio=_check(ratio, 1), workers=worker_count)
else:
logger.error("'%s' does not seem to be a valid cluster or single job.", n)
if destroy_flag:
destroy(config)
sys.exit(1)

logger.debug('OVERRIDE DETAILS | RAM: %s out of %s | CPU: %s with ratio of %s', ram, total_ram, cpu, ram2cpu_ratio)
subnet = config.get('aws_multi_az')

gpu = config.get('gpu_flag', False)
market = config.get('market', DEFAULT_ARG_VALS['market'])
subnet = config.get('aws_subnet')

market = market[-1] if 'cluster-worker' in n else market[0]

set_boto_session(region, profile)
az = config['aws_az']

fmt = FormatEmpty()
access_vars = user_accessible_vars(config, market=market, task=task)
Expand All @@ -555,8 +592,8 @@ def _check(x, i):
}
},
'TargetCapacitySpecification': {
'TotalTargetCapacity': worker_count or total_ram,
'TargetCapacityUnitType': 'units' if worker_count else 'memory-mib',
'TotalTargetCapacity': instance_details['total_capacity'],
'TargetCapacityUnitType': instance_details['capacity_unit'],
'DefaultTargetCapacityType': market
},
'Type': 'maintain',
Expand All @@ -575,22 +612,17 @@ def _check(x, i):
if not tags:
kwargs.pop('TagSpecifications')

override_instance_stats = {
'MemoryMiB': {'Min': ram[0], 'Max': ram[1]},
'VCpuCount': {'Min': cpu[0], 'Max': cpu[1]},
'SpotMaxPricePercentageOverLowestPrice': 100,
'MemoryGiBPerVCpu': {'Min': ram2cpu_ratio[0], 'Max': ram2cpu_ratio[1]}
}
if gpu:
override_instance_stats['AcceleratorTypes'] = ['gpu']
instance_details['override_instance_stats']['AcceleratorTypes'] = ['gpu']
if excluded_ec2s:
override_instance_stats['ExcludedInstanceTypes'] = excluded_ec2s
instance_details['override_instance_stats']['ExcludedInstanceTypes'] = excluded_ec2s

launch_template_config = {
'LaunchTemplateSpecification': {'LaunchTemplateName': n, 'Version': '1'},
'Overrides': [{
'SubnetId': subnet,
'InstanceRequirements': override_instance_stats
'SubnetId': subnet[az],
'AvailabilityZone': az,
'InstanceRequirements': instance_details['override_instance_stats']
}]
}
kwargs['LaunchTemplateConfigs'] = [launch_template_config]
Expand All @@ -601,7 +633,7 @@ def _check(x, i):
create_status(n, request, config)


def search_and_create(config, task):
def search_and_create(config, task, instance_details):
"""check for running instances and create new ones if necessary
Parameters
Expand All @@ -610,6 +642,8 @@ def search_and_create(config, task):
Forge configuration data
task : str
Forge service to run
instance_details: dict
EC2 instance details for create_fleet
"""
if not config.get('ram') and not config.get('cpu'):
logger.error('Please supply either a ram or cpu value to continue.')
Expand All @@ -634,18 +668,67 @@ def search_and_create(config, task):
logger.info('Fleet is running without EC2, will recreate it.')
destroy(config)
create_template(n, config, task)
create_fleet(n, config, task)
create_fleet(n, config, task, instance_details)
elif len(detail) > 1 and task != 'cluster-worker':
logger.info('Multiple %s instances running, destroying and recreating', task)
destroy(config)
create_template(n, config, task)
create_fleet(n, config, task)
create_fleet(n, config, task, instance_details)
detail = ec2_ip(n, config)
for e in detail:
if e['state'] == 'running':
logger.info('%s is running, the IP is %s', task, e['ip'])


def get_instance_details(config, task_list):
"""calculate instance details & resources for fleet creation
Parameters
----------
config : dict
Forge configuration data
task_list : list
Forge services to get details of
"""
ram = config.get('ram', None)
cpu = config.get('cpu', None)
ratio = config.get('ratio', None)
worker_count = config.get('workers', None)
destroy_flag = config.get('destroy_after_failure')

instance_details = {}

def _check(x, i):
logger.debug('Get index %d of %s', i, x)
return x[i] if x and x[i:i + 1] else None

for task in task_list:
if 'cluster-master' in task or 'single' in task:
task_ram, task_cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 0), cpu=_check(cpu, 0), ratio=_check(ratio, 0))
worker_count = 1
elif 'cluster-worker' in task:
task_ram, task_cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 1), cpu=_check(cpu, 1), ratio=_check(ratio, 1), workers=worker_count)
else:
logger.error("'%s' does not seem to be a valid cluster or single job.", task)
if destroy_flag:
destroy(config)
sys.exit(1)

logger.debug('%s OVERRIDE DETAILS | RAM: %s out of %s | CPU: %s with ratio of %s', task, ram, total_ram, cpu, ram2cpu_ratio)

instance_details[task] = {
'total_capacity': worker_count or total_ram,
'capacity_unit': 'units' if worker_count else 'memory-mib',
'override_instance_stats': {
'MemoryMiB': {'Min': task_ram[0], 'Max': task_ram[1]},
'VCpuCount': {'Min': task_cpu[0], 'Max': task_cpu[1]},
'SpotMaxPricePercentageOverLowestPrice': 100,
'MemoryGiBPerVCpu': {'Min': ram2cpu_ratio[0], 'Max': ram2cpu_ratio[1]}
}
}

return instance_details


def create(config):
"""creates EC2 instances based on config
Expand All @@ -657,9 +740,15 @@ def create(config):
sys.excepthook = destroy_hook

service = config.get('service')
task_list = ['single']

if service == 'single':
search_and_create(config, 'single')
if service == 'cluster':
search_and_create(config, 'cluster-master')
search_and_create(config, 'cluster-worker')
task_list = ['cluster-master', 'cluster-worker']

instance_details = get_instance_details(config, task_list)

if not config.get('aws_az'):
config['aws_az'] = get_placement_az(config, instance_details[task_list[-1]])

for task in task_list:
search_and_create(config, task, instance_details[task])
1 change: 1 addition & 0 deletions src/forge/destroy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def pricing(detail, config, market):
if dif > max_dif:
max_dif = dif
ec2_type = e['instance_type']
config['aws_az'] = e['az']
total_cost = get_ec2_pricing(ec2_type, market, config)

if total_cost > 0:
Expand Down

0 comments on commit 81f4c2b

Please sign in to comment.