diff --git a/.bumpversion.cfg b/.bumpversion.cfg index a5a17bb..32da4e3 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.2 +current_version = 1.1.1 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 16bd04a..ee39f93 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -22,7 +22,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v3 with: - python-version: '3.7' + python-version: '3.9' - name: Install dependencies run: | diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index eb3db7d..743af2c 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -23,7 +23,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v3 with: - python-version: '3.7' + python-version: '3.9' - name: Install dependencies run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 6caa440..f462a93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,35 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [1.1.1] - 2024-12-12 + +### Changed +- **Python Version** - Bump minimum python version to 3.9. +- **Rsync** - Properly triggers retry sequence +- **Rsync** - Gives a return code now + +### Fixed +- **Create** - Fix GPU AMI not being selected. +- **Parser** - Fix GPU flag not being passed properly to the config dict. +- **Create** - Better error reporting regarding RAM and CPU misconfigurations. + + +## [1.1.0] - 2024-02-26 + +### Added +- **Create** - Added `destroy_on_create` +- **Create** - Added `create_timeout` option +- **Common** - Moved all `n_list` functions to `get_nlist()` +- **Dependencies** - Updated dependencies and tested on latest versions +- **Create** - Set default boto3 session at beginning of create to resolve region bug +- **Create** + - Multi-AZ functionality + - Spot retries + - On-demand Failover + +### Changed +- **Create** - Configurable spot strategy +- **Documentation** - Updated with new changes ## [1.0.2] - 2022-10-27 @@ -33,12 +61,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - **GitHub** - Update action to build and publish package only when version is bumped. - **Forge** - Added automatic tag `forge-name` to allow `Name` tag to be changed. + ## [1.0.0] - 2022-09-27 ### Added - **Initial commit** - Forge source code, unittests, docs, pyproject.toml, README.md, and LICENSE files. -[unreleased]: https://github.com/carsdotcom/cars-forge/compare/v1.0.2...HEAD +[unreleased]: https://github.com/carsdotcom/cars-forge/compare/v1.1.1...HEAD +[1.1.1]: https://github.com/carsdotcom/cars-forge/compare/v1.1.0...v1.1.1 +[1.1.0]: https://github.com/carsdotcom/cars-forge/compare/v1.0.2...v1.1.0 [1.0.2]: https://github.com/carsdotcom/cars-forge/compare/v1.0.1...v1.0.2 [1.0.1]: https://github.com/carsdotcom/cars-forge/compare/v1.0.0...v1.0.1 [1.0.0]: https://github.com/carsdotcom/cars-forge/releases/tag/v1.0.0 diff --git a/README.md b/README.md index 39d80bc..8abce66 100755 --- a/README.md +++ b/README.md @@ -2,10 +2,10 @@ [![GitHub license](https://img.shields.io/github/license/carsdotcom/cars-forge?color=navy&label=License&logo=License&style=flat-square)](https://github.com/carsdotcom/cars-forge/blob/main/LICENSE) [![PyPI](https://img.shields.io/pypi/v/cars-forge?color=navy&style=flat-square)](https://pypi.org/project/cars-forge/) -![hacktoberfest](https://img.shields.io/github/issues/carsdotcom/cars-forge?color=orange&label=Hacktoberfest%202022&style=flat-square&?labelColor=black) ![PyPI - Downloads](https://img.shields.io/pypi/dm/cars-forge?color=navy&style=flat-square) ![GitHub Workflow Status (branch)](https://img.shields.io/github/workflow/status/carsdotcom/cars-forge/Publish%20Package/main?color=navy&style=flat-square) ![GitHub contributors](https://img.shields.io/github/contributors/carsdotcom/cars-forge?color=navy&style=flat-square) + --- ## About diff --git a/docs/environmental_yaml.md b/docs/environmental_yaml.md index f423a24..220e4f4 100644 --- a/docs/environmental_yaml.md +++ b/docs/environmental_yaml.md @@ -59,10 +59,18 @@ https://github.com/carsdotcom/cars-forge/blob/main/examples/env_yaml_example/exa constraints: [2.3, 3.0, 3.1] error: "Invalid Spark version. Only 2.3, 3.0, and 3.1 are supported." ``` -- **aws_az** - The [AWS availability zone](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html) where Forge will create the EC2 instance. Currently, Forge can run only in one AZ -- **aws_profile** - [AWS CLI profile](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html) to use +- **aws_az** - The [AWS availability zone](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html) where Forge will create the EC2 instance. If set, multi-az placement will be disabled. +- **aws_region** - The AWS region for Forge to run in- **aws_profile** - [AWS CLI profile](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html) to use - **aws_security_group** - [AWS Security Group](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-security-groups.html) for the instance -- **aws_subnet** - [AWS subnet](https://docs.aws.amazon.com/vpc/latest/userguide/configure-subnets.html) where the EC2s will run +- **aws_subnet** - [AWS subnet](https://docs.aws.amazon.com/vpc/latest/userguide/configure-subnets.html) where the EC2s will run +- **aws_multi_az** - [AWS subnet](https://docs.aws.amazon.com/vpc/latest/userguide/configure-subnets.html) where the EC2s will run organized by AZ + - E.g. + ```yaml + aws_multi_az: + us-east-1a: subnet-aaaaaaaaaaaaaaaaa + us-east-1b: subnet-bbbbbbbbbbbbbbbbb + us-east-1c: subnet-ccccccccccccccccc + ``` - **default_ratio** - Override the default ratio of RAM to CPU if the user does not provide one. Must be a list of the minimum and maximum. - default is [8, 8] - **ec2_amis** - A dictionary of dictionaries to store [AMI](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AMIs.html) info. @@ -95,6 +103,8 @@ https://github.com/carsdotcom/cars-forge/blob/main/examples/env_yaml_example/exa ``` - **forge_env** - Name of the Forge environment. The user will refer to this in their yaml. - **forge_pem_secret** - The secret name where the `ec2_key` is stored +- **on_demand_failover** - If using engine mode and all spot attempts (market: spot + spot retries) have failed, run a final attempt using on-demand. +- **spot_retries** - If using engine mode, sets the number of times to retry a spot instance. Only retries if either market is spot. - **tags** - [Tags](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html) to apply to instances created by Forge. Follows the AWS tag format. - Forge also exposes all string, numeric, and some extra variables from the combined user and environmental configs that will be replaced at runtime by the matching values (e.g. `{name}` for job name, `{date}` for job date, etc.) See the [variables](variables.md) page for more details. - E.g. diff --git a/docs/yaml.md b/docs/yaml.md index dc8701a..b55ebc2 100644 --- a/docs/yaml.md +++ b/docs/yaml.md @@ -43,6 +43,7 @@ Each forge command certain parameters. A yaml file with all the parameters can b ``` - If running via the command line, a range of values is passed as: ``--market on-demand spot``. - **name** - Name of the instance/cluster +- **on_demand_failover** - If using engine mode and all spot attempts (market: spot + spot retries) have failed, run a final attempt using on-demand. - **ram** - Minimum amount of RAM required. Can be a range e.g. [16, 32]. - If using a cluster, you must specify both the master and worker. Master first, worker second. ```yaml @@ -76,5 +77,7 @@ Each forge command certain parameters. A yaml file with all the parameters can b - Use the `--all` flag to run the script on all the instances in a cluster. - E.g. `run_cmd: scripts/run.sh {env} {date} {ip}` - **service** - `cluster` or `single` +- **spot_strategy** - Select the [spot allocation strategy](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2/client/create_fleet.html). +- **spot_retries** - If using engine mode, sets the number of times to retry a spot instance. Only retries if either market is spot. - **user_data** - Custom script passed to instance. Will be run only once when the instance starts up. - **valid_time** - How many hours the fleet will stay up. After this time, all EC2s will be destroyed. The default is 8. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 981805f..0bcc7ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,11 +5,20 @@ name = "cars-forge" description = "Create an on-demand/spot fleet of single or cluster EC2 instances." readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.9" license = "Apache-2.0" authors = [ - {name = "Nikhil Patel", email = "npatel@cars.com"} + {name = "Nikhil Patel", email = "npatel@cars.com"}, + {name = "Gabriele Ron", email = "gron@cars.com"}, + {name = "Joao Moreira", email = "jmoreira@cars.com"} ] + +maintainers = [ + {name = "Nikhil Patel", email = "npatel@cars.com"}, + {name = "Gabriele Ron", email = "gron@cars.com"}, + {name = "Joao Moreira", email = "jmoreira@cars.com"} +] + keywords = [ "AWS", "EC2", @@ -19,6 +28,7 @@ keywords = [ "Cluster", "Jupyter" ] + classifiers = [ "Development Status :: 5 - Production/Stable", "Environment :: Console", @@ -28,24 +38,25 @@ classifiers = [ "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] + dynamic = ["version"] dependencies = [ - "boto3~=1.19.0", - "pyyaml~=5.3.0", - "schema~=0.7.0", + "boto3", + "pyyaml", + "schema", ] + [project.optional-dependencies] test = [ - "pytest~=7.1.0", - "pytest-cov~=4.0" + "pytest", + "pytest-cov" ] + dev = [ "bump2version~=1.0", ] diff --git a/src/forge/__init__.py b/src/forge/__init__.py index 5c92f2c..1781885 100755 --- a/src/forge/__init__.py +++ b/src/forge/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.0.2" +__version__ = "1.1.1" # Default values for forge's essential arguments DEFAULT_ARG_VALS = { @@ -11,7 +11,8 @@ 'destroy_after_failure': True, 'default_ratio': [8, 8], 'valid_time': 8, - 'ec2_max': 768 + 'ec2_max': 768, + 'spot_strategy': 'price-capacity-optimized' } # Required arguments for each Forge job diff --git a/src/forge/common.py b/src/forge/common.py index 02ea5c3..4fe65b4 100755 --- a/src/forge/common.py +++ b/src/forge/common.py @@ -14,6 +14,7 @@ from botocore.exceptions import ClientError, NoCredentialsError from . import DEFAULT_ARG_VALS, ADDITIONAL_KEYS +from .exceptions import ExitHandlerException logger = logging.getLogger(__name__) @@ -117,7 +118,8 @@ def ec2_ip(n, config): 'instance_type': i.get('InstanceType'), 'state': i.get('State').get('Name'), 'launch_time': i.get('LaunchTime'), - 'fleet_id': check_fleet_id(n, config) + 'fleet_id': check_fleet_id(n, config), + 'az': i.get('Placement')['AvailabilityZone'] } details.append(x) logger.debug('ec2_ip details is %s', details) @@ -142,6 +144,35 @@ def get_ip(details, states): return [(i['ip'], i['id']) for i in list(filter(lambda x: x['state'] in states, details))] +def get_nlist(config): + """get list of instance names based on service + + Parameters + ---------- + config : dict + Forge configuration data + + Returns + ------- + list + List of instance names + """ + date = config.get('date', '') + market = config.get('market', DEFAULT_ARG_VALS['market']) + name = config['name'] + service = config['service'] + + n_list = [] + if service == "cluster": + n_list.append(f'{name}-{market[0]}-{service}-master-{date}') + if config.get('rr_all'): + n_list.append(f'{name}-{market[-1]}-{service}-worker-{date}') + elif service == "single": + n_list.append(f'{name}-{market[0]}-{service}-{date}') + + return n_list + + @contextlib.contextmanager def key_file(secret_id, region, profile): """Safely retrieve a secret file from AWS for temporary use. @@ -320,6 +351,14 @@ def normalize_config(config): if config.get('aws_az'): config['region'] = config['aws_az'][:-1] + if config.get('aws_subnet') and not config.get('aws_multi_az'): + config['aws_multi_az'] = {config.get('aws_az'): config.get('aws_subnet')} + elif config.get('aws_subnet') and config.get('aws_multi_az'): + logger.warning('Both aws_multi_az and aws_subnet exist, defaulting to aws_multi_az') + + if config.get('aws_region'): + config['region'] = config['aws_region'] + if not config.get('ram') and not config.get('cpu') and config.get('ratio'): DEFAULT_ARG_VALS['default_ratio'] = config.pop('ratio') @@ -492,8 +531,8 @@ def get_ec2_pricing(ec2_type, market, config): float Hourly price of given EC2 type in given market. """ - region = config.get('region') - az = config.get('aws_az') + region = config['region'] + az = config['aws_az'] if market == 'spot': client = boto3.client('ec2') @@ -529,3 +568,14 @@ def get_ec2_pricing(ec2_type, market, config): price = float(price) return price + + +def exit_callback(config, exit: bool = False): + if config['job'] == 'engine' and (config.get('spot_retries') or (config.get('on_demand_failover') or config.get('market_failover'))): + logger.error('Error occurred, bubbling up error to handler.') + raise ExitHandlerException + + if exit: + sys.exit(1) + + pass diff --git a/src/forge/configure.py b/src/forge/configure.py index 76db581..ccc5578 100755 --- a/src/forge/configure.py +++ b/src/forge/configure.py @@ -5,7 +5,7 @@ import sys import yaml -from schema import Schema, And, Optional, SchemaError +from schema import Schema, And, Optional, Or, SchemaError, Use from .common import set_config_dir @@ -50,11 +50,13 @@ def check_env_yaml(env_yaml): """ schema = Schema({ 'forge_env': And(str, len, error='Invalid Environment Name'), - 'aws_az': And(str, len, error='Invalid AWS availability zone'), + Optional('aws_region'): And(str, len, error='Invalid AWS region'), + Optional('aws_az'): And(str, len, error='Invalid AWS availability zone'), + Optional('aws_subnet'): And(str, len, error='Invalid AWS Subnet'), 'ec2_amis': And(dict, len, error='Invalid AMI Dictionary'), - 'aws_subnet': And(str, len, error='Invalid AWS Subnet'), + Optional('aws_multi_az'): And(dict, len, error='Invalid AWS Subnet'), 'ec2_key': And(str, len, error='Invalid AWS key'), - 'aws_security_group': And(str, len, error='Invalid AWS Security Group'), + Optional('aws_security_group'): And(str, len, error='Invalid AWS Security Group'), 'forge_pem_secret': And(str, len, error='Invalid Name of Secret'), Optional('aws_profile'): And(str, len, error='Invalid AWS profile'), Optional('ratio'): And(list, len, error='Invalid default ratio'), @@ -62,7 +64,18 @@ def check_env_yaml(env_yaml): Optional('tags'): And(list, len, error="Invalid AWS tags"), Optional('excluded_ec2s'): And(list), Optional('additional_config'): And(list), - Optional('ec2_max'): And(int) + Optional('ec2_max'): And(int), + Optional('spot_strategy'): And(str, len, + Or( + 'lowest-price', + 'diversified', + 'capacity-optimized', + 'capacity-optimized-prioritized', + 'price-capacity-optimized'), + error='Invalid spot allocation strategy'), + Optional('on_demand_failover'): And(bool), + Optional('spot_retries'): And(Use(int), lambda x: x > 0), + Optional('create_timeout'): And(Use(int), lambda x: x > 0), }) try: validated = schema.validate(env_yaml) diff --git a/src/forge/create.py b/src/forge/create.py index 1ec4ccc..988f0e9 100755 --- a/src/forge/create.py +++ b/src/forge/create.py @@ -2,16 +2,18 @@ import base64 import logging import sys +import math import os import time from datetime import datetime, timedelta import boto3 +import botocore.exceptions from botocore.exceptions import ClientError from . import DEFAULT_ARG_VALS, REQUIRED_ARGS from .parser import add_basic_args, add_job_args, add_env_args, add_general_args -from .common import (ec2_ip, destroy_hook, set_boto_session, +from .common import (ec2_ip, destroy_hook, set_boto_session, exit_callback, user_accessible_vars, FormatEmpty, get_ec2_pricing) from .destroy import destroy @@ -166,6 +168,12 @@ def create_status(n, request, config): create_time = fleet_details.get('CreateTime') time_without_spot = 0 while current_status != 'fulfilled': + if config.get('create_timeout') and t > config['create_timeout']: + logger.error('Timeout of %s seconds hit for instance fulfillment; Aborting.', config['create_timeout']) + if destroy_flag: + destroy(config) + exit_callback(config, exit=True) + if current_status == 'pending_fulfillment': time.sleep(10) t += 10 @@ -177,7 +185,7 @@ def create_status(n, request, config): destroy(config) error_details = get_fleet_error(client, fleet_id, create_time) logger.error('Last status details: %s', error_details) - sys.exit(1) + exit_callback(config, exit=True) time.sleep(10) t += 10 time_without_spot += 10 @@ -198,7 +206,7 @@ def create_status(n, request, config): logger.error('The EC2 spot instance failed to start, please try again.') if destroy_flag: destroy(config) - sys.exit(1) + exit_callback(config, exit=True) logger.info('Finding EC2... - %ds elapsed', t) fleet_request_configs = client.describe_fleet_instances(FleetId=fleet_id) active_instances_list = fleet_request_configs.get('ActiveInstances') @@ -207,6 +215,7 @@ def create_status(n, request, config): list_len = len(ec2_id_list) logger.debug('EC2 list is: %s', ec2_id_list) + config['ec2_id_list'] = ec2_id_list time_without_instance = 0 for s in ec2_id_list: status = 'initializing' @@ -222,12 +231,12 @@ def create_status(n, request, config): logger.error('The EC2 spot instance failed to start, please try again.') if destroy_flag: destroy(config) - sys.exit(1) + exit_callback(config, exit=True) elif status not in {'initializing', 'ok'}: logger.error('Could not start instance. Last EC2 status: %s', status) if destroy_flag: destroy(config) - sys.exit(1) + exit_callback(config, exit=True) logger.info('EC2 initialized.') pricing(n, config, fleet_id) @@ -316,8 +325,10 @@ def create_template(n, config, task): market = market[-1] if task == 'cluster-worker' else market[0] if service: if len(user_ami) == 21 and user_ami[:4] == "ami-": - ami, disk, disk_device_name = (user_ami, config['disk'], config['disk_device_name']) + ami, disk, disk_device_name = (user_ami, user_disk, user_disk_device_name) else: + if gpu: + user_ami += '_gpu' ami_info = env_ami.get(user_ami) ami, disk, disk_device_name = (ami_info['ami'], ami_info['disk'], ami_info['disk_device_name']) @@ -474,7 +485,70 @@ def calc_machine_ranges(*, ram=None, cpu=None, ratio=None, workers=None): return job_ram, job_cpu, total_ram, sorted(ram2cpu_ratio) -def create_fleet(n, config, task): +def get_placement_az(config, instance_details, mode=None): + if not mode: + mode = 'balanced' + + region = config.get('region') + subnet = config['aws_multi_az'] + + client = boto3.client('ec2') + az_info = client.describe_availability_zones() + az_mapping = {x['ZoneId']: x['ZoneName'] for x in az_info['AvailabilityZones']} + + try: + response = client.get_spot_placement_scores( + TargetCapacity=instance_details['total_capacity'], + TargetCapacityUnitType=instance_details['capacity_unit'], + SingleAvailabilityZone=True, + RegionNames=[region], + InstanceRequirementsWithMetadata={ + 'ArchitectureTypes': ['x86_64'], # ToDo: Make configurable + 'InstanceRequirements': instance_details['override_instance_stats'] + }, + MaxResults=10, + ) + + placement = {az_mapping[x['AvailabilityZoneId']]: x['Score'] for x in response['SpotPlacementScores']} + logger.debug(placement) + except botocore.exceptions.ClientError as e: + logger.error('Permissions to pull spot placement scores are necessary') + logger.error(e) + placement = {} + + subnet_details = {} + + try: + for placement_az, placement_subnet in subnet.items(): + response = client.describe_subnets( + SubnetIds=[placement_subnet] + ) + + logger.debug(response) + + subnet_details[placement_az] = response['Subnets'][0]['AvailableIpAddressCount'] + + except botocore.exceptions.ClientError as e: + logger.error(e) + + if mode in ['balanced', 'placement']: + subnet_details = {k: int(math.sqrt(v)) for k, v in subnet_details.items()} + + if mode == 'placement': + placement = {k: v**2 for k, v in placement.items()} + + for k, v in subnet_details.items(): + if placement.get(k): + placement[k] += v + else: + placement[k] = v + + az = max(placement, key=placement.get) + + return az + + +def create_fleet(n, config, task, instance_details): """creates the AWS EC2 fleet Parameters @@ -485,6 +559,8 @@ def create_fleet(n, config, task): Forge configuration data task : str Forge service to run + instance_details: dict + EC2 instance details for create_fleet """ profile = config.get('aws_profile') valid = config.get('valid_time', DEFAULT_ARG_VALS['valid_time']) @@ -494,45 +570,16 @@ def create_fleet(n, config, task): now_utc = datetime.utcnow() now_utc = now_utc.replace(microsecond=0) valid_until = now_utc + timedelta(hours=int(valid)) - subnet = config.get('aws_subnet') - - ram = config.get('ram', None) - cpu = config.get('cpu', None) - ratio = config.get('ratio', None) - worker_count = config.get('workers', None) - destroy_flag = config.get('destroy_after_failure') - - if 'single' in n and max(len(ram or []), len(cpu or []), len(ratio or [])) > 1: - raise ValueError("Too many values provided for single job.") - elif 'cluster' in n and min(len(ram or [0, 0]), len(cpu or [0, 0]), len(ratio or [0, 0])) < 2: - raise ValueError("Too few values provided for cluster job.") - - def _check(x, i): - logger.debug('Get index %d of %s', i, x) - return x[i] if x and x[i:i + 1] else None - - if 'cluster-master' in n or 'single' in n: - ram, cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 0), cpu=_check(cpu, 0), - ratio=_check(ratio, 0)) - worker_count = 1 - elif 'cluster-worker' in n: - ram, cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 1), cpu=_check(cpu, 1), - ratio=_check(ratio, 1), workers=worker_count) - else: - logger.error("'%s' does not seem to be a valid cluster or single job.", n) - if destroy_flag: - destroy(config) - sys.exit(1) - - logger.debug('OVERRIDE DETAILS | RAM: %s out of %s | CPU: %s with ratio of %s', ram, total_ram, cpu, ram2cpu_ratio) + subnet = config.get('aws_multi_az') gpu = config.get('gpu_flag', False) market = config.get('market', DEFAULT_ARG_VALS['market']) - subnet = config.get('aws_subnet') + strategy = config.get('spot_strategy') market = market[-1] if 'cluster-worker' in n else market[0] set_boto_session(region, profile) + az = config['aws_az'] fmt = FormatEmpty() access_vars = user_accessible_vars(config, market=market, task=task) @@ -545,7 +592,7 @@ def _check(x, i): 'AllocationStrategy': 'lowest-price' }, 'SpotOptions': { - 'AllocationStrategy': 'capacity-optimized', + 'AllocationStrategy': strategy, 'InstanceInterruptionBehavior': 'terminate', 'MaintenanceStrategies': { 'CapacityRebalance': { @@ -555,8 +602,8 @@ def _check(x, i): } }, 'TargetCapacitySpecification': { - 'TotalTargetCapacity': worker_count or total_ram, - 'TargetCapacityUnitType': 'units' if worker_count else 'memory-mib', + 'TotalTargetCapacity': instance_details['total_capacity'], + 'TargetCapacityUnitType': instance_details['capacity_unit'], 'DefaultTargetCapacityType': market }, 'Type': 'maintain', @@ -575,22 +622,17 @@ def _check(x, i): if not tags: kwargs.pop('TagSpecifications') - override_instance_stats = { - 'MemoryMiB': {'Min': ram[0], 'Max': ram[1]}, - 'VCpuCount': {'Min': cpu[0], 'Max': cpu[1]}, - 'SpotMaxPricePercentageOverLowestPrice': 100, - 'MemoryGiBPerVCpu': {'Min': ram2cpu_ratio[0], 'Max': ram2cpu_ratio[1]} - } if gpu: - override_instance_stats['AcceleratorTypes'] = ['gpu'] + instance_details['override_instance_stats']['AcceleratorTypes'] = ['gpu'] if excluded_ec2s: - override_instance_stats['ExcludedInstanceTypes'] = excluded_ec2s + instance_details['override_instance_stats']['ExcludedInstanceTypes'] = excluded_ec2s launch_template_config = { 'LaunchTemplateSpecification': {'LaunchTemplateName': n, 'Version': '1'}, 'Overrides': [{ - 'SubnetId': subnet, - 'InstanceRequirements': override_instance_stats + 'SubnetId': subnet[az], + 'AvailabilityZone': az, + 'InstanceRequirements': instance_details['override_instance_stats'] }] } kwargs['LaunchTemplateConfigs'] = [launch_template_config] @@ -601,7 +643,7 @@ def _check(x, i): create_status(n, request, config) -def search_and_create(config, task): +def search_and_create(config, task, instance_details): """check for running instances and create new ones if necessary Parameters @@ -610,6 +652,8 @@ def search_and_create(config, task): Forge configuration data task : str Forge service to run + instance_details: dict + EC2 instance details for create_fleet """ if not config.get('ram') and not config.get('cpu'): logger.error('Please supply either a ram or cpu value to continue.') @@ -629,23 +673,93 @@ def search_and_create(config, task): e = detail[0] if e['state'] in ['running', 'stopped', 'stopping', 'pending']: logger.info('%s is %s, the IP is %s', task, e['state'], e['ip']) + + if config.get('destroy_on_create'): + logger.info('destroy_on_create true, destroying fleet.') + destroy(config) + create_template(n, config, task) + create_fleet(n, config, task, instance_details) else: if len(e['fleet_id']) != 0: logger.info('Fleet is running without EC2, will recreate it.') destroy(config) create_template(n, config, task) - create_fleet(n, config, task) + create_fleet(n, config, task, instance_details) elif len(detail) > 1 and task != 'cluster-worker': logger.info('Multiple %s instances running, destroying and recreating', task) destroy(config) create_template(n, config, task) - create_fleet(n, config, task) + create_fleet(n, config, task, instance_details) detail = ec2_ip(n, config) for e in detail: if e['state'] == 'running': logger.info('%s is running, the IP is %s', task, e['ip']) +def get_instance_details(config, task_list): + """calculate instance details & resources for fleet creation + Parameters + ---------- + config : dict + Forge configuration data + task_list : list + Forge services to get details of + """ + ram = config.get('ram', None) + cpu = config.get('cpu', None) + ratio = config.get('ratio', None) + service = config.get('service', None) + worker_count = config.get('workers', None) + destroy_flag = config.get('destroy_after_failure') + + rc_length = 1 if service == 'single' else 2 if service == 'cluster' else None + + if not ram and not cpu: + logger.error('Invalid configuration, either ram or cpu must be provided.') + if destroy_flag: + destroy(config) + sys.exit(1) + elif (ram and len(ram) != rc_length) or (cpu and len(cpu) != rc_length): + logger.error('Invalid configuration, ram or cpu must have one value for single jobs, and two for cluster jobs.') + if destroy_flag: + destroy(config) + sys.exit(1) + + instance_details = {} + + def _check(x, i): + logger.debug('Get index %d of %s', i, x) + return x[i] if x and x[i:i + 1] else None + + for task in task_list: + task_worker_count = worker_count + if 'cluster-master' in task or 'single' in task: + task_ram, task_cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 0), cpu=_check(cpu, 0), ratio=_check(ratio, 0)) + task_worker_count = 1 + elif 'cluster-worker' in task: + task_ram, task_cpu, total_ram, ram2cpu_ratio = calc_machine_ranges(ram=_check(ram, 1), cpu=_check(cpu, 1), ratio=_check(ratio, 1), workers=task_worker_count) + else: + logger.error("'%s' does not seem to be a valid cluster or single job.", task) + if destroy_flag: + destroy(config) + sys.exit(1) + + logger.debug('%s OVERRIDE DETAILS | RAM: %s out of %s | CPU: %s with ratio of %s', task, task_ram, total_ram, task_cpu, ram2cpu_ratio) + + instance_details[task] = { + 'total_capacity': task_worker_count or total_ram, + 'capacity_unit': 'units' if task_worker_count else 'memory-mib', + 'override_instance_stats': { + 'MemoryMiB': {'Min': task_ram[0], 'Max': task_ram[1]}, + 'VCpuCount': {'Min': task_cpu[0], 'Max': task_cpu[1]}, + 'SpotMaxPricePercentageOverLowestPrice': 100, + 'MemoryGiBPerVCpu': {'Min': ram2cpu_ratio[0], 'Max': ram2cpu_ratio[1]} + } + } + + return instance_details + + def create(config): """creates EC2 instances based on config @@ -656,10 +770,21 @@ def create(config): """ sys.excepthook = destroy_hook + profile = config.get('aws_profile') + region = config.get('region') + + set_boto_session(region, profile) + service = config.get('service') + task_list = ['single'] - if service == 'single': - search_and_create(config, 'single') if service == 'cluster': - search_and_create(config, 'cluster-master') - search_and_create(config, 'cluster-worker') + task_list = ['cluster-master', 'cluster-worker'] + + instance_details = get_instance_details(config, task_list) + + if not config.get('aws_az'): + config['aws_az'] = get_placement_az(config, instance_details[task_list[-1]]) + + for task in task_list: + search_and_create(config, task, instance_details[task]) diff --git a/src/forge/destroy.py b/src/forge/destroy.py index a1e53e6..d50a127 100755 --- a/src/forge/destroy.py +++ b/src/forge/destroy.py @@ -60,6 +60,7 @@ def pricing(detail, config, market): if dif > max_dif: max_dif = dif ec2_type = e['instance_type'] + config['aws_az'] = e['az'] total_cost = get_ec2_pricing(ec2_type, market, config) if total_cost > 0: diff --git a/src/forge/engine.py b/src/forge/engine.py index 5f1cf97..2da0366 100755 --- a/src/forge/engine.py +++ b/src/forge/engine.py @@ -1,11 +1,20 @@ """Run a command on remote EC2, rsync user content, and execute it.""" -from . import REQUIRED_ARGS -from .parser import add_basic_args, add_job_args, add_env_args, add_general_args, add_action_args -from .create import create +import logging +import time + +import boto3 + +from . import DEFAULT_ARG_VALS, REQUIRED_ARGS +from .exceptions import ExitHandlerException +from .parser import add_basic_args, add_job_args, add_env_args, add_general_args, add_action_args, nonnegative_int_arg +from .create import create, ec2_ip from .rsync import rsync from .run import run +logger = logging.getLogger(__name__) + + def cli_engine(subparsers): """add engine parser to subparsers @@ -21,6 +30,9 @@ def cli_engine(subparsers): add_general_args(parser) add_action_args(parser) + parser.add_argument('--spot_retries', '--spot-retries', type=nonnegative_int_arg) + parser.add_argument('--on_demand_failover', '--on-demand-failover', action='store_true', dest='market_failover') + REQUIRED_ARGS['engine'] = list(set(REQUIRED_ARGS['create'] + REQUIRED_ARGS['rsync'] + REQUIRED_ARGS['run'] + @@ -40,7 +52,52 @@ def engine(config): int The exit status for run with 0 for success """ - create(config) - rsync(config) - status = run(config) + status = 4 + + try: + create(config) + logger.info('Waiting for 60s to ensure EC2 has finished starting up...') + time.sleep(60) + status = rsync(config) + status = run(config) + except ExitHandlerException: + # Check for spot instances and retries + if 'spot' in config['market']: + + name = config.get('name') + date = config.get('date', '') + market = config.get('market', DEFAULT_ARG_VALS['market']) + service = config.get('service') + + n_list = [] + if service == "cluster": + n_list.append(f'{name}-{market[0]}-{service}-master-{date}') + n_list.append(f'{name}-{market[-1]}-{service}-worker-{date}') + elif service == "single": + n_list.append(f'{name}-{market[0]}-{service}-{date}') + + for n in n_list: + flag = False + details = ec2_ip(n, config) + + for ec2 in details: + if ec2['state'] != 'running': + flag = True + + if flag: + break + else: + logger.critical('Bubble received but all instances are ok.') + status = 3 + return status + + if config.get('spot_retries', 0) > 0: + config['spot_retries'] -= 1 + status = engine(config) + elif config.get('on_demand_failover') or config.get('market_failover'): + config['market'][0] = config['market'][-1] = 'on-demand' + status = engine(config) + else: + status = 5 + return status diff --git a/src/forge/exceptions.py b/src/forge/exceptions.py new file mode 100644 index 0000000..c507967 --- /dev/null +++ b/src/forge/exceptions.py @@ -0,0 +1,3 @@ +class ExitHandlerException(Exception): + """Raised when there's an exception for the ExitHandler""" + pass diff --git a/src/forge/main.py b/src/forge/main.py index 719ae23..067063d 100755 --- a/src/forge/main.py +++ b/src/forge/main.py @@ -83,7 +83,7 @@ def execute(config): elif job == 'destroy': destroy(config) elif job == 'rsync': - rsync(config) + status = rsync(config) elif job == 'run': status = run(config) elif job == 'engine': diff --git a/src/forge/parser.py b/src/forge/parser.py index fb72e41..185d8d8 100755 --- a/src/forge/parser.py +++ b/src/forge/parser.py @@ -126,7 +126,8 @@ def add_job_args(parser): common_grp.add_argument('--disk', type=positive_int_arg) common_grp.add_argument('--valid_time', '--valid-time', type=positive_int_arg) common_grp.add_argument('--user_data', '--user-data', nargs='*') - common_grp.add_argument('--gpu', action='store_true', dest='gpu_flag') + common_grp.add_argument('--gpu', action='store_true', dest='gpu_flag', default=None) + common_grp.add_argument('--destroy_on_create', '--destroy-on-create', action='store_true', default=None) def add_action_args(parser): diff --git a/src/forge/rsync.py b/src/forge/rsync.py index 3e5cd46..7077b7e 100755 --- a/src/forge/rsync.py +++ b/src/forge/rsync.py @@ -5,8 +5,10 @@ import sys from . import DEFAULT_ARG_VALS, REQUIRED_ARGS +from .destroy import destroy +from .exceptions import ExitHandlerException from .parser import add_basic_args, add_general_args, add_env_args, add_action_args -from .common import ec2_ip, key_file, get_ip +from .common import ec2_ip, key_file, get_ip, get_nlist, exit_callback logger = logging.getLogger(__name__) @@ -39,12 +41,15 @@ def rsync(config): ---------- config : dict Forge configuration data + + Returns + ------- + int + The status of the rsync commands """ - name = config.get('name') - date = config.get('date', '') - market = config.get('market', DEFAULT_ARG_VALS['market']) - service = config.get('service') - rr_all = config.get('rr_all') + + destroy_flag = config.get('destroy_after_failure') + rval = 0 def _rsync(config, ip): """performs the rsync to a given ip @@ -72,24 +77,19 @@ def _rsync(config, ip): sys.exit(1) cmd = 'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' - cmd +=f' -i {pem_path}" {rsync_loc} root@{ip}:/root/' + cmd += f' -i {pem_path}" {rsync_loc} root@{ip}:/root/' try: output = subprocess.check_output( cmd, stderr=subprocess.STDOUT, shell=True, universal_newlines=True ) + logger.info('Rsync successful:\n%s', output) + return 0 except subprocess.CalledProcessError as exc: logger.error('Rsync failed:\n%s', exc.output) - else: - logger.info('Rsync successful:\n%s', output) + return exc.returncode - n_list = [] - if service == "cluster": - n_list.append(f'{name}-{market[0]}-{service}-master-{date}') - if rr_all: - n_list.append(f'{name}-{market[-1]}-{service}-worker-{date}') - elif service == "single": - n_list.append(f'{name}-{market[0]}-{service}-{date}') + n_list = get_nlist(config) for n in n_list: try: @@ -103,6 +103,14 @@ def _rsync(config, ip): for ip, _ in targets: logger.info('Rsync destination is %s', ip) - _rsync(config, ip) - except Exception as e: + rval = _rsync(config, ip) + if rval: + raise ValueError('Rsync command unsuccessful, ending attempts.') + except ValueError as e: logger.error('Got error %s when trying to rsync.', e) + try: + exit_callback(config) + except ExitHandlerException: + raise + + return rval diff --git a/src/forge/run.py b/src/forge/run.py index 906015f..bd8f245 100755 --- a/src/forge/run.py +++ b/src/forge/run.py @@ -5,8 +5,9 @@ import sys from . import DEFAULT_ARG_VALS, REQUIRED_ARGS +from .exceptions import ExitHandlerException from .parser import add_basic_args, add_general_args, add_env_args, add_action_args -from .common import ec2_ip, key_file, get_ip, destroy_hook, user_accessible_vars, FormatEmpty +from .common import ec2_ip, key_file, get_ip, destroy_hook, user_accessible_vars, FormatEmpty, exit_callback, get_nlist from .destroy import destroy logger = logging.getLogger(__name__) @@ -76,7 +77,6 @@ def _run(config, ip): pem_secret = config['forge_pem_secret'] region = config['region'] profile = config.get('aws_profile') - env = config['forge_env'] with key_file(pem_secret, region, profile) as pem_path: fmt = FormatEmpty() @@ -93,13 +93,7 @@ def _run(config, ip): ) return exc.returncode - n_list = [] - if service == "cluster": - n_list.append(f'{name}-{market[0]}-{service}-master-{date}') - if rr_all: - n_list.append(f'{name}-{market[-1]}-{service}-worker-{date}') - elif service == "single": - n_list.append(f'{name}-{market[0]}-{service}-{date}') + n_list = get_nlist(config) for n in n_list: try: @@ -127,10 +121,13 @@ def _run(config, ip): raise ValueError('Run command unsuccessful, ending attempts.') except ValueError as e: logger.error('Run command raised error: %s', e) - if destroy_flag: - logger.info('destroy_after_failure parameter True, running forge destroy...') - destroy(config) - - break + try: + exit_callback(config) + except ExitHandlerException: + raise + finally: + if destroy_flag: + logger.info('destroy_after_failure parameter True, running forge destroy...') + destroy(config) return rval diff --git a/src/forge/start.py b/src/forge/start.py index 3b18f79..8b8ff41 100755 --- a/src/forge/start.py +++ b/src/forge/start.py @@ -6,7 +6,7 @@ from . import REQUIRED_ARGS from .parser import add_basic_args, add_general_args, add_env_args -from .common import ec2_ip, get_ip, set_boto_session +from .common import ec2_ip, get_ip, set_boto_session, get_nlist logger = logging.getLogger(__name__) @@ -75,25 +75,11 @@ def start(config): config : dict Forge configuration data """ - name = config['name'] - date = config.get('date', '') - service = config['service'] market = config.get('market') - n_list = [] - if service == "cluster": - if market[0] == 'spot': - logger.error('Master is a spot instance; you cannot start a spot instance') - elif market[0] == 'on-demand': - n_list.append(f'{name}-{market[0]}-{service}-master-{date}') - - if market[-1] == 'spot': - logger.error('Worker is a spot fleet; you cannot start a spot fleet') - elif market[-1] == 'on-demand': - n_list.append(f'{name}-{market[0]}-{service}-worker-{date}') - elif service == "single": - if market[0] == 'spot': - logger.error('The instance is a spot instance; you cannot start a spot instance') - elif market[0] == 'on-demand': - n_list.append(f'{name}-{market[0]}-{service}-{date}') + if 'spot' in market: + logger.error('Master or worker is a spot instance; you cannot start a spot instance') + # sys.exit(1) # ToDo: Should we change the tests to reflect an exit or allow it to continue? + + n_list = get_nlist({**config, 'rr_all': True}) start_fleet(n_list, config) diff --git a/src/forge/stop.py b/src/forge/stop.py index 7426a21..7888a40 100755 --- a/src/forge/stop.py +++ b/src/forge/stop.py @@ -6,7 +6,7 @@ from . import REQUIRED_ARGS from .parser import add_basic_args, add_general_args, add_env_args -from .common import ec2_ip, get_ip, set_boto_session +from .common import ec2_ip, get_ip, set_boto_session, get_nlist logger = logging.getLogger(__name__) @@ -70,25 +70,12 @@ def stop(config): config : dict Forge configuration data """ - name = config['name'] - date = config.get('date', '') - service = config['service'] market = config.get('market') - n_list = [] - if service == "cluster": - if market[0] == 'spot': - logger.error('Master is a spot instance; you cannot stop a spot instance') - elif market[0] == 'on-demand': - n_list.append(f'{name}-{market[0]}-{service}-master-{date}') - - if market[-1] == 'spot': - logger.error('Worker is a spot fleet; you cannot stop a spot fleet') - elif market[-1] == 'on-demand': - n_list.append(f'{name}-{market[0]}-{service}-worker-{date}') - elif service == "single": - if market[0] == 'spot': - logger.error('The instance is a spot instance; you cannot stop a spot instance') - elif market[0] == 'on-demand': - n_list.append(f'{name}-{market[0]}-{service}-{date}') + if 'spot' in market: + logger.error('Master or worker is a spot instance; you cannot stop a spot instance') + # sys.exit(1) # ToDo: Should we change the tests to reflect an exit or allow it to continue? + + n_list = get_nlist({**config, 'rr_all': True}) + stop_fleet(n_list, config) diff --git a/src/forge/yaml_loader.py b/src/forge/yaml_loader.py index b9cd30f..a68cbb3 100755 --- a/src/forge/yaml_loader.py +++ b/src/forge/yaml_loader.py @@ -135,17 +135,20 @@ def _get_type(x): Optional('cpu'): And(list, error='Invalid CPU cores'), Optional('destroy_after_success'): And(bool), Optional('destroy_after_failure'): And(bool), + Optional('destroy_on_create'): And(bool), Optional('disk'): And(Use(int), positive_int), Optional('disk_device_name'): And(str, len, error='Invalid Device Name'), Optional('forge_env'): And(str, len, error='Invalid Environment'), Optional('gpu_flag'): And(bool), Optional('market'): And(Or(str, list)), Optional('name'): And(str, len, error='Invalid Name'), + Optional('on_demand_failover'): And(bool), Optional('ratio'): And(list), Optional('ram'): And(list, error='Invalid RAM'), Optional('rsync_path'): And(str), Optional('run_cmd'): And(str, len, error='Invalid run_cmd'), Optional('service'): And(str, len, Or('single', 'cluster'), error='Invalid Service'), + Optional('spot_retries'): And(Use(int), positive_int), Optional('user_data'): And(list), Optional('valid_time'): And(Use(int), positive_int), Optional('workers'): And(Use(int), positive_int), @@ -189,7 +192,7 @@ def load_config(args): logger.info('Checking config file: %s', args['yaml']) logger.debug('Required User config is %s', user_config) - env = args['forge_env'] or user_config.get('forge_env') + env = args.get('forge_env') or user_config.get('forge_env') if env is None: logger.error("'forge_env' variable required.") @@ -216,12 +219,11 @@ def load_config(args): env_config.update(normalize_config(env_config)) check_keys(env_config['region'], env_config.get('aws_profile')) - additional_config_data = env_config.pop('additional_config', None) + additional_config_data = env_config.pop('additional_config', []) additional_config = [] - if additional_config_data: - for i in additional_config_data: - ADDITIONAL_KEYS.append(i['name']) - additional_config.append(i) + for i in additional_config_data: + ADDITIONAL_KEYS.append(i['name']) + additional_config.append(i) logger.debug('Additional config options: %s', additional_config) diff --git a/tests/data/admin_configs/dev/dev.yaml b/tests/data/admin_configs/dev/dev.yaml index 5d1ab06..b9de8f2 100755 --- a/tests/data/admin_configs/dev/dev.yaml +++ b/tests/data/admin_configs/dev/dev.yaml @@ -1,6 +1,6 @@ forge_env: dev aws_profile: data-dev -aws_az: us-east-1a +aws_region: us-east-1 ec2_amis: single: - ami-123 @@ -14,9 +14,9 @@ ec2_amis: - ami-789 - 90 - /dev/sda1 -aws_subnet: subnet-123 +aws_multi_az: + us-east-1a: subnet-123 ec2_key: bdp -aws_security_group: sg-123 forge_pem_secret: forge-pem excluded_ec2s: ["t2.medium","t2.large","m4.large", "*g.*", "gd.*", "*metal*", "g4ad*"] tags: diff --git a/tests/test_common.py b/tests/test_common.py index f7c8a5f..78e85eb 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -124,6 +124,7 @@ def test_get_ec2_pricing_ondemand(mock_regions, mock_boto): """Test getting on-demand EC2 hourly pricing.""" exp_price = 0.123 region = 'us-east-1' + az = 'us-east-1a' long_region = 'US East (N. Virginia)' response = {'PriceList': [json.dumps( {"terms": {"OnDemand": { @@ -136,7 +137,7 @@ def test_get_ec2_pricing_ondemand(mock_regions, mock_boto): mock_products.return_value = response mock_regions.return_value = {region: long_region} - config = {'region': region} + config = {'region': region, 'aws_az': az} ec2_type = 'r5.large' act_price = common.get_ec2_pricing(ec2_type, 'on-demand', config) assert act_price == exp_price diff --git a/tests/test_create.py b/tests/test_create.py index 83586b9..d4f9f90 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -10,21 +10,25 @@ from forge import create +@mock.patch('forge.create.get_instance_details') @mock.patch('forge.create.search_and_create') -def test_create_single(mock_search_create): +def test_create_single(mock_search_create, mock_get_instance_details): """Test entry-point for creation of single instance.""" - config = {'service': 'single'} + config = {'service': 'single', 'aws_az': 'us-east-1a'} + mock_get_instance_details.return_value = {'single': {}} create.create(config) - mock_search_create.assert_called_once_with(config, 'single') + mock_search_create.assert_called_once_with(config, 'single', {}) +@mock.patch('forge.create.get_instance_details') @mock.patch('forge.create.search_and_create') -def test_create_cluster_master_workers(mock_search_create): +def test_create_cluster_master_workers(mock_search_create, mock_get_instance_details): """Test entry-point for creation of cluster with master and workers.""" - config = {'service': 'cluster', 'ram': [512]} + config = {'service': 'cluster', 'aws_az': 'us-east-1a'} + mock_get_instance_details.return_value = {'cluster-master': {}, 'cluster-worker': {}} create.create(config) mock_search_create.assert_has_calls([ - mock.call(config, 'cluster-master'), mock.call(config, 'cluster-worker') + mock.call(config, 'cluster-master', {}), mock.call(config, 'cluster-worker', {}) ]) diff --git a/tests/test_destroy.py b/tests/test_destroy.py new file mode 100644 index 0000000..4f14bad --- /dev/null +++ b/tests/test_destroy.py @@ -0,0 +1,58 @@ +from unittest import mock + +from forge import destroy + +import pytest + + +@mock.patch("forge.destroy.find_and_destroy") +@pytest.mark.parametrize("service, market", [("single", ["spot"]), ("cluster", ["spot", "spot"])]) +def test_destroy(mock_find_and_destroy, service, market): + config = { + "name": "test-run", + "date": "2021-02-01", + "service": service, + "market": market + } + + destroy.destroy(config) + + if service == 'single': + n = f'{config["name"]}-{market[0]}-{service}-{config["date"]}' + mock_find_and_destroy.assert_called_once_with(n, config) + + if service == 'cluster': + n1 = f'{config["name"]}-{market[0]}-{service}-master-{config["date"]}' + n2 = f'{config["name"]}-{market[-1]}-{service}-worker-{config["date"]}' + assert mock_find_and_destroy.call_args_list == [((n1, config),), ((n2, config),)] + + +@mock.patch("forge.destroy.ec2_ip") +@mock.patch("forge.destroy.pricing") +@mock.patch("forge.destroy.fleet_destroy") +@pytest.mark.parametrize("service, market", [("single", ["spot"]), ("cluster", ["spot", "spot"])]) +def test_find_and_destroy(mock_fleet_destroy, mock_pricing, mock_ec2_ip, service, market): + ip = "123.456.789" + fleet_id = "abc-123" + ec2_details = [{"ip": ip, "spot_id": ["abc"], "state": None, "fleet_id": fleet_id}] + mock_ec2_ip.return_value = ec2_details + + config = { + "name": "test-run", + "date": "2021-02-01", + "service": service, + "market": market + } + + destroy.destroy(config) + + if service == 'single': + n1 = f'{config["name"]}-{market[0]}-{service}-{config["date"]}' + assert mock_fleet_destroy.call_args_list == [((n1, fleet_id, config),)] + assert mock_pricing.call_args_list == [((ec2_details, config, market[0]),)] + + if service == 'cluster': + n2 = f'{config["name"]}-{market[0]}-{service}-master-{config["date"]}' + n3 = f'{config["name"]}-{market[-1]}-{service}-worker-{config["date"]}' + assert mock_fleet_destroy.call_args_list == [((n2, fleet_id, config),), ((n3, fleet_id, config),)] + assert mock_pricing.call_args_list == [((ec2_details, config, market[0]),), ((ec2_details, config, market[1]),)] diff --git a/tests/test_main.py b/tests/test_main.py index 25bb53f..a9f1794 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -43,7 +43,7 @@ def _load_admin_cfg(env): 'home_dir': os.path.dirname(FORGE_DIR), 'yaml_dir': os.path.join(TEST_DIR, 'data'), 'user': 'test_user', 'ami': 'single_ami', 'forge_version': False, 'config_dir': os.path.join(TEST_CFG_DIR, 'dev'), 'region': 'us-east-1', 'destroy_after_success': True, 'destroy_after_failure': True, 'default_ratio': [8, 8], - 'valid_time': 8, 'ec2_max': 768}), + 'valid_time': 8, 'ec2_max': 768, 'spot_strategy': 'price-capacity-optimized'}), # Destroy job; passing the relative path to a yaml (['forge', 'destroy', '--yaml', os.path.join(TEST_DIR_REL, 'data', 'single_intermediate.yaml')], {'name': 'test-single-intermediate', 'log_level': 'INFO', @@ -53,7 +53,7 @@ def _load_admin_cfg(env): 'yaml_dir': os.path.join(TEST_DIR, 'data'), 'user': 'test_user', 'ami': 'single_ami', 'forge_version': False, 'config_dir': os.path.join(TEST_CFG_DIR, 'dev'), 'region': 'us-east-1', 'destroy_after_success': True, 'destroy_after_failure': True, 'default_ratio': [8, 8], 'valid_time': 8, - 'ec2_max': 768}), + 'ec2_max': 768, 'spot_strategy': 'price-capacity-optimized'}), # Destroy job; overriding log_level (['forge', 'destroy', '--yaml', os.path.join(TEST_DIR, 'data', 'single_intermediate.yaml'), '--log_level', 'debug'], {'name': 'test-single-intermediate', 'log_level': 'DEBUG', @@ -63,7 +63,7 @@ def _load_admin_cfg(env): 'home_dir': os.path.dirname(FORGE_DIR), 'yaml_dir': os.path.join(TEST_DIR, 'data'), 'user': 'test_user', 'ami': 'single_ami', 'forge_version': False, 'config_dir': os.path.join(TEST_CFG_DIR, 'dev'), 'region': 'us-east-1', 'destroy_after_success': True, 'destroy_after_failure': True, 'default_ratio': [8, 8], - 'valid_time': 8, 'ec2_max': 768}), + 'valid_time': 8, 'ec2_max': 768, 'spot_strategy': 'price-capacity-optimized'}), # Destroy job; overriding market (['forge', 'destroy', '--yaml', os.path.join(TEST_DIR, 'data', 'single_intermediate.yaml'), '--market', 'on-demand'], @@ -73,7 +73,8 @@ def _load_admin_cfg(env): 'gpu_flag': False, 'app_dir': TEST_DIR, 'src_dir': FORGE_DIR, 'home_dir': os.path.dirname(FORGE_DIR), 'yaml_dir': os.path.join(TEST_DIR, 'data'), 'user': 'test_user', 'ami': 'single_ami', 'forge_version': False, 'config_dir': os.path.join(TEST_CFG_DIR, 'dev'), 'region': 'us-east-1', 'destroy_after_success': True, - 'destroy_after_failure': True, 'default_ratio': [8, 8], 'valid_time': 8, 'ec2_max': 768}), + 'destroy_after_failure': True, 'default_ratio': [8, 8], 'valid_time': 8, 'ec2_max': 768, + 'spot_strategy': 'price-capacity-optimized'}), # Destroy job; no market (['forge', 'destroy', '--yaml', os.path.join(TEST_DIR, 'data', 'single_basic.yaml'), '--forge_env', 'dev'], {'name': 'test-single-basic', 'log_level': 'INFO', 'yaml': os.path.join(TEST_DIR, 'data', 'single_basic.yaml'), @@ -82,7 +83,7 @@ def _load_admin_cfg(env): 'home_dir': os.path.dirname(FORGE_DIR), 'yaml_dir': os.path.join(TEST_DIR, 'data'), 'user': 'test_user', 'ami': 'single_ami', 'forge_version': False, 'config_dir': os.path.join(TEST_CFG_DIR, 'dev'), 'region': 'us-east-1', 'destroy_after_success': True, 'destroy_after_failure': True, 'default_ratio': [8, 8], - 'valid_time': 8, 'ec2_max': 768}), + 'valid_time': 8, 'ec2_max': 768, 'spot_strategy': 'price-capacity-optimized'}), # Configure job (['forge', 'configure'], {'forge_version': False, 'job': 'configure', 'log_level': 'INFO'}), @@ -95,7 +96,7 @@ def _load_admin_cfg(env): 'home_dir': os.path.dirname(FORGE_DIR), 'yaml_dir': os.path.join(TEST_DIR, 'data'), 'user': 'test_user', 'ami': 'single_ami', 'forge_version': False, 'config_dir': os.path.join(TEST_CFG_DIR, 'dev'), 'region': 'us-east-1', 'destroy_after_success': True, 'destroy_after_failure': True, 'default_ratio': [8, 8], - 'valid_time': 8, 'ec2_max': 768}), + 'valid_time': 8, 'ec2_max': 768, 'spot_strategy': 'price-capacity-optimized'}), # Create job; setting gpu (['forge', 'create', '--yaml', os.path.join(TEST_DIR, 'data', 'single_basic.yaml'), '--forge_env', 'dev', '--gpu'], {'name': 'test-single-basic', 'log_level': 'INFO', 'yaml': os.path.join(TEST_DIR, 'data', 'single_basic.yaml'), @@ -104,7 +105,7 @@ def _load_admin_cfg(env): 'home_dir': os.path.dirname(FORGE_DIR), 'yaml_dir': os.path.join(TEST_DIR, 'data'), 'user': 'test_user', 'ami': 'single_ami', 'forge_version': False, 'config_dir': os.path.join(TEST_CFG_DIR, 'dev'), 'region': 'us-east-1', 'destroy_after_success': True, 'destroy_after_failure': True, 'default_ratio': [8, 8], - 'valid_time': 8, 'ec2_max': 768}), + 'valid_time': 8, 'ec2_max': 768, 'spot_strategy': 'price-capacity-optimized'}), ]) def test_forge_main(mock_pass, mock_execute, mock_keys, mock_config_dir, cli_call, exp_config, load_admin_cfg): """Test the config after calling forge via the command line.""" diff --git a/tests/test_rsync.py b/tests/test_rsync.py index c71dabe..34ebef2 100644 --- a/tests/test_rsync.py +++ b/tests/test_rsync.py @@ -167,6 +167,7 @@ def test_rsync_fail(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_chk, 'region': 'us-east-1', 'aws_profile': 'dev', 'rsync_path': rsync_path, + 'job': 'rsync', } expected_cmd = 'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o' expected_cmd += f' StrictHostKeyChecking=no -i {key_path}" {rsync_path}/* root@{ip}:/root/' @@ -175,7 +176,7 @@ def test_rsync_fail(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_chk, returncode=123, cmd=expected_cmd ) - rsync.rsync(config) + assert rsync.rsync(config) == 123 mock_ec2_ip.assert_called_once_with( f"{config['name']}-spot-single-{config['date']}", config diff --git a/tests/test_run.py b/tests/test_run.py index c5340ce..ed2769c 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -72,6 +72,7 @@ def test_run_error(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_run, caplog 'aws_profile': 'dev', 'forge_env': 'test', 'run_cmd': 'dummy.sh dev test', + 'job': 'run', } expected_cmd = [ 'ssh', '-t', '-o', 'UserKnownHostsFile=/dev/null', '-o', diff --git a/tests/test_start.py b/tests/test_start.py index 9c0abd5..c3513ac 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -55,14 +55,8 @@ def test_start_error_in_spot_instance(mock_start_fleet, caplog, service, markets "service": service, } error_msg = "" - if service == "cluster": - if markets[0] == "spot": - error_msg = "Master is a spot instance; you cannot start a spot instance" - elif markets[1] == "spot": - error_msg = "Worker is a spot fleet; you cannot start a spot fleet" - else: - if markets[0] == "spot": - error_msg = "The instance is a spot instance; you cannot start a spot instance" + if 'spot' in markets: + error_msg = 'Master or worker is a spot instance; you cannot start a spot instance' with caplog.at_level(logging.ERROR): start.start(config) diff --git a/tests/test_yaml_loader.py b/tests/test_yaml_loader.py index 430f7d4..1e81171 100644 --- a/tests/test_yaml_loader.py +++ b/tests/test_yaml_loader.py @@ -164,6 +164,11 @@ def test_check_user_yaml_invalid(mock_exit, bad_config, error_msg, caplog): 'date': '2021-01-01', 'forge_env': 'dev'}, {'forge_env': 'dev', 'service': 'single', 'ram': [[64]], 'aws_role': 'forge-test_role-dev', 'run_cmd': 'dummy.sh dev test', 'gpu_flag': True}), + # Job with runtime override of gpu instance + ({'yaml': os.path.join(TEST_DIR, 'data', 'single_basic_gpu.yaml'), + 'date': '2021-01-01', 'forge_env': 'dev', 'gpu_flag': False}, + {'forge_env': 'dev', 'service': 'single', 'ram': [[64]], 'aws_role': 'forge-test_role-dev', + 'run_cmd': 'dummy.sh dev test', 'gpu_flag': False}), # Job with no runtime overrides and on-demand instance ({'yaml': os.path.join(TEST_DIR, 'data', 'single_basic_ondemand.yaml'), 'date': '2021-01-01', 'forge_env': 'dev'},