Skip to content

Commit

Permalink
[SC-230138] Refactor tagging to only occur after the workflow (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
rzlim08 authored Apr 3, 2023
1 parent bf7e1ca commit 3cd78ea
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 53 deletions.
76 changes: 24 additions & 52 deletions miniwdl-plugins/s3upload/miniwdl_s3upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
import threading
import json
import logging
import time
import random
from pathlib import Path
from urllib.parse import urlparse
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union, Set
import sys

import WDL
from WDL import Env, Value, values_to_json
Expand Down Expand Up @@ -66,6 +65,14 @@ def get_s3_get_prefix(cfg: config.Loader) -> str:
return s3prefix


def tag_temporary_output_files(output_file_set, s3prefix):
for object_path in _uploaded_files.values():
if (object_path not in output_file_set) and (object_path not in _processed_files):
flag_temporary(object_path)
print(object_path, file=sys.stderr)
_processed_files.add(object_path)


def flag_temporary(s3uri):
uri = urlparse(s3uri)
bucket, key = uri.hostname, uri.path[1:]
Expand All @@ -87,42 +94,6 @@ def flag_temporary(s3uri):
pass


def remove_temporary_flag(s3uri, retry=0):
""" Remove temporary flag from s3 if in outputs.json """
uri = urlparse(s3uri)
bucket, key = uri.hostname, uri.path[1:]
tags = s3_client.get_object_tagging(
Bucket=bucket,
Key=key,
)
remaining_tags = []
for tag in tags["TagSet"]:
if not (tag["Key"] == "intermediate_output" and tag["Value"] == "true"):
remaining_tags.append(tag)
try:
if remaining_tags:
s3_client.put_object_tagging(
Bucket=bucket,
Key=key,
Tagging={
'TagSet': remaining_tags
},
)
elif len(tags["TagSet"]) > 0: # Delete tags if they exist
s3_client.delete_object_tagging(
Bucket=bucket,
Key=key,
)
except botocore.exceptions.ClientError as e:
if retry > 3:
raise e
print(f"Error deleting tags for object {key} in bucket {bucket}: {e}")
delay = 20 + random.randint(0, 10)
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
remove_temporary_flag(s3uri, retry+1)


def inode(link: str):
if re.match(r'^\w+://', link):
return link
Expand All @@ -134,6 +105,7 @@ def inode(link: str):
_cached_files: Dict[Tuple[int, int], Tuple[str, Env.Bindings[Value.Base]]] = {}
_key_inputs: Dict[str, Env.Bindings[Value.Base]] = {}
_uploaded_files_lock = threading.Lock()
_processed_files: Set = set()


def cache_put(cfg: config.Loader, logger: logging.Logger, key: str, outputs: Env.Bindings[Value.Base]):
Expand Down Expand Up @@ -222,8 +194,8 @@ def task(cfg, logger, run_id, run_dir, task, **recv):
# ignore command/runtime/container
recv = yield recv

def upload_file(abs_fn, s3uri, flag_temporary_file=False):
s3cp(logger, abs_fn, s3uri, flag_temporary_file=flag_temporary_file)
def upload_file(abs_fn, s3uri):
s3cp(logger, abs_fn, s3uri)
# record in _uploaded_files (keyed by inode, so that it can be found from any
# symlink or hardlink)
with _uploaded_files_lock:
Expand Down Expand Up @@ -263,13 +235,13 @@ def _raise(ex):
for fn in files:
abs_fn = os.path.join(dn, fn)
s3uri = os.path.join(s3prefix, os.path.relpath(abs_fn, abs_output))
upload_file(abs_fn, s3uri, flag_temporary_file=False)
upload_file(abs_fn, s3uri)
elif len(output_contents) == 1 and os.path.isfile(output_contents[0]):
# file output
basename = os.path.basename(output_contents[0])
abs_fn = os.path.join(abs_output, basename)
s3uri = os.path.join(s3prefix, basename)
upload_file(abs_fn, s3uri, flag_temporary_file=True)
upload_file(abs_fn, s3uri)
else:
# file array output
assert all(os.path.basename(abs_fn).isdigit() for abs_fn in output_contents), output_contents
Expand All @@ -278,7 +250,7 @@ def _raise(ex):
assert len(fns) == 1
abs_fn = os.path.join(index_dir, fns[0])
s3uri = os.path.join(s3prefix, fns[0])
upload_file(abs_fn, s3uri, flag_temporary_file=False)
upload_file(abs_fn, s3uri)
yield recv


Expand All @@ -299,7 +271,7 @@ def workflow(cfg, logger, run_id, run_dir, workflow, **recv):
logger,
recv["outputs"],
run_dir,
os.path.join(get_s3_put_prefix(cfg), *run_id[1:]),
os.path.join(get_s3_put_prefix(cfg)),
workflow.name,
)

Expand Down Expand Up @@ -336,25 +308,27 @@ def rewriter(fd):
json.dump(outputs_s3_json, outfile, indent=2)
outfile.write("\n")

output_set = set()
for output_file in outputs_s3_json.values():
if isinstance(output_file, list):
for filename in output_file:
remove_temporary_flag(filename)
output_set.add(filename)
elif isinstance(output_file, str) and output_file.startswith("s3://"):
remove_temporary_flag(output_file)
output_set.add(output_file)

tag_temporary_output_files(output_file_set=output_set, s3prefix=s3prefix)

s3cp(
logger,
fn,
os.environ.get("WDL_OUTPUT_URI", os.path.join(s3prefix, "outputs.s3.json")),
flag_temporary_file=False
os.environ.get("WDL_OUTPUT_URI", os.path.join(s3prefix, "outputs.s3.json"))
)


_s3parcp_lock = threading.Lock()


def s3cp(logger, fn, s3uri, flag_temporary_file=False):
def s3cp(logger, fn, s3uri):
with _s3parcp_lock:
# when uploading many small outputs from the same pipeline you end up with a
# quick intense burst of load that can bump into the S3 rate limit
Expand All @@ -372,5 +346,3 @@ def s3cp(logger, fn, s3uri, flag_temporary_file=False):
)
)
raise WDL.Error.RuntimeError("failed: " + " ".join(cmd))
if flag_temporary_file:
flag_temporary(s3uri)
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v1.4.5
v1.4.6

0 comments on commit 3cd78ea

Please sign in to comment.