From d681fe620d2da89873f33d6a1514c7decc390a33 Mon Sep 17 00:00:00 2001 From: stxue1 <122345910+stxue1@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:21:32 -0700 Subject: [PATCH] Replace all usage of boto2 with boto3 (#4868) * Take out boto2 from awsProvisioner.py * Add mypy stub file for s3 * Lazy import aws to avoid dependency if extra is not installed yet * Also lazy import in tests * Separate out wdl kubernetes test to avoid missing dependency * Add unittest main * Fix wdl CI to run separated tests * Fix typo in lookup * Update moto and remove leftover line in node.py * Remove all instances of boto * Fix issues with boto return types and grab attributes before deleting * Remove some unnecessary abstraction * Fix improperly types in ec2.py * Ensure UUID is a string for boto3 * No more boto * Remove comments * Move attribute initialization * Properly delete all attributes of the item * Move out pager and use pager for select to get around output limits * Turn getter into method * Remove comment in setup.py * Remove commented dead import * Remove stray boto import * Apply suggestions from code review Co-authored-by: Adam Novak * Rename, rearrange some code * Revert not passing Value's to attributes when deleting attributes in SDB * Fix missed changed var names * Change ordering of jobstorexists exception to fix improper output on exception --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Adam Novak --- Makefile | 2 +- requirements-aws.txt | 1 - setup.py | 11 - src/toil/__init__.py | 231 ---------- src/toil/batchSystems/awsBatch.py | 10 +- src/toil/jobStores/abstractJobStore.py | 2 +- src/toil/jobStores/aws/jobStore.py | 485 +++++++++++--------- src/toil/jobStores/aws/utils.py | 62 ++- src/toil/jobStores/conftest.py | 4 +- src/toil/lib/aws/__init__.py | 11 +- src/toil/lib/aws/session.py | 1 - src/toil/lib/aws/utils.py | 51 +- src/toil/lib/ec2.py | 29 +- src/toil/provisioners/aws/awsProvisioner.py | 49 +- src/toil/test/__init__.py | 7 +- src/toil/test/jobStores/jobStoreTest.py | 16 +- src/toil/test/provisioners/clusterTest.py | 6 +- 17 files changed, 428 insertions(+), 550 deletions(-) diff --git a/Makefile b/Makefile index a8c457b819..2355b7d080 100644 --- a/Makefile +++ b/Makefile @@ -135,7 +135,7 @@ clean_sdist: # Setting SET_OWNER_TAG will tag cloud resources so that UCSC's cloud murder bot won't kill them. test: check_venv check_build_reqs TOIL_OWNER_TAG="shared" \ - python -m pytest --durations=0 --strict-markers --log-level DEBUG --log-cli-level INFO -r s $(cov) -n $(threads) --dist loadscope $(tests) -m "$(marker)" + python -m pytest --durations=0 --strict-markers --log-level DEBUG --log-cli-level INFO -r s $(cov) -n $(threads) --dist loadscope $(tests) -m "$(marker)" --color=yes test_debug: check_venv check_build_reqs TOIL_OWNER_TAG="$(whoami)" \ diff --git a/requirements-aws.txt b/requirements-aws.txt index b5a76b21ee..e15b9fb959 100644 --- a/requirements-aws.txt +++ b/requirements-aws.txt @@ -1,4 +1,3 @@ -boto>=2.48.0, <3 boto3-stubs[s3,sdb,iam,sts,boto3,ec2,autoscaling]>=1.28.3.post2, <2 mypy-boto3-iam>=1.28.3.post2, <2 # Need to force .post1 to be replaced mypy-boto3-s3>=1.28.3.post2, <2 diff --git a/setup.py b/setup.py index 4430d9dc49..6847542799 100755 --- a/setup.py +++ b/setup.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# import imp import os import types from importlib.machinery import SourceFileLoader @@ -142,16 +141,6 @@ def import_version(): })) os.rename(f.name, 'src/toil/version.py') - # Unfortunately, we can't use a straight import here because that would also load the stuff - # defined in "src/toil/__init__.py" which imports modules from external dependencies that may - # yet to be installed when setup.py is invoked. - # - # This is also the reason we cannot switch from the "deprecated" imp library - # and use: - # from importlib.machinery import SourceFileLoader - # return SourceFileLoader('toil.version', path='src/toil/version.py').load_module() - # - # Because SourceFileLoader will error and load "src/toil/__init__.py" . loader = SourceFileLoader('toil.version', 'src/toil/version.py') mod = types.ModuleType(loader.name) loader.exec_module(mod) diff --git a/src/toil/__init__.py b/src/toil/__init__.py index e231f68646..375ca4b7ca 100644 --- a/src/toil/__init__.py +++ b/src/toil/__init__.py @@ -440,7 +440,6 @@ def logProcessContext(config: "Config") -> None: try: - from boto import provider from botocore.credentials import (JSONFileCache, RefreshableCredentials, create_credential_resolver) @@ -476,235 +475,5 @@ def str_to_datetime(s): datetime.datetime(1970, 1, 1, 0, 0) """ return datetime.strptime(s, datetime_format) - - - class BotoCredentialAdapter(provider.Provider): - """ - Boto 2 Adapter to use AWS credentials obtained via Boto 3's credential finding logic. - - This allows for automatic role assumption - respecting the Boto 3 config files, even when parts of the app still use - Boto 2. - - This class also handles caching credentials in multi-process environments - to avoid loads of processes swamping the EC2 metadata service. - """ - - # TODO: We take kwargs because new boto2 versions have an 'anon' - # argument and we want to be future proof - - def __init__(self, name, access_key=None, secret_key=None, - security_token=None, profile_name=None, **kwargs): - """Create a new BotoCredentialAdapter.""" - # TODO: We take kwargs because new boto2 versions have an 'anon' - # argument and we want to be future proof - - if (name == 'aws' or name is None) and access_key is None and not kwargs.get('anon', False): - # We are on AWS and we don't have credentials passed along and we aren't anonymous. - # We will backend into a boto3 resolver for getting credentials. - # Make sure to enable boto3's own caching, so we can share that - # cache with pure boto3 code elsewhere in Toil. - # Keep synced with toil.lib.aws.session.establish_boto3_session - self._boto3_resolver = create_credential_resolver(Session(profile=profile_name), cache=JSONFileCache()) - else: - # We will use the normal flow - self._boto3_resolver = None - - # Pass along all the arguments - super().__init__(name, access_key=access_key, - secret_key=secret_key, security_token=security_token, - profile_name=profile_name, **kwargs) - - def get_credentials(self, access_key=None, secret_key=None, security_token=None, profile_name=None): - """ - Make sure our credential fields are populated. - - Called by the base class constructor. - """ - if self._boto3_resolver is not None: - # Go get the credentials from the cache, or from boto3 if not cached. - # We need to be eager here; having the default None - # _credential_expiry_time makes the accessors never try to refresh. - self._obtain_credentials_from_cache_or_boto3() - else: - # We're not on AWS, or they passed a key, or we're anonymous. - # Use the normal route; our credentials shouldn't expire. - super().get_credentials(access_key=access_key, - secret_key=secret_key, security_token=security_token, - profile_name=profile_name) - - def _populate_keys_from_metadata_server(self): - """ - Hack to catch _credential_expiry_time being too soon and refresh the credentials. - - This override is misnamed; it's actually the only hook we have to catch - _credential_expiry_time being too soon and refresh the credentials. We - actually just go back and poke the cache to see if it feels like - getting us new credentials. - - Boto 2 hardcodes a refresh within 5 minutes of expiry: - https://github.com/boto/boto/blob/591911db1029f2fbb8ba1842bfcc514159b37b32/boto/provider.py#L247 - - Boto 3 wants to refresh 15 or 10 minutes before expiry: - https://github.com/boto/botocore/blob/8d3ea0e61473fba43774eb3c74e1b22995ee7370/botocore/credentials.py#L279 - - So if we ever want to refresh, Boto 3 wants to refresh too. - """ - # This should only happen if we have expiring credentials, which we should only get from boto3 - if self._boto3_resolver is None: - raise RuntimeError("The Boto3 resolver should not be None.") - - self._obtain_credentials_from_cache_or_boto3() - - @retry() - def _obtain_credentials_from_boto3(self): - """ - Fill our credential fields from Boto 3. - - We know the current cached credentials are not good, and that we - need to get them from Boto 3. Fill in our credential fields - (_access_key, _secret_key, _security_token, - _credential_expiry_time) from Boto 3. - """ - # We get a Credentials object - # - # or a RefreshableCredentials, or None on failure. - creds = self._boto3_resolver.load_credentials() - - if creds is None: - try: - resolvers = str(self._boto3_resolver.providers) - except: - resolvers = "(Resolvers unavailable)" - raise RuntimeError("Could not obtain AWS credentials from Boto3. Resolvers tried: " + resolvers) - - # Make sure the credentials actually has some credentials if it is lazy - creds.get_frozen_credentials() - - # Get when the credentials will expire, if ever - if isinstance(creds, RefreshableCredentials): - # Credentials may expire. - # Get a naive UTC datetime like boto 2 uses from the boto 3 time. - self._credential_expiry_time = creds._expiry_time.astimezone(timezone('UTC')).replace(tzinfo=None) - else: - # Credentials never expire - self._credential_expiry_time = None - - # Then, atomically get all the credentials bits. They may be newer than we think they are, but never older. - frozen = creds.get_frozen_credentials() - - # Copy them into us - self._access_key = frozen.access_key - self._secret_key = frozen.secret_key - self._security_token = frozen.token - - def _obtain_credentials_from_cache_or_boto3(self): - """ - Get the cached credentials. - - Or retrieve them from Boto 3 and cache them - (or wait for another cooperating process to do so) if they are missing - or not fresh enough. - """ - cache_path = '~/.cache/aws/cached_temporary_credentials' - path = os.path.expanduser(cache_path) - tmp_path = path + '.tmp' - while True: - log.debug('Attempting to read cached credentials from %s.', path) - try: - with open(path) as f: - content = f.read() - if content: - record = content.split('\n') - if len(record) != 4: - raise RuntimeError("Number of cached credentials is not 4.") - self._access_key = record[0] - self._secret_key = record[1] - self._security_token = record[2] - self._credential_expiry_time = str_to_datetime(record[3]) - else: - log.debug('%s is empty. Credentials are not temporary.', path) - self._obtain_credentials_from_boto3() - return - except OSError as e: - if e.errno == errno.ENOENT: - log.debug('Cached credentials are missing.') - dir_path = os.path.dirname(path) - if not os.path.exists(dir_path): - log.debug('Creating parent directory %s', dir_path) - try: - # A race would be ok at this point - os.makedirs(dir_path, exist_ok=True) - except OSError as e2: - if e2.errno == errno.EROFS: - # Sometimes we don't actually have write access to ~. - # We may be running in a non-writable Toil container. - # We should just go get our own credentials - log.debug('Cannot use the credentials cache because we are working on a read-only filesystem.') - self._obtain_credentials_from_boto3() - else: - raise - else: - raise - else: - if self._credentials_need_refresh(): - log.debug('Cached credentials are expired.') - else: - log.debug('Cached credentials exist and are still fresh.') - return - # We get here if credentials are missing or expired - log.debug('Racing to create %s.', tmp_path) - # Only one process, the winner, will succeed - try: - fd = os.open(tmp_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600) - except OSError as e: - if e.errno == errno.EEXIST: - log.debug('Lost the race to create %s. Waiting on winner to remove it.', tmp_path) - while os.path.exists(tmp_path): - time.sleep(0.1) - log.debug('Winner removed %s. Trying from the top.', tmp_path) - else: - raise - else: - try: - log.debug('Won the race to create %s. Requesting credentials from backend.', tmp_path) - self._obtain_credentials_from_boto3() - except: - os.close(fd) - fd = None - log.debug('Failed to obtain credentials, removing %s.', tmp_path) - # This unblocks the losers. - os.unlink(tmp_path) - # Bail out. It's too likely to happen repeatedly - raise - else: - if self._credential_expiry_time is None: - os.close(fd) - fd = None - log.debug('Credentials are not temporary. Leaving %s empty and renaming it to %s.', - tmp_path, path) - # No need to actually cache permanent credentials, - # because we know we aren't getting them from the - # metadata server or by assuming a role. Those both - # give temporary credentials. - else: - log.debug('Writing credentials to %s.', tmp_path) - with os.fdopen(fd, 'w') as fh: - fd = None - fh.write('\n'.join([ - self._access_key, - self._secret_key, - self._security_token, - datetime_to_str(self._credential_expiry_time)])) - log.debug('Wrote credentials to %s. Renaming to %s.', tmp_path, path) - os.rename(tmp_path, path) - return - finally: - if fd is not None: - os.close(fd) - - - provider.Provider = BotoCredentialAdapter - except ImportError: pass diff --git a/src/toil/batchSystems/awsBatch.py b/src/toil/batchSystems/awsBatch.py index b0b564768a..29751f370d 100644 --- a/src/toil/batchSystems/awsBatch.py +++ b/src/toil/batchSystems/awsBatch.py @@ -36,7 +36,7 @@ from argparse import ArgumentParser, _ArgumentGroup from typing import Any, Dict, Iterator, List, Optional, Set, Union -from boto.exception import BotoServerError +from botocore.exceptions import ClientError from toil import applianceSelf from toil.batchSystems.abstractBatchSystem import (EXIT_STATUS_UNAVAILABLE_VALUE, @@ -376,7 +376,7 @@ def shutdown(self) -> None: # Get rid of the job definition we are using if we can. self._destroy_job_definition() - @retry(errors=[BotoServerError]) + @retry(errors=[ClientError]) def _try_terminate(self, aws_id: str) -> None: """ Internal function. Should not be called outside this class. @@ -392,7 +392,7 @@ def _try_terminate(self, aws_id: str) -> None: # Kill the AWS Batch job self.client.terminate_job(jobId=aws_id, reason='Killed by Toil') - @retry(errors=[BotoServerError]) + @retry(errors=[ClientError]) def _wait_until_stopped(self, aws_id: str) -> None: """ Internal function. Should not be called outside this class. @@ -418,7 +418,7 @@ def _wait_until_stopped(self, aws_id: str) -> None: logger.info('Waiting for killed job %s to stop', self.aws_id_to_bs_id.get(aws_id, aws_id)) time.sleep(2) - @retry(errors=[BotoServerError]) + @retry(errors=[ClientError]) def _get_or_create_job_definition(self) -> str: """ Internal function. Should not be called outside this class. @@ -482,7 +482,7 @@ def _get_or_create_job_definition(self) -> str: return self.job_definition - @retry(errors=[BotoServerError]) + @retry(errors=[ClientError]) def _destroy_job_definition(self) -> None: """ Internal function. Should not be called outside this class. diff --git a/src/toil/jobStores/abstractJobStore.py b/src/toil/jobStores/abstractJobStore.py index aff0ddfb18..3b3aead3dd 100644 --- a/src/toil/jobStores/abstractJobStore.py +++ b/src/toil/jobStores/abstractJobStore.py @@ -155,7 +155,7 @@ def __init__(self, locator: str, prefix: str): class JobStoreExistsException(LocatorException): """Indicates that the specified job store already exists.""" - def __init__(self, prefix: str, locator: str): + def __init__(self, locator: str, prefix: str): """ :param str locator: The location of the job store """ diff --git a/src/toil/jobStores/aws/jobStore.py b/src/toil/jobStores/aws/jobStore.py index 82f3c522ba..08800262f8 100644 --- a/src/toil/jobStores/aws/jobStore.py +++ b/src/toil/jobStores/aws/jobStore.py @@ -23,16 +23,16 @@ import uuid from contextlib import contextmanager from io import BytesIO -from typing import List, Optional, IO +from typing import List, Optional, IO, Dict, Union, Generator, Tuple, cast, TYPE_CHECKING from urllib.parse import ParseResult, parse_qs, urlencode, urlsplit, urlunsplit -import boto.s3.connection -import boto.sdb -from boto.exception import SDBResponseError from botocore.exceptions import ClientError +from mypy_boto3_sdb import SimpleDBClient +from mypy_boto3_sdb.type_defs import ReplaceableItemTypeDef, ReplaceableAttributeTypeDef, SelectResultTypeDef, ItemTypeDef, AttributeTypeDef, DeletableItemTypeDef, UpdateConditionTypeDef import toil.lib.encryption as encryption from toil.fileStores import FileID +from toil.job import Job, JobDescription from toil.jobStores.abstractJobStore import (AbstractJobStore, ConcurrentFileModificationException, JobStoreExistsException, @@ -43,7 +43,6 @@ ServerSideCopyProhibitedError, copyKeyMultipart, fileSizeAndTime, - monkeyPatchSdbConnection, no_such_sdb_domain, retry_sdb, sdb_unavailable, @@ -61,7 +60,7 @@ get_object_for_url, list_objects_for_url, retry_s3, - retryable_s3_errors) + retryable_s3_errors, boto3_pager, get_item_from_attributes) from toil.lib.compatibility import compat_bytes from toil.lib.ec2nodes import EC2Regions from toil.lib.exceptions import panic @@ -70,6 +69,9 @@ from toil.lib.objects import InnerClass from toil.lib.retry import get_error_code, get_error_status, retry +if TYPE_CHECKING: + from toil import Config + boto3_session = establish_boto3_session() s3_boto3_resource = boto3_session.resource('s3') s3_boto3_client = boto3_session.client('s3') @@ -85,6 +87,12 @@ class ChecksumError(Exception): """Raised when a download from AWS does not contain the correct data.""" +class DomainDoesNotExist(Exception): + """Raised when a domain that is expected to exist does not exist.""" + def __init__(self, domain_name): + super().__init__(f"Expected domain {domain_name} to exist!") + + class AWSJobStore(AbstractJobStore): """ A job store that uses Amazon's S3 for file storage and SimpleDB for storing job info and @@ -134,17 +142,17 @@ def __init__(self, locator: str, partSize: int = 50 << 20) -> None: logger.debug("Instantiating %s for region %s and name prefix '%s'", self.__class__, region, namePrefix) self.region = region - self.namePrefix = namePrefix - self.partSize = partSize - self.jobsDomain = None - self.filesDomain = None - self.filesBucket = None - self.db = self._connectSimpleDB() + self.name_prefix = namePrefix + self.part_size = partSize + self.jobs_domain_name: Optional[str] = None + self.files_domain_name: Optional[str] = None + self.files_bucket = None + self.db = boto3_session.client(service_name="sdb", region_name=region) self.s3_resource = boto3_session.resource('s3', region_name=self.region) self.s3_client = self.s3_resource.meta.client - def initialize(self, config): + def initialize(self, config: "Config") -> None: if self._registered: raise JobStoreExistsException(self.locator, "aws") self._registered = None @@ -159,36 +167,38 @@ def initialize(self, config): self._registered = True @property - def sseKeyPath(self): + def sseKeyPath(self) -> Optional[str]: return self.config.sseKey - def resume(self): + def resume(self) -> None: if not self._registered: raise NoSuchJobStoreException(self.locator, "aws") self._bind(create=False) super().resume() - def _bind(self, create=False, block=True, check_versioning_consistency=True): + def _bind(self, create: bool = False, block: bool = True, check_versioning_consistency: bool = True) -> None: def qualify(name): assert len(name) <= self.maxNameLen - return self.namePrefix + self.nameSeparator + name + return self.name_prefix + self.nameSeparator + name # The order in which this sequence of events happens is important. We can easily handle the # inability to bind a domain, but it is a little harder to handle some cases of binding the # jobstore bucket. Maintaining this order allows for an easier `destroy` method. - if self.jobsDomain is None: - self.jobsDomain = self._bindDomain(qualify('jobs'), create=create, block=block) - if self.filesDomain is None: - self.filesDomain = self._bindDomain(qualify('files'), create=create, block=block) - if self.filesBucket is None: - self.filesBucket = self._bindBucket(qualify('files'), - create=create, - block=block, - versioning=True, - check_versioning_consistency=check_versioning_consistency) + if self.jobs_domain_name is None: + self.jobs_domain_name = qualify("jobs") + self._bindDomain(self.jobs_domain_name, create=create, block=block) + if self.files_domain_name is None: + self.files_domain_name = qualify("files") + self._bindDomain(self.files_domain_name, create=create, block=block) + if self.files_bucket is None: + self.files_bucket = self._bindBucket(qualify('files'), + create=create, + block=block, + versioning=True, + check_versioning_consistency=check_versioning_consistency) @property - def _registered(self): + def _registered(self) -> Optional[bool]: """ A optional boolean property indicating whether this job store is registered. The registry is the authority on deciding if a job store exists or not. If True, this job @@ -205,55 +215,60 @@ def _registered(self): # store destruction, indicates a job store in transition, reflecting the fact that 3.3.0 # may leak buckets or domains even though the registry reports 'False' for them. We # can't handle job stores that were partially created by 3.3.0, though. - registry_domain = self._bindDomain(domain_name='toil-registry', - create=False, - block=False) - if registry_domain is None: + registry_domain_name = "toil-registry" + try: + self._bindDomain(domain_name=registry_domain_name, + create=False, + block=False) + except DomainDoesNotExist: return False - else: - for attempt in retry_sdb(): - with attempt: - attributes = registry_domain.get_attributes(item_name=self.namePrefix, - attribute_name='exists', - consistent_read=True) - try: - exists = attributes['exists'] - except KeyError: - return False - else: - if exists == 'True': - return True - elif exists == 'False': - return None - else: - assert False + + for attempt in retry_sdb(): + with attempt: + get_result = self.db.get_attributes(DomainName=registry_domain_name, + ItemName=self.name_prefix, + AttributeNames=['exists'], + ConsistentRead=True) + attributes: List[AttributeTypeDef] = get_result.get("Attributes", []) # the documentation says 'Attributes' should always exist, but this is not true + exists: Optional[str] = get_item_from_attributes(attributes=attributes, name="exists") + if exists is None: + return False + elif exists == 'True': + return True + elif exists == 'False': + return None + else: + assert False @_registered.setter - def _registered(self, value): - - registry_domain = self._bindDomain(domain_name='toil-registry', - # Only create registry domain when registering or - # transitioning a store - create=value is not False, - block=False) - if registry_domain is None and value is False: + def _registered(self, value: bool) -> None: + registry_domain_name = "toil-registry" + try: + self._bindDomain(domain_name=registry_domain_name, + # Only create registry domain when registering or + # transitioning a store + create=value is not False, + block=False) + except DomainDoesNotExist: pass else: for attempt in retry_sdb(): with attempt: if value is False: - registry_domain.delete_attributes(item_name=self.namePrefix) + self.db.delete_attributes(DomainName=registry_domain_name, + ItemName=self.name_prefix) else: if value is True: - attributes = dict(exists='True') + attributes: List[ReplaceableAttributeTypeDef] = [{"Name": "exists", "Value": "True", "Replace": True}] elif value is None: - attributes = dict(exists='False') + attributes = [{"Name": "exists", "Value": "False", "Replace": True}] else: assert False - registry_domain.put_attributes(item_name=self.namePrefix, - attributes=attributes) + self.db.put_attributes(DomainName=registry_domain_name, + ItemName=self.name_prefix, + Attributes=attributes) - def _checkItem(self, item, enforce: bool = True): + def _checkItem(self, item: ItemTypeDef, enforce: bool = True) -> None: """ Make sure that the given SimpleDB item actually has the attributes we think it should. @@ -261,32 +276,48 @@ def _checkItem(self, item, enforce: bool = True): If enforce is false, log but don't throw. """ + self._checkAttributes(item["Attributes"], enforce) - if "overlargeID" not in item: + def _checkAttributes(self, attributes: List[AttributeTypeDef], enforce: bool = True) -> None: + if get_item_from_attributes(attributes=attributes, name="overlargeID") is None: logger.error("overlargeID attribute isn't present: either SimpleDB entry is " - "corrupt or jobstore is from an extremely old Toil: %s", item) + "corrupt or jobstore is from an extremely old Toil: %s", attributes) if enforce: raise RuntimeError("encountered SimpleDB entry missing required attribute " "'overlargeID'; is your job store ancient?") - def _awsJobFromItem(self, item): - self._checkItem(item) - if item.get("overlargeID", None): - assert self.file_exists(item["overlargeID"]) + def _awsJobFromAttributes(self, attributes: List[AttributeTypeDef]) -> Job: + """ + Get a Toil Job object from attributes that are defined in an item from the DB + :param attributes: List of attributes + :return: Toil job + """ + self._checkAttributes(attributes) + overlarge_id_value = get_item_from_attributes(attributes=attributes, name="overlargeID") + if overlarge_id_value: + assert self.file_exists(overlarge_id_value) # This is an overlarge job, download the actual attributes # from the file store logger.debug("Loading overlarge job from S3.") - with self.read_file_stream(item["overlargeID"]) as fh: + with self.read_file_stream(overlarge_id_value) as fh: binary = fh.read() else: - binary, _ = SDBHelper.attributesToBinary(item) + binary, _ = SDBHelper.attributesToBinary(attributes) assert binary is not None job = pickle.loads(binary) if job is not None: job.assignConfig(self.config) return job - def _awsJobToItem(self, job): + def _awsJobFromItem(self, item: ItemTypeDef) -> Job: + """ + Get a Toil Job object from an item from the DB + :param item: ItemTypeDef + :return: Toil Job + """ + return self._awsJobFromAttributes(item["Attributes"]) + + def _awsJobToAttributes(self, job: JobDescription) -> List[AttributeTypeDef]: binary = pickle.dumps(job, protocol=pickle.HIGHEST_PROTOCOL) if len(binary) > SDBHelper.maxBinarySize(extraReservedChunks=1): # Store as an overlarge job in S3 @@ -297,65 +328,82 @@ def _awsJobToItem(self, job): else: item = SDBHelper.binaryToAttributes(binary) item["overlargeID"] = "" - return item + return SDBHelper.attributeDictToList(item) + + def _awsJobToItem(self, job: JobDescription, name: str) -> ItemTypeDef: + return {"Name": name, "Attributes": self._awsJobToAttributes(job)} jobsPerBatchInsert = 25 @contextmanager - def batch(self): + def batch(self) -> None: self._batchedUpdates = [] yield batches = [self._batchedUpdates[i:i + self.jobsPerBatchInsert] for i in range(0, len(self._batchedUpdates), self.jobsPerBatchInsert)] for batch in batches: + items: List[ReplaceableItemTypeDef] = [] for jobDescription in batch: + item_attributes: List[ReplaceableAttributeTypeDef] = [] jobDescription.pre_update_hook() - items = {compat_bytes(jobDescription.jobStoreID): self._awsJobToItem(jobDescription) for jobDescription in batch} + item_name = compat_bytes(jobDescription.jobStoreID) + got_job_attributes: List[AttributeTypeDef] = self._awsJobToAttributes(jobDescription) + for each_attribute in got_job_attributes: + new_attribute: ReplaceableAttributeTypeDef = {"Name": each_attribute["Name"], + "Value": each_attribute["Value"], + "Replace": True} + item_attributes.append(new_attribute) + items.append({"Name": item_name, + "Attributes": item_attributes}) + for attempt in retry_sdb(): with attempt: - assert self.jobsDomain.batch_put_attributes(items) + self.db.batch_put_attributes(DomainName=self.jobs_domain_name, Items=items) self._batchedUpdates = None - def assign_job_id(self, job_description): + def assign_job_id(self, job_description: JobDescription) -> None: jobStoreID = self._new_job_id() logger.debug("Assigning ID to job %s", jobStoreID) job_description.jobStoreID = jobStoreID - def create_job(self, job_description): + def create_job(self, job_description: JobDescription) -> JobDescription: if hasattr(self, "_batchedUpdates") and self._batchedUpdates is not None: self._batchedUpdates.append(job_description) else: self.update_job(job_description) return job_description - def job_exists(self, job_id): + def job_exists(self, job_id: Union[bytes, str]) -> bool: for attempt in retry_sdb(): with attempt: - return bool(self.jobsDomain.get_attributes( - item_name=compat_bytes(job_id), - attribute_name=[SDBHelper.presenceIndicator()], - consistent_read=True)) + return len(self.db.get_attributes(DomainName=self.jobs_domain_name, + ItemName=compat_bytes(job_id), + AttributeNames=[SDBHelper.presenceIndicator()], + ConsistentRead=True).get("Attributes", [])) > 0 - def jobs(self): - result = None + def jobs(self) -> Generator[Job, None, None]: + job_items: Optional[List[ItemTypeDef]] = None for attempt in retry_sdb(): with attempt: - result = list(self.jobsDomain.select( - consistent_read=True, - query="select * from `%s`" % self.jobsDomain.name)) - assert result is not None - for jobItem in result: + job_items = boto3_pager(self.db.select, + "Items", + ConsistentRead=True, + SelectExpression="select * from `%s`" % self.jobs_domain_name) + assert job_items is not None + for jobItem in job_items: yield self._awsJobFromItem(jobItem) - def load_job(self, job_id): - item = None + def load_job(self, job_id: FileID) -> Job: + item_attributes = None for attempt in retry_sdb(): with attempt: - item = self.jobsDomain.get_attributes(compat_bytes(job_id), consistent_read=True) - if not item: + item_attributes = self.db.get_attributes(DomainName=self.jobs_domain_name, + ItemName=compat_bytes(job_id), + ConsistentRead=True).get("Attributes", []) + if not item_attributes: raise NoSuchJobException(job_id) - job = self._awsJobFromItem(item) + job = self._awsJobFromAttributes(item_attributes) if job is None: raise NoSuchJobException(job_id) logger.debug("Loaded job %s", job_id) @@ -364,10 +412,12 @@ def load_job(self, job_id): def update_job(self, job_description): logger.debug("Updating job %s", job_description.jobStoreID) job_description.pre_update_hook() - item = self._awsJobToItem(job_description) + job_attributes = self._awsJobToAttributes(job_description) + update_attributes: List[ReplaceableAttributeTypeDef] = [{"Name": attribute["Name"], "Value": attribute["Value"], "Replace": True} + for attribute in job_attributes] for attempt in retry_sdb(): with attempt: - assert self.jobsDomain.put_attributes(compat_bytes(job_description.jobStoreID), item) + self.db.put_attributes(DomainName=self.jobs_domain_name, ItemName=compat_bytes(job_description.jobStoreID), Attributes=update_attributes) itemsPerBatchDelete = 25 @@ -376,49 +426,53 @@ def delete_job(self, job_id): logger.debug("Deleting job %s", job_id) # If the job is overlarge, delete its file from the filestore - item = None for attempt in retry_sdb(): with attempt: - item = self.jobsDomain.get_attributes(compat_bytes(job_id), consistent_read=True) + attributes = self.db.get_attributes(DomainName=self.jobs_domain_name, + ItemName=compat_bytes(job_id), + ConsistentRead=True).get("Attributes", []) # If the overlargeID has fallen off, maybe we partially deleted the # attributes of the item? Or raced on it? Or hit SimpleDB being merely # eventually consistent? We should still be able to get rid of it. - self._checkItem(item, enforce = False) - if item.get("overlargeID", None): + self._checkAttributes(attributes, enforce=False) + overlarge_id_value = get_item_from_attributes(attributes=attributes, name="overlargeID") + if overlarge_id_value: logger.debug("Deleting job from filestore") - self.delete_file(item["overlargeID"]) + self.delete_file(overlarge_id_value) for attempt in retry_sdb(): with attempt: - self.jobsDomain.delete_attributes(item_name=compat_bytes(job_id)) - items = None + self.db.delete_attributes(DomainName=self.jobs_domain_name, ItemName=compat_bytes(job_id)) + items: Optional[List[ItemTypeDef]] = None for attempt in retry_sdb(): with attempt: - items = list(self.filesDomain.select( - consistent_read=True, - query=f"select version from `{self.filesDomain.name}` where ownerID='{job_id}'")) + items = list(boto3_pager(self.db.select, + "Items", + ConsistentRead=True, + SelectExpression=f"select version from `{self.files_domain_name}` where ownerID='{job_id}'")) assert items is not None if items: logger.debug("Deleting %d file(s) associated with job %s", len(items), job_id) n = self.itemsPerBatchDelete batches = [items[i:i + n] for i in range(0, len(items), n)] for batch in batches: - itemsDict = {item.name: None for item in batch} + delete_items: List[DeletableItemTypeDef] = [{"Name": item["Name"]} for item in batch] for attempt in retry_sdb(): with attempt: - self.filesDomain.batch_delete_attributes(itemsDict) + self.db.batch_delete_attributes(DomainName=self.files_domain_name, Items=delete_items) for item in items: - version = item.get('version') + item: ItemTypeDef + version = get_item_from_attributes(attributes=item["Attributes"], name="version") for attempt in retry_s3(): with attempt: if version: - self.s3_client.delete_object(Bucket=self.filesBucket.name, - Key=compat_bytes(item.name), + self.s3_client.delete_object(Bucket=self.files_bucket.name, + Key=compat_bytes(item["Name"]), VersionId=version) else: - self.s3_client.delete_object(Bucket=self.filesBucket.name, - Key=compat_bytes(item.name)) + self.s3_client.delete_object(Bucket=self.files_bucket.name, + Key=compat_bytes(item["Name"])) - def get_empty_file_store_id(self, jobStoreID=None, cleanup=False, basename=None): + def get_empty_file_store_id(self, jobStoreID=None, cleanup=False, basename=None) -> FileID: info = self.FileInfo.create(jobStoreID if cleanup else None) with info.uploadStream() as _: # Empty @@ -427,7 +481,8 @@ def get_empty_file_store_id(self, jobStoreID=None, cleanup=False, basename=None) logger.debug("Created %r.", info) return info.fileID - def _import_file(self, otherCls, uri, shared_file_name=None, hardlink=False, symlink=True): + def _import_file(self, otherCls, uri: ParseResult, shared_file_name: Optional[str] = None, + hardlink: bool = False, symlink: bool = True) -> Optional[FileID]: try: if issubclass(otherCls, AWSJobStore): srcObj = get_object_for_url(uri, existing=True) @@ -450,7 +505,7 @@ def _import_file(self, otherCls, uri, shared_file_name=None, hardlink=False, sym # copy if exception return super()._import_file(otherCls, uri, shared_file_name=shared_file_name) - def _export_file(self, otherCls, file_id, uri): + def _export_file(self, otherCls, file_id: FileID, uri: ParseResult) -> None: try: if issubclass(otherCls, AWSJobStore): dstObj = get_object_for_url(uri) @@ -474,11 +529,11 @@ def _url_exists(cls, url: ParseResult) -> bool: return cls._get_is_directory(url) @classmethod - def _get_size(cls, url): + def _get_size(cls, url: ParseResult) -> int: return get_object_for_url(url, existing=True).content_length @classmethod - def _read_from_url(cls, url, writable): + def _read_from_url(cls, url: ParseResult, writable): srcObj = get_object_for_url(url, existing=True) srcObj.download_fileobj(writable) return ( @@ -496,7 +551,7 @@ def _open_url(cls, url: ParseResult) -> IO[bytes]: return response['Body'] @classmethod - def _write_to_url(cls, readable, url, executable=False): + def _write_to_url(cls, readable, url: ParseResult, executable: bool = False) -> None: dstObj = get_object_for_url(url) logger.debug("Uploading %s", dstObj.key) @@ -518,10 +573,10 @@ def _get_is_directory(cls, url: ParseResult) -> bool: return len(list_objects_for_url(url)) > 0 @classmethod - def _supports_url(cls, url, export=False): + def _supports_url(cls, url: ParseResult, export: bool = False) -> bool: return url.scheme.lower() == 's3' - def write_file(self, local_path, job_id=None, cleanup=False): + def write_file(self, local_path: FileID, job_id: Optional[FileID] = None, cleanup: bool = False) -> FileID: info = self.FileInfo.create(job_id if cleanup else None) info.upload(local_path, not self.config.disableJobStoreChecksumVerification) info.save() @@ -529,7 +584,7 @@ def write_file(self, local_path, job_id=None, cleanup=False): return info.fileID @contextmanager - def write_file_stream(self, job_id=None, cleanup=False, basename=None, encoding=None, errors=None): + def write_file_stream(self, job_id: Optional[FileID] = None, cleanup: bool = False, basename=None, encoding=None, errors=None): info = self.FileInfo.create(job_id if cleanup else None) with info.uploadStream(encoding=encoding, errors=errors) as writable: yield writable, info.fileID @@ -613,7 +668,7 @@ def read_logs(self, callback, read_all=False): itemsProcessed = 0 for info in self._read_logs(callback, self.statsFileOwnerID): - info._ownerID = self.readStatsFileOwnerID + info._ownerID = str(self.readStatsFileOwnerID) # boto3 requires strings info.save() itemsProcessed += 1 @@ -627,10 +682,10 @@ def _read_logs(self, callback, ownerId): items = None for attempt in retry_sdb(): with attempt: - items = list(self.filesDomain.select( - consistent_read=True, - query="select * from `{}` where ownerID='{}'".format( - self.filesDomain.name, str(ownerId)))) + items = boto3_pager(self.db.select, + "Items", + ConsistentRead=True, + SelectExpression="select * from `{}` where ownerID='{}'".format(self.files_domain_name, str(ownerId))) assert items is not None for item in items: info = self.FileInfo.fromItem(item) @@ -647,10 +702,10 @@ def get_public_url(self, jobStoreFileID): with info.uploadStream(allowInlining=False) as f: f.write(info.content) - self.filesBucket.Object(compat_bytes(jobStoreFileID)).Acl().put(ACL='public-read') + self.files_bucket.Object(compat_bytes(jobStoreFileID)).Acl().put(ACL='public-read') url = self.s3_client.generate_presigned_url('get_object', - Params={'Bucket': self.filesBucket.name, + Params={'Bucket': self.files_bucket.name, 'Key': compat_bytes(jobStoreFileID), 'VersionId': info.version}, ExpiresIn=self.publicUrlExpiration.total_seconds()) @@ -675,16 +730,6 @@ def get_shared_public_url(self, shared_file_name): self._requireValidSharedFileName(shared_file_name) return self.get_public_url(self._shared_file_id(shared_file_name)) - def _connectSimpleDB(self): - """ - :rtype: SDBConnection - """ - db = boto.sdb.connect_to_region(self.region) - if db is None: - raise ValueError("Could not connect to SimpleDB. Make sure '%s' is a valid SimpleDB region." % self.region) - monkeyPatchSdbConnection(db) - return db - def _bindBucket(self, bucket_name: str, create: bool = False, @@ -716,7 +761,7 @@ def bucket_retry_predicate(error): """ if (isinstance(error, ClientError) and - get_error_status(error) in (404, 409)): + get_error_status(error) in (404, 409)): # Handle cases where the bucket creation is in a weird state that might let us proceed. # https://github.com/BD2KGenomics/toil/issues/955 # https://github.com/BD2KGenomics/toil/issues/995 @@ -759,7 +804,7 @@ def bucket_retry_predicate(error): # NoSuchBucket. We let that kick us back up to the # main retry loop. assert ( - get_bucket_region(bucket_name) == self.region + get_bucket_region(bucket_name) == self.region ), f"bucket_name: {bucket_name}, {get_bucket_region(bucket_name)} != {self.region}" tags = build_tag_dict_from_env() @@ -814,8 +859,10 @@ def bucket_retry_predicate(error): return bucket - def _bindDomain(self, domain_name, create=False, block=True): + def _bindDomain(self, domain_name: str, create: bool = False, block: bool = True) -> None: """ + Return the Boto3 domain name representing the SDB domain. When create=True, it will + create the domain if it does not exist. Return the Boto Domain object representing the SDB domain of the given name. If the domain does not exist and `create` is True, it will be created. @@ -823,11 +870,11 @@ def _bindDomain(self, domain_name, create=False, block=True): :param bool create: True if domain should be created if it doesn't exist - :param bool block: If False, return None if the domain doesn't exist. If True, wait until + :param bool block: If False, raise DomainDoesNotExist if the domain doesn't exist. If True, wait until domain appears. This parameter is ignored if create is True. - :rtype: Domain|None - :raises SDBResponseError: If `block` is True and the domain still doesn't exist after the + :rtype: None + :raises ClientError: If `block` is True and the domain still doesn't exist after the retry timeout expires. """ logger.debug("Binding to job store domain '%s'.", domain_name) @@ -837,15 +884,17 @@ def _bindDomain(self, domain_name, create=False, block=True): for attempt in retry_sdb(**retryargs): with attempt: try: - return self.db.get_domain(domain_name) - except SDBResponseError as e: + self.db.domain_metadata(DomainName=domain_name) + return + except ClientError as e: if no_such_sdb_domain(e): if create: - return self.db.create_domain(domain_name) + self.db.create_domain(DomainName=domain_name) + return elif block: raise else: - return None + raise DomainDoesNotExist(domain_name) else: raise @@ -957,7 +1006,7 @@ def content(self, content): self.version = '' @classmethod - def create(cls, ownerID): + def create(cls, ownerID: str): return cls(str(uuid.uuid4()), ownerID, encrypted=cls.outer.sseKeyPath is not None) @classmethod @@ -968,18 +1017,22 @@ def presenceIndicator(cls): def exists(cls, jobStoreFileID): for attempt in retry_sdb(): with attempt: - return bool(cls.outer.filesDomain.get_attributes( - item_name=compat_bytes(jobStoreFileID), - attribute_name=[cls.presenceIndicator()], - consistent_read=True)) + return bool(cls.outer.db.get_attributes(DomainName=cls.outer.files_domain_name, + ItemName=compat_bytes(jobStoreFileID), + AttributeNames=[cls.presenceIndicator()], + ConsistentRead=True).get("Attributes", [])) @classmethod def load(cls, jobStoreFileID): for attempt in retry_sdb(): with attempt: self = cls.fromItem( - cls.outer.filesDomain.get_attributes(item_name=compat_bytes(jobStoreFileID), - consistent_read=True)) + { + "Name": compat_bytes(jobStoreFileID), + "Attributes": cls.outer.db.get_attributes(DomainName=cls.outer.files_domain_name, + ItemName=compat_bytes(jobStoreFileID), + ConsistentRead=True).get("Attributes", []) + }) return self @classmethod @@ -1009,7 +1062,7 @@ def loadOrFail(cls, jobStoreFileID, customName=None): return self @classmethod - def fromItem(cls, item): + def fromItem(cls, item: ItemTypeDef): """ Convert an SDB item to an instance of this class. @@ -1022,31 +1075,26 @@ def strOrNone(s): return s if s is None else str(s) # ownerID and encrypted are the only mandatory attributes - ownerID = strOrNone(item.get('ownerID')) - encrypted = item.get('encrypted') + ownerID, encrypted, version, checksum = SDBHelper.get_attributes_from_item(item, ["ownerID", "encrypted", "version", "checksum"]) if ownerID is None: assert encrypted is None return None else: - version = strOrNone(item['version']) - checksum = strOrNone(item.get('checksum')) encrypted = strict_bool(encrypted) - content, numContentChunks = cls.attributesToBinary(item) + content, numContentChunks = cls.attributesToBinary(item["Attributes"]) if encrypted: sseKeyPath = cls.outer.sseKeyPath if sseKeyPath is None: raise AssertionError('Content is encrypted but no key was provided.') if content is not None: content = encryption.decrypt(content, sseKeyPath) - self = cls(fileID=item.name, ownerID=ownerID, encrypted=encrypted, version=version, + self = cls(fileID=item["Name"], ownerID=ownerID, encrypted=encrypted, version=version, content=content, numContentChunks=numContentChunks, checksum=checksum) return self - def toItem(self): + def toItem(self) -> Tuple[Dict[str, str], int]: """ - Convert this instance to an attribute dictionary suitable for SDB put_attributes(). - - :rtype: (dict,int) + Convert this instance to a dictionary of attribute names to values :return: the attributes dict and an integer specifying the the number of chunk attributes in the dictionary that are used for storing inlined content. @@ -1060,9 +1108,9 @@ def toItem(self): content = encryption.encrypt(content, sseKeyPath) assert content is None or isinstance(content, bytes) attributes = self.binaryToAttributes(content) - numChunks = attributes['numChunks'] - attributes.update(dict(ownerID=self.ownerID, - encrypted=self.encrypted, + numChunks = int(attributes['numChunks']) + attributes.update(dict(ownerID=self.ownerID or '', + encrypted=str(self.encrypted), version=self.version or '', checksum=self.checksum or '')) return attributes, numChunks @@ -1077,32 +1125,47 @@ def maxInlinedSize(): def save(self): attributes, numNewContentChunks = self.toItem() + attributes_boto3 = SDBHelper.attributeDictToList(attributes) # False stands for absence - expected = ['version', False if self.previousVersion is None else self.previousVersion] + if self.previousVersion is None: + expected: UpdateConditionTypeDef = {"Name": 'version', "Exists": False} + else: + expected = {"Name": 'version', "Value": cast(str, self.previousVersion)} try: for attempt in retry_sdb(): with attempt: - assert self.outer.filesDomain.put_attributes(item_name=compat_bytes(self.fileID), - attributes=attributes, - expected_value=expected) + self.outer.db.put_attributes(DomainName=self.outer.files_domain_name, + ItemName=compat_bytes(self.fileID), + Attributes=[{"Name": attribute["Name"], "Value": attribute["Value"], "Replace": True} + for attribute in attributes_boto3], + Expected=expected) # clean up the old version of the file if necessary and safe if self.previousVersion and (self.previousVersion != self.version): for attempt in retry_s3(): with attempt: - self.outer.s3_client.delete_object(Bucket=self.outer.filesBucket.name, + self.outer.s3_client.delete_object(Bucket=self.outer.files_bucket.name, Key=compat_bytes(self.fileID), VersionId=self.previousVersion) self._previousVersion = self._version if numNewContentChunks < self._numContentChunks: residualChunks = range(numNewContentChunks, self._numContentChunks) - attributes = [self._chunkName(i) for i in residualChunks] + residual_chunk_names = [self._chunkName(i) for i in residualChunks] + # boto3 requires providing the value as well as the name in the attribute, and we don't store it locally + # the php sdk resolves this issue by not requiring the Value key https://github.com/aws/aws-sdk-php/issues/185 + # but this doesnt extend to boto3 + delete_attributes = self.outer.db.get_attributes(DomainName=self.outer.files_domain_name, + ItemName=compat_bytes(self.fileID), + AttributeNames=[chunk for chunk in residual_chunk_names]).get("Attributes") for attempt in retry_sdb(): with attempt: - self.outer.filesDomain.delete_attributes(compat_bytes(self.fileID), - attributes=attributes) + self.outer.db.delete_attributes(DomainName=self.outer.files_domain_name, + ItemName=compat_bytes(self.fileID), + Attributes=delete_attributes) + self.outer.db.get_attributes(DomainName=self.outer.files_domain_name, ItemName=compat_bytes(self.fileID)) + self._numContentChunks = numNewContentChunks - except SDBResponseError as e: - if e.error_code == 'ConditionalCheckFailed': + except ClientError as e: + if get_error_code(e) == 'ConditionalCheckFailed': raise ConcurrentFileModificationException(self.fileID) else: raise @@ -1122,10 +1185,10 @@ def upload(self, localFilePath, calculateChecksum=True): self.checksum = self._get_file_checksum(localFilePath) if calculateChecksum else None self.version = uploadFromPath(localFilePath, resource=resource, - bucketName=self.outer.filesBucket.name, + bucketName=self.outer.files_bucket.name, fileID=compat_bytes(self.fileID), headerArgs=headerArgs, - partSize=self.outer.partSize) + partSize=self.outer.part_size) def _start_checksum(self, to_match=None, algorithm='sha1'): """ @@ -1172,7 +1235,7 @@ def _finish_checksum(self, checksum_in_progress): # We expected a particular hash if result_hash != checksum_in_progress[2]: raise ChecksumError('Checksum mismatch. Expected: %s Actual: %s' % - (checksum_in_progress[2], result_hash)) + (checksum_in_progress[2], result_hash)) return '$'.join([checksum_in_progress[0], result_hash]) @@ -1203,7 +1266,7 @@ def uploadStream(self, multipart=True, allowInlining=True, encoding=None, errors class MultiPartPipe(WritablePipe): def readFrom(self, readable): # Get the first block of data we want to put - buf = readable.read(store.partSize) + buf = readable.read(store.part_size) assert isinstance(buf, bytes) if allowInlining and len(buf) <= info.maxInlinedSize(): @@ -1218,7 +1281,7 @@ def readFrom(self, readable): info._update_checksum(hasher, buf) client = store.s3_client - bucket_name = store.filesBucket.name + bucket_name = store.files_bucket.name headerArgs = info._s3EncryptionArgs() for attempt in retry_s3(): @@ -1232,7 +1295,6 @@ def readFrom(self, readable): parts = [] logger.debug('Multipart upload started as %s', uploadId) - for attempt in retry_s3(): with attempt: for i in range(CONSISTENCY_TICKS): @@ -1241,8 +1303,8 @@ def readFrom(self, readable): MaxUploads=1, Prefix=compat_bytes(info.fileID)) if ('Uploads' in response and - len(response['Uploads']) != 0 and - response['Uploads'][0]['UploadId'] == uploadId): + len(response['Uploads']) != 0 and + response['Uploads'][0]['UploadId'] == uploadId): logger.debug('Multipart upload visible as %s', uploadId) break @@ -1267,7 +1329,7 @@ def readFrom(self, readable): parts.append({"PartNumber": part_num + 1, "ETag": part["ETag"]}) # Get the next block of data we want to put - buf = readable.read(info.outer.partSize) + buf = readable.read(info.outer.part_size) assert isinstance(buf, bytes) if len(buf) == 0: # Don't allow any part other than the very first to be empty. @@ -1283,7 +1345,7 @@ def readFrom(self, readable): else: - while not store._getBucketVersioning(store.filesBucket.name): + while not store._getBucketVersioning(store.files_bucket.name): logger.warning('Versioning does not appear to be enabled yet. Deferring multipart ' 'upload completion...') time.sleep(1) @@ -1340,7 +1402,7 @@ def readFrom(self, readable): info._update_checksum(hasher, buf) info.checksum = info._finish_checksum(hasher) - bucket_name = store.filesBucket.name + bucket_name = store.files_bucket.name headerArgs = info._s3EncryptionArgs() client = store.s3_client @@ -1421,7 +1483,7 @@ def copyFrom(self, srcObj): srcBucketName=compat_bytes(srcObj.bucket_name), srcKeyName=compat_bytes(srcObj.key), srcKeyVersion=compat_bytes(srcObj.version_id), - dstBucketName=compat_bytes(self.outer.filesBucket.name), + dstBucketName=compat_bytes(self.outer.files_bucket.name), dstKeyName=compat_bytes(self._fileID), sseAlgorithm='AES256', sseKey=self._getSSEKey()) @@ -1444,7 +1506,7 @@ def copyTo(self, dstObj): # encrypted = True if self.outer.sseKeyPath else False with attempt: copyKeyMultipart(resource, - srcBucketName=compat_bytes(self.outer.filesBucket.name), + srcBucketName=compat_bytes(self.outer.files_bucket.name), srcKeyName=compat_bytes(self.fileID), srcKeyVersion=compat_bytes(self.version), dstBucketName=compat_bytes(dstObj.bucket_name), @@ -1461,7 +1523,7 @@ def download(self, localFilePath, verifyChecksum=True): f.write(self.content) elif self.version: headerArgs = self._s3EncryptionArgs() - obj = self.outer.filesBucket.Object(compat_bytes(self.fileID)) + obj = self.outer.files_bucket.Object(compat_bytes(self.fileID)) for attempt in retry_s3(predicate=lambda e: retryable_s3_errors(e) or isinstance(e, ChecksumError)): with attempt: @@ -1493,7 +1555,7 @@ def writeTo(self, writable): writable.write(info.content) elif info.version: headerArgs = info._s3EncryptionArgs() - obj = info.outer.filesBucket.Object(compat_bytes(info.fileID)) + obj = info.outer.files_bucket.Object(compat_bytes(info.fileID)) for attempt in retry_s3(): with attempt: obj.download_fileobj(writable, ExtraArgs={'VersionId': info.version, **headerArgs}) @@ -1540,15 +1602,16 @@ def transform(self, readable, writable): def delete(self): store = self.outer if self.previousVersion is not None: + expected: UpdateConditionTypeDef = {"Name": 'version', "Value": cast(str, self.previousVersion)} for attempt in retry_sdb(): with attempt: - store.filesDomain.delete_attributes( - compat_bytes(self.fileID), - expected_values=['version', self.previousVersion]) + store.db.delete_attributes(DomainName=store.files_domain_name, + ItemName=compat_bytes(self.fileID), + Expected=expected) if self.previousVersion: for attempt in retry_s3(): with attempt: - store.s3_client.delete_object(Bucket=store.filesBucket.name, + store.s3_client.delete_object(Bucket=store.files_bucket.name, Key=compat_bytes(self.fileID), VersionId=self.previousVersion) @@ -1561,7 +1624,7 @@ def getSize(self): elif self.version: for attempt in retry_s3(): with attempt: - obj = self.outer.filesBucket.Object(compat_bytes(self.fileID)) + obj = self.outer.files_bucket.Object(compat_bytes(self.fileID)) return obj.content_length else: return 0 @@ -1630,22 +1693,22 @@ def destroy(self): pass # TODO: Add other failure cases to be ignored here. self._registered = None - if self.filesBucket is not None: - self._delete_bucket(self.filesBucket) - self.filesBucket = None - for name in 'filesDomain', 'jobsDomain': - domain = getattr(self, name) - if domain is not None: - self._delete_domain(domain) + if self.files_bucket is not None: + self._delete_bucket(self.files_bucket) + self.files_bucket = None + for name in 'files_domain_name', 'jobs_domain_name': + domainName = getattr(self, name) + if domainName is not None: + self._delete_domain(domainName) setattr(self, name, None) self._registered = False - def _delete_domain(self, domain): + def _delete_domain(self, domainName): for attempt in retry_sdb(): with attempt: try: - domain.delete() - except SDBResponseError as e: + self.db.delete_domain(DomainName=domainName) + except ClientError as e: if not no_such_sdb_domain(e): raise diff --git a/src/toil/jobStores/aws/utils.py b/src/toil/jobStores/aws/utils.py index 48ef581ff7..81deb1ff5f 100644 --- a/src/toil/jobStores/aws/utils.py +++ b/src/toil/jobStores/aws/utils.py @@ -17,12 +17,12 @@ import os import types from ssl import SSLError -from typing import Optional, cast, TYPE_CHECKING +from typing import Optional, cast, TYPE_CHECKING, Dict, List, Tuple from boto3.s3.transfer import TransferConfig -from boto.exception import SDBResponseError from botocore.client import Config from botocore.exceptions import ClientError +from mypy_boto3_sdb.type_defs import ItemTypeDef, AttributeTypeDef from toil.lib.aws import session, AWSServerErrors from toil.lib.aws.utils import connection_reset, get_bucket_region @@ -125,11 +125,11 @@ def _maxEncodedSize(cls): return cls._maxChunks() * cls.maxValueSize @classmethod - def binaryToAttributes(cls, binary): + def binaryToAttributes(cls, binary) -> Dict[str, str]: """ Turn a bytestring, or None, into SimpleDB attributes. """ - if binary is None: return {'numChunks': 0} + if binary is None: return {'numChunks': '0'} assert isinstance(binary, bytes) assert len(binary) <= cls.maxBinarySize() # The use of compression is just an optimization. We can't include it in the maxValueSize @@ -143,10 +143,41 @@ def binaryToAttributes(cls, binary): assert len(encoded) <= cls._maxEncodedSize() n = cls.maxValueSize chunks = (encoded[i:i + n] for i in range(0, len(encoded), n)) - attributes = {cls._chunkName(i): chunk for i, chunk in enumerate(chunks)} - attributes.update({'numChunks': len(attributes)}) + attributes = {cls._chunkName(i): chunk.decode("utf-8") for i, chunk in enumerate(chunks)} + attributes.update({'numChunks': str(len(attributes))}) return attributes + @classmethod + def attributeDictToList(cls, attributes: Dict[str, str]) -> List[AttributeTypeDef]: + """ + Convert the attribute dict (ex: from binaryToAttributes) into a list of attribute typed dicts + to be compatible with boto3 argument syntax + :param attributes: Dict[str, str], attribute in object form + :return: List[AttributeTypeDef], list of attributes in typed dict form + """ + return [{"Name": name, "Value": value} for name, value in attributes.items()] + + @classmethod + def attributeListToDict(cls, attributes: List[AttributeTypeDef]) -> Dict[str, str]: + """ + Convert the attribute boto3 representation of list of attribute typed dicts + back to a dictionary with name, value pairs + :param attribute: List[AttributeTypeDef, attribute in typed dict form + :return: Dict[str, str], attribute in dict form + """ + return {attribute["Name"]: attribute["Value"] for attribute in attributes} + + @classmethod + def get_attributes_from_item(cls, item: ItemTypeDef, keys: List[str]) -> List[Optional[str]]: + return_values: List[Optional[str]] = [None for _ in keys] + mapped_indices: Dict[str, int] = {name: index for index, name in enumerate(keys)} + for attribute in item["Attributes"]: + name = attribute["Name"] + value = attribute["Value"] + if name in mapped_indices: + return_values[mapped_indices[name]] = value + return return_values + @classmethod def _chunkName(cls, i): return str(i).zfill(3) @@ -165,14 +196,21 @@ def presenceIndicator(cls): return 'numChunks' @classmethod - def attributesToBinary(cls, attributes): + def attributesToBinary(cls, attributes: List[AttributeTypeDef]) -> Tuple[bytes, int]: """ :rtype: (str|None,int) :return: the binary data and the number of chunks it was composed from """ - chunks = [(int(k), v) for k, v in attributes.items() if cls._isValidChunkName(k)] + chunks = [] + numChunks: int = 0 + for attribute in attributes: + name = attribute["Name"] + value = attribute["Value"] + if cls._isValidChunkName(name): + chunks.append((int(name), value)) + if name == "numChunks": + numChunks = int(value) chunks.sort() - numChunks = int(attributes['numChunks']) if numChunks: serializedJob = b''.join(v.encode() for k, v in chunks) compressed = base64.b64decode(serializedJob) @@ -429,9 +467,9 @@ def sdb_unavailable(e): def no_such_sdb_domain(e): - return (isinstance(e, SDBResponseError) - and e.error_code - and e.error_code.endswith('NoSuchDomain')) + return (isinstance(e, ClientError) + and get_error_code(e) + and get_error_code(e).endswith('NoSuchDomain')) def retryable_ssl_error(e): diff --git a/src/toil/jobStores/conftest.py b/src/toil/jobStores/conftest.py index bc402b4dec..b90874c8de 100644 --- a/src/toil/jobStores/conftest.py +++ b/src/toil/jobStores/conftest.py @@ -17,7 +17,7 @@ collect_ignore = [] try: - import boto - print(boto.__file__) # prevent this import from being removed + import boto3 + print(boto3.__file__) # prevent this import from being removed except ImportError: collect_ignore.append("aws") diff --git a/src/toil/lib/aws/__init__.py b/src/toil/lib/aws/__init__.py index 93a221b93b..e978caf7bd 100644 --- a/src/toil/lib/aws/__init__.py +++ b/src/toil/lib/aws/__init__.py @@ -95,12 +95,15 @@ def get_aws_zone_from_metadata() -> Optional[str]: def get_aws_zone_from_boto() -> Optional[str]: """ - Get the AWS zone from the Boto config file, if it is configured and the - boto module is available. + Get the AWS zone from the Boto3 config file or from AWS_DEFAULT_REGION, if it is configured and the + boto3 module is available. """ try: - import boto - zone = boto.config.get('Boto', 'ec2_region_name') + import boto3 + from session import client + boto3_session = boto3.session.Session() + # this should check AWS_DEFAULT_REGION and ~/.aws/config + zone = boto3_session.region_name if zone is not None: zone += 'a' # derive an availability zone in the region return zone diff --git a/src/toil/lib/aws/session.py b/src/toil/lib/aws/session.py index 96cacae0bb..bd03db78de 100644 --- a/src/toil/lib/aws/session.py +++ b/src/toil/lib/aws/session.py @@ -17,7 +17,6 @@ import threading from typing import Dict, Optional, Tuple, cast, Union, Literal, overload, TypeVar -import boto import boto3 import boto3.resources.base import botocore diff --git a/src/toil/lib/aws/utils.py b/src/toil/lib/aws/utils.py index bee4973dfd..7ce75321ed 100644 --- a/src/toil/lib/aws/utils.py +++ b/src/toil/lib/aws/utils.py @@ -27,6 +27,7 @@ cast) from urllib.parse import ParseResult +from mypy_boto3_sdb.type_defs import AttributeTypeDef from toil.lib.aws import session, AWSRegionName, AWSServerErrors from toil.lib.misc import printq from toil.lib.retry import (DEFAULT_DELAYS, @@ -34,10 +35,9 @@ get_error_code, get_error_status, old_retry, - retry) + retry, ErrorCondition) try: - from boto.exception import BotoServerError, S3ResponseError from botocore.exceptions import ClientError from mypy_boto3_iam import IAMClient, IAMServiceResource from mypy_boto3_s3 import S3Client, S3ServiceResource @@ -45,7 +45,6 @@ from mypy_boto3_s3.service_resource import Bucket, Object from mypy_boto3_sdb import SimpleDBClient except ImportError: - BotoServerError = None # type: ignore ClientError = None # type: ignore # AWS/boto extra is not installed @@ -74,7 +73,6 @@ def delete_iam_role( role_name: str, region: Optional[str] = None, quiet: bool = True ) -> None: - # TODO: the Boto3 type hints are a bit oversealous here; they want hundreds # of overloads of the client-getting methods to exist based on the literal # string passed in, to return exactly the right kind of client or resource. @@ -137,10 +135,10 @@ def retryable_s3_errors(e: Exception) -> bool: Return true if this is an error from S3 that looks like we ought to retry our request. """ return (connection_reset(e) - or (isinstance(e, BotoServerError) and e.status in (429, 500)) - or (isinstance(e, BotoServerError) and e.code in THROTTLED_ERROR_CODES) + or (isinstance(e, ClientError) and get_error_status(e) in (429, 500)) + or (isinstance(e, ClientError) and get_error_code(e) in THROTTLED_ERROR_CODES) # boto3 errors - or (isinstance(e, (S3ResponseError, ClientError)) and get_error_code(e) in THROTTLED_ERROR_CODES) + or (isinstance(e, ClientError) and get_error_code(e) in THROTTLED_ERROR_CODES) or (isinstance(e, ClientError) and 'BucketNotEmpty' in str(e)) or (isinstance(e, ClientError) and e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 409 and 'try again' in str(e)) or (isinstance(e, ClientError) and e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') in (404, 429, 500, 502, 503, 504))) @@ -434,3 +432,42 @@ def flatten_tags(tags: Dict[str, str]) -> List[Dict[str, str]]: Convert tags from a key to value dict into a list of 'Key': xxx, 'Value': xxx dicts. """ return [{'Key': k, 'Value': v} for k, v in tags.items()] + + +def boto3_pager(requestor_callable: Callable[..., Any], result_attribute_name: str, + **kwargs: Any) -> Iterable[Any]: + """ + Yield all the results from calling the given Boto 3 method with the + given keyword arguments, paging through the results using the Marker or + NextToken, and fetching out and looping over the list in the response + with the given attribute name. + """ + + # Recover the Boto3 client, and the name of the operation + client = requestor_callable.__self__ # type: ignore[attr-defined] + op_name = requestor_callable.__name__ + + # grab a Boto 3 built-in paginator. See + # + paginator = client.get_paginator(op_name) + + for page in paginator.paginate(**kwargs): + # Invoke it and go through the pages, yielding from them + yield from page.get(result_attribute_name, []) + + +def get_item_from_attributes(attributes: List[AttributeTypeDef], name: str) -> Any: + """ + Given a list of attributes, find the attribute associated with the name and return its corresponding value. + + The `attribute_list` will be a list of TypedDict's (which boto3 SDB functions commonly return), + where each TypedDict has a "Name" and "Value" key value pair. + This function grabs the value out of the associated TypedDict. + + If the attribute with the name does not exist, the function will return None. + + :param attributes: list of attributes as List[AttributeTypeDef] + :param name: name of the attribute + :return: value of the attribute + """ + return next((attribute["Value"] for attribute in attributes if attribute["Name"] == name), None) diff --git a/src/toil/lib/ec2.py b/src/toil/lib/ec2.py index 8b3baa94b3..b7f2b439e4 100644 --- a/src/toil/lib/ec2.py +++ b/src/toil/lib/ec2.py @@ -6,8 +6,6 @@ import botocore.client from boto3.resources.base import ServiceResource -from boto.ec2.instance import Instance as Boto2Instance -from boto.ec2.spotinstancerequest import SpotInstanceRequest from toil.lib.aws.session import establish_boto3_session from toil.lib.aws.utils import flatten_tags @@ -102,8 +100,8 @@ def wait_instances_running(boto3_ec2: EC2Client, instances: Iterable[InstanceTyp entered the running state as soon as it does. :param EC2Client boto3_ec2: the EC2 connection to use for making requests - :param Iterable[Boto2Instance] instances: the instances to wait on - :rtype: Iterable[Boto2Instance] + :param Iterable[InstanceTypeDef] instances: the instances to wait on + :rtype: Iterable[InstanceTypeDef] """ running_ids = set() other_ids = set() @@ -136,7 +134,7 @@ def wait_instances_running(boto3_ec2: EC2Client, instances: Iterable[InstanceTyp instances = [instance for reservation in described_instances["Reservations"] for instance in reservation["Instances"]] -def wait_spot_requests_active(boto3_ec2: EC2Client, requests: Iterable[SpotInstanceRequestTypeDef], timeout: float = None, tentative: bool = False) -> Iterable[List[SpotInstanceRequest]]: +def wait_spot_requests_active(boto3_ec2: EC2Client, requests: Iterable[SpotInstanceRequestTypeDef], timeout: float = None, tentative: bool = False) -> Iterable[List[SpotInstanceRequestTypeDef]]: """ Wait until no spot request in the given iterator is in the 'open' state or, optionally, a timeout occurs. Yield spot requests as soon as they leave the 'open' state. @@ -171,17 +169,17 @@ def spot_request_not_found(e: Exception) -> bool: batch = [] for r in requests: r: SpotInstanceRequestTypeDef # pycharm thinks it is a string - if r['State']['Name'] == 'open': + if r['State'] == 'open': open_ids.add(r['InstanceId']) - if r['Status']['Code'] == 'pending-evaluation': + if r['Status'] == 'pending-evaluation': eval_ids.add(r['InstanceId']) - elif r['Status']['Code'] == 'pending-fulfillment': + elif r['Status'] == 'pending-fulfillment': fulfill_ids.add(r['InstanceId']) else: logger.info( 'Request %s entered status %s indicating that it will not be ' - 'fulfilled anytime soon.', r['InstanceId'], r['Status']['Code']) - elif r['State']['Name'] == 'active': + 'fulfilled anytime soon.', r['InstanceId'], r['Status']) + elif r['State'] == 'active': if r['InstanceId'] in active_ids: raise RuntimeError("A request was already added to the list of active requests. Maybe there are duplicate requests.") active_ids.add(r['InstanceId']) @@ -248,14 +246,15 @@ def spotRequestNotFound(e): tentative=tentative): instance_ids = [] for request in batch: - if request.state == 'active': - instance_ids.append(request.instance_id) + request: SpotInstanceRequestTypeDef + if request["State"] == 'active': + instance_ids.append(request["InstanceId"]) num_active += 1 else: logger.info( 'Request %s in unexpected state %s.', - request.id, - request.state) + request["InstanceId"], + request["State"]) num_other += 1 if instance_ids: # This next line is the reason we batch. It's so we can get multiple instances in @@ -281,7 +280,7 @@ def create_ondemand_instances(boto3_ec2: EC2Client, image_id: str, spec: Mapping Requests the RunInstances EC2 API call but accounts for the race between recently created instance profiles, IAM roles and an instance creation that refers to them. - :rtype: List[Boto2Instance] + :rtype: List[InstanceTypeDef] """ instance_type = spec['InstanceType'] logger.info('Creating %s instance(s) ... ', instance_type) diff --git a/src/toil/provisioners/aws/awsProvisioner.py b/src/toil/provisioners/aws/awsProvisioner.py index a1254cfeae..eddbb00a31 100644 --- a/src/toil/provisioners/aws/awsProvisioner.py +++ b/src/toil/provisioners/aws/awsProvisioner.py @@ -51,7 +51,7 @@ get_policy_permissions, policy_permissions_allow) from toil.lib.aws.session import AWSConnectionManager -from toil.lib.aws.utils import create_s3_bucket, flatten_tags +from toil.lib.aws.utils import create_s3_bucket, flatten_tags, boto3_pager from toil.lib.conversions import human2bytes from toil.lib.ec2 import (a_short_time, create_auto_scaling_group, @@ -88,6 +88,7 @@ from mypy_boto3_iam.client import IAMClient from mypy_boto3_ec2.type_defs import DescribeInstancesResultTypeDef, InstanceTypeDef, TagTypeDef, BlockDeviceMappingTypeDef, EbsBlockDeviceTypeDef, FilterTypeDef, SpotInstanceRequestTypeDef, TagDescriptionTypeDef, SecurityGroupTypeDef, \ CreateSecurityGroupResultTypeDef, IpPermissionTypeDef, ReservationTypeDef +from mypy_boto3_s3.literals import BucketLocationConstraintType logger = logging.getLogger(__name__) logging.getLogger("boto").setLevel(logging.CRITICAL) @@ -222,7 +223,11 @@ def __init__(self, clusterName: Optional[str], clusterType: Optional[str], zone: # Call base class constructor, which will call createClusterSettings() # or readClusterSettings() super().__init__(clusterName, clusterType, zone, nodeStorage, nodeStorageOverrides, enable_fuse) - self._leader_subnet: str = self._get_default_subnet(self._zone) + + if self._zone is None: + logger.warning("Leader zone was never initialized before creating AWS provisioner. Defaulting to cluster zone.") + + self._leader_subnet: str = self._get_default_subnet(self._zone or zone) self._tags: Dict[str, Any] = {} # After self.clusterName is set, generate a valid name for the S3 bucket associated with this cluster @@ -567,7 +572,7 @@ def _get_subnet_acls(self, subnet: str) -> List[str]: }] # TODO: Can't we use the resource's network_acls.filter(Filters=)? - return [item['NetworkAclId'] for item in self._pager(ec2.describe_network_acls, + return [item['NetworkAclId'] for item in boto3_pager(ec2.describe_network_acls, 'NetworkAcls', Filters=filters)] @@ -619,7 +624,7 @@ def getKubernetesCloudProvider(self) -> Optional[str]: return 'aws' - def getNodeShape(self, instance_type: str, preemptible: bool = False) -> Shape: + def getNodeShape(self, instance_type: str, preemptible: bool=False) -> Shape: """ Get the Shape for the given instance type (e.g. 't2.medium'). """ @@ -813,7 +818,7 @@ def _recover_node_type_bid(self, node_type: Set[str], spot_bid: Optional[float]) return spot_bid - def addNodes(self, nodeTypes: Set[str], numNodes: int, preemptible: bool, spotBid: Optional[float] = None) -> int: + def addNodes(self, nodeTypes: Set[str], numNodes: int, preemptible: bool, spotBid: Optional[float]=None) -> int: # Grab the AWS connection we need boto3_ec2 = get_client(service_name='ec2', region_name=self._region) assert self._leaderPrivateIP @@ -1259,7 +1264,6 @@ def _createSecurityGroups(self) -> List[str]: """ Create security groups for the cluster. Returns a list of their IDs. """ - def group_not_found(e: ClientError) -> bool: retry = (get_error_status(e) == 400 and 'does not exist in default VPC' in get_error_body(e)) return retry @@ -1595,27 +1599,6 @@ def _boto2_pager(self, requestor_callable: Callable[[...], Any], result_attribut else: break - def _pager(self, requestor_callable: Callable[..., Any], result_attribute_name: str, - **kwargs: Any) -> Iterable[Any]: - """ - Yield all the results from calling the given Boto 3 method with the - given keyword arguments, paging through the results using the Marker or - NextToken, and fetching out and looping over the list in the response - with the given attribute name. - """ - - # Recover the Boto3 client, and the name of the operation - client = requestor_callable.__self__ # type: ignore[attr-defined] - op_name = requestor_callable.__name__ - - # grab a Boto 3 built-in paginator. See - # - paginator = client.get_paginator(op_name) - - for page in paginator.paginate(**kwargs): - # Invoke it and go through the pages, yielding from them - yield from page.get(result_attribute_name, []) - @awsRetry def _getRoleNames(self) -> List[str]: """ @@ -1624,7 +1607,7 @@ def _getRoleNames(self) -> List[str]: results = [] boto3_iam = self.aws.client(self._region, 'iam') - for result in self._pager(boto3_iam.list_roles, 'Roles'): + for result in boto3_pager(boto3_iam.list_roles, 'Roles'): # For each Boto2 role object # Grab out the name cast(RoleTypeDef, result) @@ -1642,10 +1625,10 @@ def _getInstanceProfileNames(self) -> List[str]: results = [] boto3_iam = self.aws.client(self._region, 'iam') - for result in self._pager(boto3_iam.list_instance_profiles, + for result in boto3_pager(boto3_iam.list_instance_profiles, 'InstanceProfiles'): - # For each Boto role object # Grab out the name + cast(InstanceProfileTypeDef, result) name = result['InstanceProfileName'] if self._is_our_namespaced_name(name): # If it looks like ours, it is ours. @@ -1663,7 +1646,7 @@ def _getRoleInstanceProfileNames(self, role_name: str) -> List[str]: # Grab the connection we need to use for this operation. boto3_iam: IAMClient = self.aws.client(self._region, 'iam') - return [item['InstanceProfileName'] for item in self._pager(boto3_iam.list_instance_profiles_for_role, + return [item['InstanceProfileName'] for item in boto3_pager(boto3_iam.list_instance_profiles_for_role, 'InstanceProfiles', RoleName=role_name)] @@ -1682,7 +1665,7 @@ def _getRolePolicyArns(self, role_name: str) -> List[str]: # TODO: we don't currently use attached policies. - return [item['PolicyArn'] for item in self._pager(boto3_iam.list_attached_role_policies, + return [item['PolicyArn'] for item in boto3_pager(boto3_iam.list_attached_role_policies, 'AttachedPolicies', RoleName=role_name)] @@ -1696,7 +1679,7 @@ def _getRoleInlinePolicyNames(self, role_name: str) -> List[str]: # Grab the connection we need to use for this operation. boto3_iam: IAMClient = self.aws.client(self._region, 'iam') - return list(self._pager(boto3_iam.list_role_policies, 'PolicyNames', RoleName=role_name)) + return list(boto3_pager(boto3_iam.list_role_policies, 'PolicyNames', RoleName=role_name)) def full_policy(self, resource: str) -> Dict[str, Any]: """ diff --git a/src/toil/test/__init__.py b/src/toil/test/__init__.py index 0556f4f1c2..eb8cab4c25 100644 --- a/src/toil/test/__init__.py +++ b/src/toil/test/__init__.py @@ -372,14 +372,15 @@ def needs_aws_s3(test_item: MT) -> MT: # TODO: we just check for generic access to the AWS account test_item = _mark_test('aws-s3', needs_online(test_item)) try: - from boto import config - boto_credentials = config.get('Credentials', 'aws_access_key_id') + from boto3 import Session + session = Session() + boto3_credentials = session.get_credentials() except ImportError: return unittest.skip("Install Toil with the 'aws' extra to include this test.")( test_item ) from toil.lib.aws import running_on_ec2 - if not (boto_credentials or os.path.exists(os.path.expanduser('~/.aws/credentials')) or running_on_ec2()): + if not (boto3_credentials or os.path.exists(os.path.expanduser('~/.aws/credentials')) or running_on_ec2()): return unittest.skip("Configure AWS credentials to include this test.")(test_item) return test_item diff --git a/src/toil/test/jobStores/jobStoreTest.py b/src/toil/test/jobStores/jobStoreTest.py index 8079d2defb..6f74fc9b47 100644 --- a/src/toil/test/jobStores/jobStoreTest.py +++ b/src/toil/test/jobStores/jobStoreTest.py @@ -651,7 +651,7 @@ def testImportExportFile(self, otherCls, size, moveExports): :param int size: the size of the file to test importing/exporting with """ # Prepare test file in other job store - self.jobstore_initialized.partSize = cls.mpTestPartSize + self.jobstore_initialized.part_size = cls.mpTestPartSize self.jobstore_initialized.moveExports = moveExports # Test assumes imports are not linked @@ -707,7 +707,7 @@ def testImportSharedFile(self, otherCls): to import from or export to """ # Prepare test file in other job store - self.jobstore_initialized.partSize = cls.mpTestPartSize + self.jobstore_initialized.part_size = cls.mpTestPartSize other = otherCls('testSharedFiles') store = other._externalStore() @@ -1295,7 +1295,6 @@ def testSDBDomainsDeletedOnFailedJobstoreBucketCreation(self): failed to be created. We simulate a failed jobstore bucket creation by using a bucket in a different region with the same name. """ - from boto.sdb import connect_to_region from botocore.exceptions import ClientError from toil.jobStores.aws.jobStore import BucketLocationConflictException @@ -1335,13 +1334,16 @@ def testSDBDomainsDeletedOnFailedJobstoreBucketCreation(self): except BucketLocationConflictException: # Catch the expected BucketLocationConflictException and ensure that the bound # domains don't exist in SDB. - sdb = connect_to_region(self.awsRegion()) + sdb = establish_boto3_session().client(region_name=self.awsRegion(), service_name="sdb") next_token = None allDomainNames = [] while True: - domains = sdb.get_all_domains(max_domains=100, next_token=next_token) - allDomainNames.extend([x.name for x in domains]) - next_token = domains.next_token + if next_token is None: + domains = sdb.list_domains(MaxNumberOfDomains=100) + else: + domains = sdb.list_domains(MaxNumberOfDomains=100, NextToken=next_token) + allDomainNames.extend(domains["DomainNames"]) + next_token = domains.get("NextToken") if next_token is None: break self.assertFalse([d for d in allDomainNames if testJobStoreUUID in d]) diff --git a/src/toil/test/provisioners/clusterTest.py b/src/toil/test/provisioners/clusterTest.py index 6a129eff5f..95c1c0b7d4 100644 --- a/src/toil/test/provisioners/clusterTest.py +++ b/src/toil/test/provisioners/clusterTest.py @@ -38,13 +38,9 @@ def __init__(self, methodName: str) -> None: self.clusterType = 'mesos' self.zone = get_best_aws_zone() assert self.zone is not None, "Could not determine AWS availability zone to test in; is TOIL_AWS_ZONE set?" - # We need a boto2 connection to EC2 to check on the cluster. - # Since we are protected by needs_aws_ec2 we can import from boto. - import boto.ec2 self.region = zone_to_region(self.zone) - self.boto2_ec2 = boto.ec2.connect_to_region(self.region) - # Get connection to AWS with boto3/boto2 + # Get connection to AWS self.aws = AWSConnectionManager() # Where should we put our virtualenv?