diff --git a/miniwdl-plugins/s3upload/miniwdl_s3upload.py b/miniwdl-plugins/s3upload/miniwdl_s3upload.py index b443216d..7c005e96 100644 --- a/miniwdl-plugins/s3upload/miniwdl_s3upload.py +++ b/miniwdl-plugins/s3upload/miniwdl_s3upload.py @@ -27,11 +27,10 @@ import threading import json import logging -import time -import random from pathlib import Path from urllib.parse import urlparse -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, Set +import sys import WDL from WDL import Env, Value, values_to_json @@ -66,6 +65,14 @@ def get_s3_get_prefix(cfg: config.Loader) -> str: return s3prefix +def tag_temporary_output_files(output_file_set, s3prefix): + for object_path in _uploaded_files.values(): + if (object_path not in output_file_set) and (object_path not in _processed_files): + flag_temporary(object_path) + print(object_path, file=sys.stderr) + _processed_files.add(object_path) + + def flag_temporary(s3uri): uri = urlparse(s3uri) bucket, key = uri.hostname, uri.path[1:] @@ -87,42 +94,6 @@ def flag_temporary(s3uri): pass -def remove_temporary_flag(s3uri, retry=0): - """ Remove temporary flag from s3 if in outputs.json """ - uri = urlparse(s3uri) - bucket, key = uri.hostname, uri.path[1:] - tags = s3_client.get_object_tagging( - Bucket=bucket, - Key=key, - ) - remaining_tags = [] - for tag in tags["TagSet"]: - if not (tag["Key"] == "intermediate_output" and tag["Value"] == "true"): - remaining_tags.append(tag) - try: - if remaining_tags: - s3_client.put_object_tagging( - Bucket=bucket, - Key=key, - Tagging={ - 'TagSet': remaining_tags - }, - ) - elif len(tags["TagSet"]) > 0: # Delete tags if they exist - s3_client.delete_object_tagging( - Bucket=bucket, - Key=key, - ) - except botocore.exceptions.ClientError as e: - if retry > 3: - raise e - print(f"Error deleting tags for object {key} in bucket {bucket}: {e}") - delay = 20 + random.randint(0, 10) - print(f"Retrying in {delay} seconds...") - time.sleep(delay) - remove_temporary_flag(s3uri, retry+1) - - def inode(link: str): if re.match(r'^\w+://', link): return link @@ -134,6 +105,7 @@ def inode(link: str): _cached_files: Dict[Tuple[int, int], Tuple[str, Env.Bindings[Value.Base]]] = {} _key_inputs: Dict[str, Env.Bindings[Value.Base]] = {} _uploaded_files_lock = threading.Lock() +_processed_files: Set = set() def cache_put(cfg: config.Loader, logger: logging.Logger, key: str, outputs: Env.Bindings[Value.Base]): @@ -222,8 +194,8 @@ def task(cfg, logger, run_id, run_dir, task, **recv): # ignore command/runtime/container recv = yield recv - def upload_file(abs_fn, s3uri, flag_temporary_file=False): - s3cp(logger, abs_fn, s3uri, flag_temporary_file=flag_temporary_file) + def upload_file(abs_fn, s3uri): + s3cp(logger, abs_fn, s3uri) # record in _uploaded_files (keyed by inode, so that it can be found from any # symlink or hardlink) with _uploaded_files_lock: @@ -263,13 +235,13 @@ def _raise(ex): for fn in files: abs_fn = os.path.join(dn, fn) s3uri = os.path.join(s3prefix, os.path.relpath(abs_fn, abs_output)) - upload_file(abs_fn, s3uri, flag_temporary_file=False) + upload_file(abs_fn, s3uri) elif len(output_contents) == 1 and os.path.isfile(output_contents[0]): # file output basename = os.path.basename(output_contents[0]) abs_fn = os.path.join(abs_output, basename) s3uri = os.path.join(s3prefix, basename) - upload_file(abs_fn, s3uri, flag_temporary_file=True) + upload_file(abs_fn, s3uri) else: # file array output assert all(os.path.basename(abs_fn).isdigit() for abs_fn in output_contents), output_contents @@ -278,7 +250,7 @@ def _raise(ex): assert len(fns) == 1 abs_fn = os.path.join(index_dir, fns[0]) s3uri = os.path.join(s3prefix, fns[0]) - upload_file(abs_fn, s3uri, flag_temporary_file=False) + upload_file(abs_fn, s3uri) yield recv @@ -299,7 +271,7 @@ def workflow(cfg, logger, run_id, run_dir, workflow, **recv): logger, recv["outputs"], run_dir, - os.path.join(get_s3_put_prefix(cfg), *run_id[1:]), + os.path.join(get_s3_put_prefix(cfg)), workflow.name, ) @@ -336,25 +308,27 @@ def rewriter(fd): json.dump(outputs_s3_json, outfile, indent=2) outfile.write("\n") + output_set = set() for output_file in outputs_s3_json.values(): if isinstance(output_file, list): for filename in output_file: - remove_temporary_flag(filename) + output_set.add(filename) elif isinstance(output_file, str) and output_file.startswith("s3://"): - remove_temporary_flag(output_file) + output_set.add(output_file) + + tag_temporary_output_files(output_file_set=output_set, s3prefix=s3prefix) s3cp( logger, fn, - os.environ.get("WDL_OUTPUT_URI", os.path.join(s3prefix, "outputs.s3.json")), - flag_temporary_file=False + os.environ.get("WDL_OUTPUT_URI", os.path.join(s3prefix, "outputs.s3.json")) ) _s3parcp_lock = threading.Lock() -def s3cp(logger, fn, s3uri, flag_temporary_file=False): +def s3cp(logger, fn, s3uri): with _s3parcp_lock: # when uploading many small outputs from the same pipeline you end up with a # quick intense burst of load that can bump into the S3 rate limit @@ -372,5 +346,3 @@ def s3cp(logger, fn, s3uri, flag_temporary_file=False): ) ) raise WDL.Error.RuntimeError("failed: " + " ".join(cmd)) - if flag_temporary_file: - flag_temporary(s3uri) diff --git a/version b/version index 959bb9d0..2aca8c01 100644 --- a/version +++ b/version @@ -1 +1 @@ -v1.4.5 +v1.4.6