Skip to content

Commit

Permalink
update for multi-storage drivers (#338)
Browse files Browse the repository at this point in the history
* update for multi-storage drivers

* fix s3client range query param

* fix typo, remove dev log
  • Loading branch information
jreadey authored Apr 8, 2024
1 parent 8ff4791 commit a63a44d
Show file tree
Hide file tree
Showing 14 changed files with 409 additions and 182 deletions.
1 change: 1 addition & 0 deletions hsds/basenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def baseInit(node_type):
app["start_time"] = int(time.time()) # seconds after epoch
app["register_time"] = 0
app["max_task_count"] = config.get("max_task_count")
app["storage_clients"] = {} # storage client drivers

is_standalone = config.getCmdLineArg("standalone")

Expand Down
7 changes: 6 additions & 1 deletion hsds/chunk_dn.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,12 @@ async def PUT_Chunk(request):
log.error(msg)
raise HTTPInternalServerError()

input_arr = bytesToArray(input_bytes, select_dt, [num_elements, ])
try:
input_arr = bytesToArray(input_bytes, select_dt, [num_elements, ])
except ValueError as ve:
log.error(f"bytesToArray threw ValueError: {ve}")
raise HTTPInternalServerError()

if bcshape:
input_arr = input_arr.reshape(bcshape)
log.debug(f"broadcasting {bcshape} to mshape {mshape}")
Expand Down
18 changes: 12 additions & 6 deletions hsds/datanode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ async def get_metadata_obj(app, obj_id, bucket=None):
if not bucket:
bucket = domain_bucket

if not bucket:
if bucket:
log.debug(f"get_metadata_obj - using bucket: {bucket}")
else:
log.warn("get_metadata_obj - bucket is None")

# don't call validateInPartition since this is used to pull in
Expand All @@ -355,6 +357,7 @@ async def get_metadata_obj(app, obj_id, bucket=None):
obj_json = meta_cache[obj_id]
else:
s3_key = getS3Key(obj_id)
log.debug(f"get_metadata_obj - using s3_key: {s3_key}")
pending_s3_read = app["pending_s3_read"]
if obj_id in pending_s3_read:
# already a read in progress, wait for it to complete
Expand All @@ -364,12 +367,10 @@ async def get_metadata_obj(app, obj_id, bucket=None):
log.info(msg)
store_read_timeout = float(config.get("store_read_timeout", default=2.0))
log.debug(f"store_read_timeout: {store_read_timeout}")
store_read_sleep_interval = float(
config.get("store_read_sleep_interval", default=0.1)
)
store_read_sleep = float(config.get("store_read_sleep_interval", default=0.1))
while time.time() - read_start_time < store_read_timeout:
log.debug(f"waiting for pending s3 read {s3_key}, sleeping")
await asyncio.sleep(store_read_sleep_interval) # sleep for sub-second?
await asyncio.sleep(store_read_sleep)
if obj_id in meta_cache:
log.info(f"object {obj_id} has arrived!")
obj_json = meta_cache[obj_id]
Expand Down Expand Up @@ -1027,13 +1028,18 @@ async def get_chunk(
log.debug(f"filter_ops: {filter_ops}")

if s3path:
if s3path.startswith("s3://"):
bucket = "s3://"
else:
bucket = ""
try:
bucket = getBucketFromStorURI(s3path)
bucket += getBucketFromStorURI(s3path)
s3key = getKeyFromStorURI(s3path)
except ValueError as ve:
log.error(f"Invalid URI path: {s3path} exception: {ve}")
raise
# raise HTTPInternalServerError()

msg = f"Using s3path bucket: {bucket} and s3key: {s3key} "
msg += f"offset: {s3offset} length: {s3size}"
log.debug(msg)
Expand Down
127 changes: 88 additions & 39 deletions hsds/util/domainUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,103 @@ def isIPAddress(s):
return True


def _stripProtocol(uri):
""" returns part of the uri or bucket name after any protocol specification:
'xyz://' or 'https://myaccount.blob.core.windows.net/'
"""

if not uri or uri.startswith("/"):
return uri

n = uri.find("://")
if n < 0:
return uri

uri = uri[(n + 3):]
parts = uri.split("/")
if len(parts) == 1:
return uri
if parts[0].endswith(".blob.core.windows.net"):
# part of the URI to indicate azure blob storage, skip it
parts = parts[1:]
return "/".join(parts)


def isValidBucketName(bucket):
"""
Check whether the given bucket name is valid
"""
is_valid = True

if bucket is None:
return True

bucket = _stripProtocol(bucket)

# Bucket names must contain at least 1 character
if len(bucket) < 1:
is_valid = False

# Bucket names can consist only of alphanumeric characters, underscores, dots, and hyphens
# other than
if not re.fullmatch("[a-zA-Z0-9_\\.\\-]+", bucket):
is_valid = False

return is_valid


def getBucketForDomain(domain):
"""get the bucket for the domain or None
if no bucket is given
"""
if not domain:
return None
if domain[0] == "/":

# strip s3://, file://, etc
domain_path = _stripProtocol(domain)
if domain_path.startswith("/"):
# no bucket specified
return None
index = domain.find("/")
if index < 0:

nchars = len(domain) - len(domain_path)
protocol = domain[:nchars] # save this so we can re-attach to the bucket name

parts = domain_path.split("/")
if len(parts) < 2:
# invalid domain?
msg = f"invalid domain: {domain}"
raise HTTPBadRequest(reason=msg)
bucket_name = parts[0]
if not isValidBucketName(bucket_name):
return None
if not isValidBucketName(domain[:index]):

# fit back the protocol prefix if set
if protocol:
bucket = protocol
else:
bucket = ""
bucket += bucket_name
return bucket


def getPathForDomain(domain):
"""
Return the non-bucket part of the domain
"""
if not domain:
return None
return domain[:index]

domain_path = _stripProtocol(domain)
if domain_path.startswith("/"):
# no bucket
return domain_path

nindex = domain_path.find("/")
if nindex > 0:
# don't include the bucket
domain_path = domain_path[nindex:]

return domain_path


def getParentDomain(domain):
Expand Down Expand Up @@ -161,7 +242,7 @@ def validateDomainPath(path):
if len(path) < 1:
raise ValueError("Domain path too short")
if path == "/":
return # default buckete, root folder
return # default bucket, root folder
if path[:-1].find("/") == -1:
msg = "Domain path should have at least one '/' before trailing slash"
raise ValueError(msg)
Expand Down Expand Up @@ -262,25 +343,13 @@ def getDomainFromRequest(request, validate=True, allow_dns=True):
pass # no bucket specified

if bucket and validate:
if (bucket.find("/") >= 0) or (not isValidBucketName(bucket)):
if not isValidBucketName(bucket):
raise ValueError(f"bucket name: {bucket} is not valid")
if domain[0] == "/":
domain = bucket + domain
return domain


def getPathForDomain(domain):
"""
Return the non-bucket part of the domain
"""
if not domain:
return None
index = domain.find("/")
if index < 1:
return domain # no bucket
return domain[(index):]


def verifyRoot(domain_json):
"""Throw bad request if we are expecting a domain,
but got a folder instead
Expand All @@ -300,23 +369,3 @@ def getLimits():
limits["max_request_size"] = int(config.get("max_request_size"))

return limits


def isValidBucketName(bucket):
"""
Check whether the given bucket name is valid
"""
is_valid = True

if bucket is None:
return True

# Bucket names must contain at least 1 character
if len(bucket) < 1:
is_valid = False

# Bucket names can consist only of alphanumeric characters, underscores, dots, and hyphens
if not re.fullmatch("[a-zA-Z0-9_\\.\\-]+", bucket):
is_valid = False

return is_valid
52 changes: 47 additions & 5 deletions hsds/util/idUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,43 @@
from .. import hsds_logger as log


S3_URI = "s3://"
FILE_URI = "file://"
AZURE_URI = "blob.core.windows.net/" # preceded with "https://"


def _getStorageProtocol(uri):
""" returns 's3://', 'file://', or 'https://...net/' prefix if present.
If the prefix is in the form: https://myaccount.blob.core.windows.net/mycontainer
(references Azure blob storage), return: https://myaccount.blob.core.windows.net/
otherwise None """

if not uri:
protocol = None
elif uri.startswith(S3_URI):
protocol = S3_URI
elif uri.startswith(FILE_URI):
protocol = FILE_URI
elif uri.startswith("https://") and uri.find(AZURE_URI) > 0:
n = uri.find(AZURE_URI) + len(AZURE_URI)
protocol = uri[:n]
elif uri.find("://") >= 0:
raise ValueError(f"storage uri: {uri} not supported")
else:
protocol = None
return protocol


def _getBaseName(uri):
""" Return the part of the URI after the storage protocol (if any) """

protocol = _getStorageProtocol(uri)
if not protocol:
return uri
else:
return uri[len(protocol):]


def getIdHash(id):
"""Return md5 prefix based on id value"""
m = hashlib.new("md5")
Expand Down Expand Up @@ -146,14 +183,19 @@ def getS3Key(id):
Chunk ids have the chunk index added after the slash:
"db/id[0:16]/d/id[16:32]/x_y_z
For domain id's return a key with the .domain suffix and no
preceeding slash
For domain id's:
Return a key with the .domain suffix and no preceeding slash.
For non-default buckets, use the format: <bucket_name>/s3_key
If the id has a storage specifier ("s3://", "file://", etc.)
include that along with the bucket name. e.g.: "s3://mybucket/a_folder/a_file.h5"
"""
if id.find("/") > 0:

base_id = _getBaseName(id) # strip any s3://, etc.
if base_id.find("/") > 0:
# a domain id
domain_suffix = ".domain.json"
index = id.find("/") + 1
key = id[index:]
index = base_id.find("/") + 1
key = base_id[index:]
if not key.endswith(domain_suffix):
if key[-1] != "/":
key += "/"
Expand Down
Loading

0 comments on commit a63a44d

Please sign in to comment.