Skip to content

Commit

Permalink
Merge pull request #413 from allenai/shanea/storage-cleaner-s3-upload…
Browse files Browse the repository at this point in the history
…-cleanup

[Storage cleaner] S3 upload cleanup
  • Loading branch information
2015aroras authored Jan 22, 2024
2 parents 3053bfa + 5c7d9c6 commit 98425a5
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from argparse import ArgumentParser, _SubParsersAction
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
Expand All @@ -18,11 +17,12 @@
import google.cloud.storage as gcs
import torch
import wandb
from boto3.s3.transfer import TransferConfig
from cached_path import add_scheme_client, cached_path, set_cache_dir
from cached_path.schemes import S3Client
from google.api_core.exceptions import NotFound
from omegaconf import OmegaConf as om
from rich.progress import Progress, TaskID, track
from rich.progress import track

from olmo import util
from olmo.aliases import PathOrStr
Expand Down Expand Up @@ -579,30 +579,28 @@ def download_folder(self, directory_path: str, local_dest_folder: PathOrStr):
else:
raise ValueError(f"Path {directory_path} is not a valid directory")

def _upload_file(self, local_filepath: str, bucket_name: str, key: str):
transfer_config = TransferConfig(max_concurrency=4)
self._s3_client.upload_file(local_filepath, bucket_name, key, Config=transfer_config)

def upload(self, local_src: PathOrStr, dest_path: str):
if self.local_fs_adapter.is_file(str(local_src)):
bucket_name, key = self._get_bucket_name_and_key(dest_path)
self._s3_client.upload_file(str(local_src), bucket_name, key)
self._upload_file(str(local_src), bucket_name, key)

elif self.local_fs_adapter.is_dir(str(local_src)):
local_src = Path(local_src)

def upload_callback(progress: Progress, upload_task: TaskID, bytes_uploaded: int):
progress.update(upload_task, advance=bytes_uploaded)

for file_local_path in local_src.rglob("*"):
local_file_paths = list(local_src.rglob("*"))
for file_local_path in track(local_file_paths, description=f"Uploading to {dest_path}"):
if file_local_path.is_dir():
continue

file_dest_path = str(file_local_path).replace(str(local_src).rstrip("/"), dest_path.rstrip("/"))
bucket_name, key = self._get_bucket_name_and_key(file_dest_path)

with Progress(transient=True) as progress:
size_in_bytes = file_local_path.stat().st_size
upload_task = progress.add_task(f"Uploading {key}", total=size_in_bytes)
callback = partial(upload_callback, progress, upload_task)

self._s3_client.upload_file(str(file_local_path), bucket_name, key, Callback=callback)
if not self._is_file(bucket_name, key):
self._upload_file(str(file_local_path), bucket_name, key)

else:
raise ValueError(f"Local source {local_src} does not correspond to a valid file or directory")
Expand Down

0 comments on commit 98425a5

Please sign in to comment.