Skip to content

Commit

Permalink
request bytes range directly from S3
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 7, 2023
1 parent cad86b8 commit 53fcdfd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
12 changes: 5 additions & 7 deletions olmo/data/memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import numpy as np
import torch
from cached_path import cached_path
from torch.utils.data import Dataset

from ..aliases import PathOrStr
from ..util import file_size
from ..util import file_size, get_bytes_range

__all__ = ["MemMapDataset"]

Expand Down Expand Up @@ -78,11 +77,10 @@ def offsets(self) -> List[Tuple[int, int]]:
return self._mmap_offsets

def _read_chunk_from_memmap(self, path: PathOrStr, index: int) -> torch.Tensor:
index_start = index * self._item_size * self._chunk_size
with open(cached_path(path), "rb") as f:
f.seek(index_start)
buffer = f.read(self._item_size * self._chunk_size)
array = np.frombuffer(buffer, dtype=self.dtype)
bytes_start = index * self._item_size * self._chunk_size
num_bytes = self._item_size * self._chunk_size
buffer = get_bytes_range(path, bytes_start, num_bytes)
array = np.frombuffer(buffer, dtype=self.dtype)
return torch.tensor(array.astype(np.int_), dtype=torch.long)

def __len__(self) -> int:
Expand Down
37 changes: 37 additions & 0 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,28 @@ def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")


def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
if is_url(source):
from urllib.parse import urlparse

parsed = urlparse(str(source))
if parsed.scheme == "gs":
from cached_path import cached_path

# TODO: directly request range from GCS.
return get_bytes_range(cached_path(source), bytes_start, num_bytes)
elif parsed.scheme == "s3":
return _s3_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
elif parsed.scheme == "file":
return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
else:
raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
else:
with open(source, "rb") as f:
f.seek(bytes_start)
return f.read(num_bytes)


def _gcs_file_size(bucket_name: str, key: str) -> int:
from google.api_core.exceptions import NotFound
from google.cloud import storage as gcs
Expand Down Expand Up @@ -478,6 +500,21 @@ def _s3_file_size(bucket_name: str, key: str) -> int:
raise FileNotFoundError(f"s3://{bucket_name}/{key}")


def _s3_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
import boto3
from botocore.exceptions import ClientError

s3_client = boto3.client("s3")
try:
return s3_client.get_object(
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
)["Body"].read()
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
raise FileNotFoundError(f"s3://{bucket_name}/{key}")


def is_weight_decay_module(module: nn.Module) -> bool:
"""Returns true if the module should use weight decay."""
from .model import LayerNormBase
Expand Down

0 comments on commit 53fcdfd

Please sign in to comment.