diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 3a516fc1d2..58038d12ec 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -987,6 +987,7 @@ def get_upload_signed_url( filename: typing.Optional[str] = None, expires_in: typing.Optional[datetime.timedelta] = None, filename_root: typing.Optional[str] = None, + add_content_md5_metadata: bool = True, ) -> _data_proxy_pb2.CreateUploadLocationResponse: """ Get a signed url to be used during fast registration @@ -1000,22 +1001,27 @@ def get_upload_signed_url( the generated url :param filename_root: If provided will be used as the root of the filename. If not, Admin will use a hash This option is useful when uploading a series of files that you want to be grouped together. + :param add_content_md5_metadata: If true, the content md5 will be added to the metadata in signed URL :rtype: flyteidl.service.dataproxy_pb2.CreateUploadLocationResponse """ - expires_in_pb = None - if expires_in: - expires_in_pb = Duration() - expires_in_pb.FromTimedelta(expires_in) - return super(SynchronousFlyteClient, self).create_upload_location( - _data_proxy_pb2.CreateUploadLocationRequest( - project=project, - domain=domain, - content_md5=content_md5, - filename=filename, - expires_in=expires_in_pb, - filename_root=filename_root, + try: + expires_in_pb = None + if expires_in: + expires_in_pb = Duration() + expires_in_pb.FromTimedelta(expires_in) + return super(SynchronousFlyteClient, self).create_upload_location( + _data_proxy_pb2.CreateUploadLocationRequest( + project=project, + domain=domain, + content_md5=content_md5, + filename=filename, + expires_in=expires_in_pb, + filename_root=filename_root, + add_content_md5_metadata=add_content_md5_metadata, + ) ) - ) + except Exception as e: + raise RuntimeError(f"Failed to get signed url for {filename}, reason: {e}") def get_download_signed_url( self, native_url: str, expires_in: datetime.timedelta = None diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index f07fc727e6..d77145c950 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -217,10 +217,10 @@ def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: if get_protocol(f) == "file": local_fs = fsspec.filesystem("file") if local_fs.exists(f) and local_fs.isdir(f): - print("Adding trailing sep to") + logger.debug("Adding trailing sep to") f = os.path.join(f, "") else: - print("Not adding trailing sep") + logger.debug("Not adding trailing sep") else: f = os.path.join(f, "") t = os.path.join(t, "") diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 076e73efad..40715de1d2 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -806,6 +806,7 @@ def upload_file( to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None, + filename_root: typing.Optional[str] = None, ) -> typing.Tuple[bytes, str]: """ Function will use remote's client to hash and then upload the file using Admin's data proxy service. @@ -813,6 +814,7 @@ def upload_file( :param to_upload: Must be a single file :param project: Project to upload under, if not supplied will use the remote's default :param domain: Domain to upload under, if not specified will use the remote's default + :param filename_root: If provided will be used as the root of the filename. If not, Admin will use a hash :return: The uploaded location. """ if not to_upload.is_file(): @@ -825,9 +827,11 @@ def upload_file( domain=domain or self.default_domain, content_md5=md5_bytes, filename=to_upload.name, + filename_root=filename_root, ) extra_headers = self.get_extra_headers_for_protocol(upload_location.native_url) + extra_headers.update(upload_location.headers) encoded_md5 = b64encode(md5_bytes) with open(str(to_upload), "+rb") as local_file: content = local_file.read() diff --git a/flytekit/remote/remote_fs.py b/flytekit/remote/remote_fs.py index 4eb9d8ebc6..10131f63fa 100644 --- a/flytekit/remote/remote_fs.py +++ b/flytekit/remote/remote_fs.py @@ -7,12 +7,10 @@ import random import threading import typing -from base64 import b64encode from uuid import UUID import fsspec import requests -from flyteidl.service.dataproxy_pb2 import CreateUploadLocationResponse from fsspec.callbacks import NoOpCallback from fsspec.implementations.http import HTTPFileSystem from fsspec.utils import get_protocol @@ -73,13 +71,6 @@ def _upload_chunk(self, final=False): self.buffer.seek(0) data = self.buffer.read() - # h = hashlib.md5() - # h.update(data) - # md5 = h.digest() - # l = len(data) - # - # headers = {"Content-Length": str(l), "Content-MD5": md5} - try: res = self._remote.client.get_upload_signed_url( self._remote.default_project, @@ -132,32 +123,6 @@ async def _get_file(self, rpath, lpath, **kwargs): """ raise NotImplementedError("FlyteFS currently doesn't support downloading files.") - def get_upload_link( - self, - local_file_path: str, - remote_file_part: str, - prefix: str, - hashes: HashStructure, - ) -> typing.Tuple[CreateUploadLocationResponse, int, bytes]: - if not pathlib.Path(local_file_path).exists(): - raise AssertionError(f"File {local_file_path} does not exist") - - p = pathlib.Path(typing.cast(str, local_file_path)) - k = str(p.absolute()) - if k in hashes: - md5_bytes, content_length = hashes[k] - else: - raise AssertionError(f"File {local_file_path} not found in hashes") - upload_response = self._remote.client.get_upload_signed_url( - self._remote.default_project, - self._remote.default_domain, - md5_bytes, - remote_file_part, - filename_root=prefix, - ) - logger.debug(f"Resolved signed url {local_file_path} to {upload_response.native_url}") - return upload_response, content_length, md5_bytes - async def _put_file( self, lpath, @@ -171,20 +136,11 @@ async def _put_file( fsspec will call this method to upload a file. If recursive, rpath will already be individual files. Make the request and upload, but then how do we get the s3 paths back to the user? """ - # remove from kwargs otherwise super() call will fail - p = kwargs.pop(_PREFIX_KEY) - hashes = kwargs.pop(_HASHES_KEY) - # Parse rpath, strip out everything that doesn't make sense. - rpath = rpath.replace(f"{REMOTE_PLACEHOLDER}/", "", 1) - resp, content_length, md5_bytes = self.get_upload_link(lpath, rpath, p, hashes) - - headers = {"Content-Length": str(content_length), "Content-MD5": b64encode(md5_bytes).decode("utf-8")} - kwargs["headers"] = headers - rpath = resp.signed_url - FlytePathResolver.add_mapping(rpath, resp.native_url) - logger.debug(f"Writing {lpath} to {rpath}") - await super()._put_file(lpath, rpath, chunk_size, callback=callback, method=method, **kwargs) - return resp.native_url + prefix = kwargs.pop(_PREFIX_KEY) + _, native_url = self._remote.upload_file( + pathlib.Path(lpath), self._remote.default_project, self._remote.default_domain, prefix + ) + return native_url @staticmethod def extract_common(native_urls: typing.List[str]) -> str: @@ -266,9 +222,6 @@ async def _put( cp file.txt flyte://data/... rpath gets ignored, so it doesn't matter what it is. """ - if rpath != REMOTE_PLACEHOLDER: - logger.debug(f"FlyteFS doesn't yet support specifying full remote path, ignoring {rpath}") - # Hash everything at the top level file_info = self.get_hashes_and_lengths(pathlib.Path(lpath)) prefix = self.get_filename_root(file_info) @@ -278,6 +231,7 @@ async def _put( res = await super()._put(lpath, REMOTE_PLACEHOLDER, recursive, callback, batch_size, **kwargs) if isinstance(res, list): res = self.extract_common(res) + FlytePathResolver.add_mapping(rpath.strip(os.path.sep), res) return res async def _isdir(self, path): diff --git a/tests/flytekit/unit/clients/test_friendly.py b/tests/flytekit/unit/clients/test_friendly.py index 94cae04c25..8d1096e764 100644 --- a/tests/flytekit/unit/clients/test_friendly.py +++ b/tests/flytekit/unit/clients/test_friendly.py @@ -29,10 +29,10 @@ def test_list_projects_paginated(mock_raw_list_projects): @_mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.create_upload_location") def test_create_upload_location(mock_raw_create_upload_location): client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True)) - client.get_upload_signed_url("foo", "bar", bytes(), "baz.qux", timedelta(minutes=42)) + client.get_upload_signed_url("foo", "bar", bytes(), "baz.qux", timedelta(minutes=42), add_content_md5_metadata=True) duration_pb = Duration() duration_pb.FromTimedelta(timedelta(minutes=42)) create_upload_location_request = _data_proxy_pb2.CreateUploadLocationRequest( - project="foo", domain="bar", filename="baz.qux", expires_in=duration_pb + project="foo", domain="bar", filename="baz.qux", expires_in=duration_pb, add_content_md5_metadata=True ) mock_raw_create_upload_location.assert_called_with(create_upload_location_request)