Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FlyteFS #2208

Merged
merged 17 commits into from
Mar 12, 2024
32 changes: 19 additions & 13 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@
filename: typing.Optional[str] = None,
expires_in: typing.Optional[datetime.timedelta] = None,
filename_root: typing.Optional[str] = None,
add_content_md5_metadata: bool = True,
) -> _data_proxy_pb2.CreateUploadLocationResponse:
"""
Get a signed url to be used during fast registration
Expand All @@ -1000,22 +1001,27 @@
the generated url
:param filename_root: If provided will be used as the root of the filename. If not, Admin will use a hash
This option is useful when uploading a series of files that you want to be grouped together.
:param add_content_md5_metadata: If true, the content md5 will be added to the metadata in signed URL
:rtype: flyteidl.service.dataproxy_pb2.CreateUploadLocationResponse
"""
expires_in_pb = None
if expires_in:
expires_in_pb = Duration()
expires_in_pb.FromTimedelta(expires_in)
return super(SynchronousFlyteClient, self).create_upload_location(
_data_proxy_pb2.CreateUploadLocationRequest(
project=project,
domain=domain,
content_md5=content_md5,
filename=filename,
expires_in=expires_in_pb,
filename_root=filename_root,
try:
expires_in_pb = None
if expires_in:
expires_in_pb = Duration()
expires_in_pb.FromTimedelta(expires_in)
return super(SynchronousFlyteClient, self).create_upload_location(
_data_proxy_pb2.CreateUploadLocationRequest(
project=project,
domain=domain,
content_md5=content_md5,
filename=filename,
expires_in=expires_in_pb,
filename_root=filename_root,
add_content_md5_metadata=add_content_md5_metadata,
)
)
)
except Exception as e:
raise RuntimeError(f"Failed to get signed url for {filename}, reason: {e}")

Check warning on line 1024 in flytekit/clients/friendly.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/friendly.py#L1023-L1024

Added lines #L1023 - L1024 were not covered by tests

def get_download_signed_url(
self, native_url: str, expires_in: datetime.timedelta = None
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]:
if get_protocol(f) == "file":
local_fs = fsspec.filesystem("file")
if local_fs.exists(f) and local_fs.isdir(f):
print("Adding trailing sep to")
logger.debug("Adding trailing sep to")
f = os.path.join(f, "")
else:
print("Not adding trailing sep")
logger.debug("Not adding trailing sep")
else:
f = os.path.join(f, "")
t = os.path.join(t, "")
Expand Down
4 changes: 4 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,13 +806,15 @@ def upload_file(
to_upload: pathlib.Path,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
filename_root: typing.Optional[str] = None,
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
) -> typing.Tuple[bytes, str]:
"""
Function will use remote's client to hash and then upload the file using Admin's data proxy service.

:param to_upload: Must be a single file
:param project: Project to upload under, if not supplied will use the remote's default
:param domain: Domain to upload under, if not specified will use the remote's default
:param filename_root: If provided will be used as the root of the filename. If not, Admin will use a hash
:return: The uploaded location.
"""
if not to_upload.is_file():
Expand All @@ -825,9 +827,11 @@ def upload_file(
domain=domain or self.default_domain,
content_md5=md5_bytes,
filename=to_upload.name,
filename_root=filename_root,
)

extra_headers = self.get_extra_headers_for_protocol(upload_location.native_url)
extra_headers.update(upload_location.headers)
encoded_md5 = b64encode(md5_bytes)
with open(str(to_upload), "+rb") as local_file:
content = local_file.read()
Expand Down
58 changes: 6 additions & 52 deletions flytekit/remote/remote_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
import random
import threading
import typing
from base64 import b64encode
from uuid import UUID

import fsspec
import requests
from flyteidl.service.dataproxy_pb2 import CreateUploadLocationResponse
from fsspec.callbacks import NoOpCallback
from fsspec.implementations.http import HTTPFileSystem
from fsspec.utils import get_protocol
Expand Down Expand Up @@ -73,13 +71,6 @@
self.buffer.seek(0)
data = self.buffer.read()

# h = hashlib.md5()
# h.update(data)
# md5 = h.digest()
# l = len(data)
#
# headers = {"Content-Length": str(l), "Content-MD5": md5}

try:
res = self._remote.client.get_upload_signed_url(
self._remote.default_project,
Expand Down Expand Up @@ -132,32 +123,6 @@
"""
raise NotImplementedError("FlyteFS currently doesn't support downloading files.")

def get_upload_link(
self,
local_file_path: str,
remote_file_part: str,
prefix: str,
hashes: HashStructure,
) -> typing.Tuple[CreateUploadLocationResponse, int, bytes]:
if not pathlib.Path(local_file_path).exists():
raise AssertionError(f"File {local_file_path} does not exist")

p = pathlib.Path(typing.cast(str, local_file_path))
k = str(p.absolute())
if k in hashes:
md5_bytes, content_length = hashes[k]
else:
raise AssertionError(f"File {local_file_path} not found in hashes")
upload_response = self._remote.client.get_upload_signed_url(
self._remote.default_project,
self._remote.default_domain,
md5_bytes,
remote_file_part,
filename_root=prefix,
)
logger.debug(f"Resolved signed url {local_file_path} to {upload_response.native_url}")
return upload_response, content_length, md5_bytes

async def _put_file(
self,
lpath,
Expand All @@ -171,20 +136,11 @@
fsspec will call this method to upload a file. If recursive, rpath will already be individual files.
Make the request and upload, but then how do we get the s3 paths back to the user?
"""
# remove from kwargs otherwise super() call will fail
p = kwargs.pop(_PREFIX_KEY)
hashes = kwargs.pop(_HASHES_KEY)
# Parse rpath, strip out everything that doesn't make sense.
rpath = rpath.replace(f"{REMOTE_PLACEHOLDER}/", "", 1)
resp, content_length, md5_bytes = self.get_upload_link(lpath, rpath, p, hashes)

headers = {"Content-Length": str(content_length), "Content-MD5": b64encode(md5_bytes).decode("utf-8")}
kwargs["headers"] = headers
rpath = resp.signed_url
FlytePathResolver.add_mapping(rpath, resp.native_url)
logger.debug(f"Writing {lpath} to {rpath}")
await super()._put_file(lpath, rpath, chunk_size, callback=callback, method=method, **kwargs)
return resp.native_url
prefix = kwargs.pop(_PREFIX_KEY)
_, native_url = self._remote.upload_file(

Check warning on line 140 in flytekit/remote/remote_fs.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote_fs.py#L139-L140

Added lines #L139 - L140 were not covered by tests
pathlib.Path(lpath), self._remote.default_project, self._remote.default_domain, prefix
)
return native_url

Check warning on line 143 in flytekit/remote/remote_fs.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote_fs.py#L143

Added line #L143 was not covered by tests

@staticmethod
def extract_common(native_urls: typing.List[str]) -> str:
Expand Down Expand Up @@ -266,9 +222,6 @@
cp file.txt flyte://data/...
rpath gets ignored, so it doesn't matter what it is.
"""
if rpath != REMOTE_PLACEHOLDER:
logger.debug(f"FlyteFS doesn't yet support specifying full remote path, ignoring {rpath}")

# Hash everything at the top level
file_info = self.get_hashes_and_lengths(pathlib.Path(lpath))
prefix = self.get_filename_root(file_info)
Expand All @@ -278,6 +231,7 @@
res = await super()._put(lpath, REMOTE_PLACEHOLDER, recursive, callback, batch_size, **kwargs)
if isinstance(res, list):
res = self.extract_common(res)
FlytePathResolver.add_mapping(rpath.strip(os.path.sep), res)

Check warning on line 234 in flytekit/remote/remote_fs.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote_fs.py#L234

Added line #L234 was not covered by tests
return res

async def _isdir(self, path):
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/clients/test_friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def test_list_projects_paginated(mock_raw_list_projects):
@_mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.create_upload_location")
def test_create_upload_location(mock_raw_create_upload_location):
client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True))
client.get_upload_signed_url("foo", "bar", bytes(), "baz.qux", timedelta(minutes=42))
client.get_upload_signed_url("foo", "bar", bytes(), "baz.qux", timedelta(minutes=42), add_content_md5_metadata=True)
duration_pb = Duration()
duration_pb.FromTimedelta(timedelta(minutes=42))
create_upload_location_request = _data_proxy_pb2.CreateUploadLocationRequest(
project="foo", domain="bar", filename="baz.qux", expires_in=duration_pb
project="foo", domain="bar", filename="baz.qux", expires_in=duration_pb, add_content_md5_metadata=True
)
mock_raw_create_upload_location.assert_called_with(create_upload_location_request)
Loading