Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Feb 27, 2024
1 parent 46713a3 commit 83593cb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
4 changes: 2 additions & 2 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]:
if get_protocol(f) == "file":
local_fs = fsspec.filesystem("file")
if local_fs.exists(f) and local_fs.isdir(f):
print("Adding trailing sep to")
logger.debug("Adding trailing sep to")
f = os.path.join(f, "")
else:
print("Not adding trailing sep")
logger.debug("Not adding trailing sep")
else:
f = os.path.join(f, "")
t = os.path.join(t, "")
Expand Down
36 changes: 25 additions & 11 deletions flytekit/remote/remote_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from fsspec.implementations.http import HTTPFileSystem
from fsspec.utils import get_protocol

from flytekit.exceptions.user import FlyteValueException
from flytekit.loggers import logger
from flytekit.tools.script_mode import hash_file

Expand All @@ -26,6 +27,7 @@
_DEFAULT_CALLBACK = NoOpCallback()
_PREFIX_KEY = "upload_prefix"
_HASHES_KEY = "hashes"
_IS_RECURSIVE_KEY = "is_recursive"
# This file system is not really a filesystem, so users aren't really able to specify the remote path,
# at least not yet.
REMOTE_PLACEHOLDER = "flyte://data"
Expand Down Expand Up @@ -153,7 +155,7 @@ def get_upload_link(
self._remote.default_domain,
md5_bytes,
remote_file_part,
filename_root=prefix if os.path.isdir(local_file_path) else None,
filename_root=prefix,
)
logger.debug(f"Resolved signed url {local_file_path} to {upload_response.native_url}")
return upload_response, content_length, md5_bytes
Expand All @@ -180,11 +182,25 @@ async def _put_file(

headers = {"Content-Length": str(content_length), "Content-MD5": b64encode(md5_bytes).decode("utf-8")}
headers.update(self._remote.get_extra_headers_for_protocol(resp.native_url))

Check warning on line 184 in flytekit/remote/remote_fs.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote_fs.py#L184

Added line #L184 was not covered by tests
kwargs["headers"] = headers
rpath = resp.signed_url
FlytePathResolver.add_mapping(rpath, resp.native_url)
logger.debug(f"Writing {lpath} to {rpath}")
await super()._put_file(lpath, rpath, chunk_size, callback=callback, method=method, **kwargs)

with open(str(lpath), "+rb") as local_file:
content = local_file.read()
rsp = requests.put(

Check warning on line 188 in flytekit/remote/remote_fs.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote_fs.py#L187-L188

Added lines #L187 - L188 were not covered by tests
resp.signed_url,
data=content,
headers=headers,
verify=False
if self._remote.config.platform.insecure_skip_verify is True
else self._remote.config.platform.ca_cert_file_path,
)

# Check both HTTP 201 and 200, because some storage backends (e.g. Azure) return 201 instead of 200.
if rsp.status_code not in (requests.codes["OK"], requests.codes["created"]):
raise FlyteValueException(

Check warning on line 199 in flytekit/remote/remote_fs.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote_fs.py#L199

Added line #L199 was not covered by tests
rsp.status_code,
f"Request to send data {rpath} failed.\nResponse: {rsp.text}",
)

return resp.native_url

@staticmethod
Expand Down Expand Up @@ -266,11 +282,7 @@ async def _put(
"""
cp file.txt flyte://data/...
rpath gets ignored, so it doesn't matter what it is.
"""
if rpath != REMOTE_PLACEHOLDER:
logger.debug(f"FlyteFS doesn't yet support specifying full remote path, ignoring {rpath}")

# Hash everything at the top level
""" # Hash everything at the top level
file_info = self.get_hashes_and_lengths(pathlib.Path(lpath))
prefix = self.get_filename_root(file_info)

Expand All @@ -279,6 +291,8 @@ async def _put(
res = await super()._put(lpath, REMOTE_PLACEHOLDER, recursive, callback, batch_size, **kwargs)
if isinstance(res, list):
res = self.extract_common(res)
FlytePathResolver.add_mapping(rpath.strip("/"), res)
print(f"aaaa Writing {lpath} to {res}")

Check warning on line 295 in flytekit/remote/remote_fs.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote_fs.py#L294-L295

Added lines #L294 - L295 were not covered by tests
return res

async def _isdir(self, path):
Expand Down

0 comments on commit 83593cb

Please sign in to comment.