diff --git a/scripts/writing/bench.py b/scripts/writing/bench.py new file mode 100644 index 000000000..55115fb02 --- /dev/null +++ b/scripts/writing/bench.py @@ -0,0 +1,597 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a synthetic dataset.""" + +import json +import os +from argparse import ArgumentParser, Namespace +from collections import defaultdict +from functools import partial +from shutil import rmtree +from time import time +from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar + +import lance +import numpy as np +import pyarrow as pa +import pyspark +import pyspark.sql +from delta import configure_spark_with_delta_pip +from numpy.random import Generator +from pyarrow import parquet as pq +from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from tqdm import tqdm +from wurlitzer import pipes + +from streaming import CSVWriter, JSONWriter, MDSWriter +from streaming.base.format.base.tabulation import Tabulator + + +def _parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + + # Reproducibility. + args.add_argument('--seed', type=int, default=1337) + + # Dataset data distribution. + args.add_argument('--data_pos_prob', type=float, default=0.75) + args.add_argument('--data_low', type=int, default=-1_000_000_000) + args.add_argument('--data_high', type=int, default=1_000_000_000) + + # Sizes of datasets splits and shards. + args.add_argument('--small', type=int, default=1 << 15) + args.add_argument('--medium', type=int, default=1 << 20) + args.add_argument('--large', type=int, default=1 << 25) + args.add_argument('--size_limit', type=int, default=1 << 23) + args.add_argument('--samples_per_shard', type=int, default=1 << 18) + + # Outputs. + args.add_argument('--data_root', type=str, default='data/backends/') + # args.add_argument('--formats', type=str, default='csv,delta,jsonl,lance,mds,parquet') + args.add_argument('--formats', type=str, default='csv,jsonl,mds') + + # Introspection. + args.add_argument('--show_progress', type=int, default=1) + args.add_argument('--quiet_delta', type=int, default=1) + + return args.parse_args() + + +def _generate_int(rng: Generator, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000) -> int: + """Pick a random integer to say in words. + + This is a synthetic dataset whose random numbers need to be distinct, deterministic given a + seed, and little else. We choose a distribution that seems the most pleasing to us. + + Properties: + * About 80% positive and 20% negative. + * Magnitude of up to a billion on either side of zero. + * Strongly skewed toward the origin, i.e. chosen uniformly across base-10 digit lengths (at + least until running out of integers of that length anyway). + + Args: + rng (Generator): NumPy random number generator. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + """ + if not 0 <= pos_prob <= 1: + raise ValueError(f'Invalid positive probability ``pos_prob``: 0 <= {pos_prob} <= 1.') + + if not low < 0 < high: + raise ValueError(f'Invalid sampling range ``low`` and/or ``high``: {low} < 0 < {high}.') + + is_pos = rng.uniform() < pos_prob + max_digits = np.log10(high) if is_pos else np.log10(-low) + exponent = rng.uniform(0, max_digits) + magnitude = int(10**exponent) + sign = is_pos * 2 - 1 + return sign * magnitude + + +def _generate_ints(count: int, + seed: int = 0x1337, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000, + show_progress: bool = True) -> List[int]: + """Sample until we have the given number of distinct integers. + + Args: + count (int): How many samples to draw. + seed (int): Seed for the random number generator. Defaults to ``0x1337``. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + show_progress (bool): Whether to display a progress bar. Defaults to ``True``. + + Returns: + List[int]: The integers that were drawn. + """ + rng = np.random.default_rng(seed) + nums = set() + progress_bar = tqdm(total=count, leave=False) if show_progress else None + while len(nums) < count: + num = _generate_int(rng) + if num in nums: + continue + + nums.add(num) + if progress_bar: + progress_bar.update(1) + if progress_bar: + progress_bar.close() + + nums = sorted(nums) + rng.shuffle(nums) + return nums + + +_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' + 'fifteen sixteen seventeen eighteen nineteen').split() + +_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() + + +def _int_to_words(num: int) -> List[str]: + """Say an integer as a list of words. + + Args: + num (int): The integer. + + Returns: + List[str]: The integer as a list of words. + """ + if num < 0: + return ['negative'] + _int_to_words(-num) + elif num <= 19: + return [_ones[num]] + elif num < 100: + tens = [_tens[num // 10 - 2]] + ones = [_ones[num % 10]] if num % 10 else [] + return tens + ones + elif num < 1_000: + hundreds = [_ones[num // 100], 'hundred'] + etc = _int_to_words(num % 100) if num % 100 else [] + return hundreds + etc + elif num < 1_000_000: + thousands = _int_to_words(num // 1_000) + ['thousand'] + etc = _int_to_words(num % 1_000) if num % 1_000 else [] + return thousands + etc + elif num < 1_000_000_000: + millions = _int_to_words(num // 1_000_000) + ['million'] + etc = _int_to_words(num % 1_000_000) if num % 1_000_000 else [] + return millions + etc + else: + raise ValueError('Integer out of range: -1,000,000,000 < {num} < +1,000,000,000.') + + +def _int_to_text(num: int) -> str: + """Say an integer as text. + + Args: + num (int): The integer. + + Returns: + str: The integer as text. + """ + words = _int_to_words(num) + return ' '.join(words) + + +T = TypeVar('T') + + +def _split(items: List[T], sizes: List[int]) -> List[List[T]]: + """Divide the given items across the splits given by their sizes. + + Args: + items (List[Any]): The items to divide across the spans. + sizes (List[int]): Number of items per split. + + Returns: + List[List[Any]]: Each split of items. + """ + total = sum(sizes) + if len(items) != total: + raise ValueError(f'Number of items must match the combined size of the splits: ' + + f'{len(items)} items vs splits of size {sizes} = {total}.') + + splits = [] + begin = 0 + for size in sizes: + split = items[begin:begin + size] + splits.append(split) + begin += size + + return splits + + +def _generate(split2size: Dict[str, int], + seed: int = 0x1337, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000, + show_progress: bool = True) -> Dict[str, Tuple[List[int], List[str]]]: + """Generate a dataset, made of splits, to be saved in different forms for comparison. + + Args: + split2size (Dict[str, int]): Mapping of split name to size in samples. + seed (int): Seed for the random number generator. Defaults to ``0x1337``. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + show_progress (bool): Whether to show a progress bar. Defaults to ``True``. + + Returns: + Dict[str, Tuple[List[int], List[str]]]: Mapping of split name to nums and texts. + """ + split_sizes = [] + total = 0 + for split in sorted(split2size): + size = split2size[split] + split_sizes.append(size) + total += size + + nums = _generate_ints(total, seed, low, high, show_progress) + nums_per_split = _split(nums, split_sizes) + + texts = list(map(_int_to_text, nums)) + texts_per_split = _split(texts, split_sizes) + + dataset = {} + for index, split in enumerate(sorted(split2size)): + dataset[split] = nums_per_split[index], texts_per_split[index] + + return dataset + + +def _write_csv(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> Dict[str, Any]: + """Save the dataset in Streaming CSV form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = { + 'num': 'int', + 'txt': 'str', + } + out = CSVWriter(out=root, columns=columns, size_limit=size_limit) + with out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = { + 'num': num, + 'txt': txt, + } + out.write(sample) + return out.timer.get_stats(10) + + +def _write_jsonl(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> Dict[str, Any]: + """Save the dataset Streaming JSON form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = { + 'num': 'int', + 'txt': 'str', + } + out = JSONWriter(out=root, columns=columns, size_limit=size_limit) + with out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = { + 'num': num, + 'txt': txt, + } + out.write(sample) + return out.timer.get_stats(10) + + +def _write_mds(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> Dict[str, Any]: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = { + 'num': 'int', + 'txt': 'str', + } + out = MDSWriter(out=root, columns=columns, size_limit=size_limit) + with out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = { + 'num': num, + 'txt': txt, + } + out.write(sample) + return out.timer.get_stats(10) + + +def _write_parquet(nums: List[int], + txts: List[str], + root: str, + samples_per_shard: int, + show_progress: bool = True) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + if not os.path.exists(root): + os.makedirs(root) + num_samples = len(nums) + num_shards = (num_samples + samples_per_shard - 1) // samples_per_shard + each_shard = range(num_shards) + if show_progress: + each_shard = tqdm(each_shard, total=num_shards, leave=False) + for i in each_shard: + begin = i * samples_per_shard + end = min(begin + samples_per_shard, num_samples) + shard_nums = nums[begin:end] + shard_txts = txts[begin:end] + path = os.path.join(root, f'{i:05}.parquet') + obj = { + 'num': shard_nums, + 'txt': shard_txts, + } + table = pa.Table.from_pydict(obj) + pq.write_table(table, path) + + +def _write_delta(nums: List[int], txts: List[str], root: str, samples_per_shard: int) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + """ + builder = pyspark.sql.SparkSession.builder.appName('prolix') # pyright: ignore + builder = builder.config('spark.sql.extensions', 'io.delta.sql.DeltaSparkSessionExtension') + builder = builder.config('spark.sql.catalog.spark_catalog', + 'org.apache.spark.sql.delta.catalog.DeltaCatalog') + spark = configure_spark_with_delta_pip(builder).getOrCreate() + schema = StructType([ + StructField('num', IntegerType(), False), + StructField('txt', StringType(), False), + ]) + samples = list(zip(nums, txts)) + df = spark.createDataFrame(samples, schema) + df.write.format('delta').option('maxRecordsPerFile', samples_per_shard).save(root) + + +def _do_write_delta(nums: List[int], + txts: List[str], + root: str, + samples_per_shard: int, + quietly: bool = True) -> None: + """Save the dataset in Streaming MDS form, possibly capturing stdout/stderr. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + quietly (bool): Whether to capture the Delta logging. Defaults to ``True``. + """ + write = lambda: _write_delta(nums, txts, root, samples_per_shard) + if quietly: + with pipes(): + write() + else: + write() + + +def _write_lance(nums: List[int], txts: List[str], root: str, samples_per_shard: int) -> None: + """Save the dataset in Lance form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + """ + column_names = 'num', 'txt' + column_values = nums, txts + table = pa.Table.from_arrays(column_values, column_names) + lance.write_dataset(table, root, mode='create', max_rows_per_file=samples_per_shard) + + +def _get_file_sizes(root: str) -> List[int]: + """Inventory what was written, collecting total files and total bytes. + + Args: + root (str): Dataset root. + + Returns: + Tuple[int, int]: Total files and total bytes written. + """ + sizes = [] + for parent, _, file_basenames in sorted(os.walk(root)): + for basename in sorted(file_basenames): + path = os.path.join(parent, basename) + size = os.stat(path).st_size + sizes.append(size) + return sizes + + +def _splits_by_size(dataset: Dict[str, Tuple[List[int], List[str]]]) -> Iterable[str]: + """Order a dataset's splits by their size in samples, then by name. + + Argxs: + dataset (Dict[str, Tuple[List[int], List[str]]]): Mapping of split name to split data. + + Returns: + Iterable[str]: Ordered split names. + """ + size2splits = defaultdict(list) + for split, (nums, _) in dataset.items(): + size2splits[len(nums)].append(split) + + splits_by_size = [] + for size in sorted(size2splits): + for split in sorted(size2splits[size]): + splits_by_size.append(split) + + return splits_by_size + + +def main(args: Namespace) -> None: + """Generate identical datasets in various formats for performance comparison. + + Args: + args (Namespace): Command-line arguments. + """ + # Confgure the dataset writing statistics table printer. + table_columns = ''' + < format 8 + > sec 7 + > samples 12 + > usec/sp 8 + > bytes 14 + > files 6 + > bytes/file 12 + > max bytes/file 14 + ''' + table_indent = 4 + table = Tabulator.from_conf(table_columns, table_indent * ' ') + + # Normalize arguments. + format_names = args.formats.split(',') if args.formats else [] + show_progress = bool(args.show_progress) + quiet_delta = bool(args.quiet_delta) + + # Given args, now we know how to configure saving the dataset in each format. + format2write = { + 'csv': + partial(_write_csv, size_limit=args.size_limit, show_progress=show_progress), + 'delta': + partial(_do_write_delta, quietly=quiet_delta, + samples_per_shard=args.samples_per_shard), + 'jsonl': + partial(_write_jsonl, size_limit=args.size_limit, show_progress=show_progress), + 'lance': + partial(_write_lance, samples_per_shard=args.samples_per_shard), + 'mds': + partial(_write_mds, size_limit=args.size_limit, show_progress=show_progress), + 'parquet': + partial(_write_parquet, + samples_per_shard=args.samples_per_shard, + show_progress=show_progress), + } + + # Collect sizes of the splits to generate. + split2size = { + 'small': args.small, + 'medium': args.medium, + 'large': args.large, + } + + # Generate the dataset samples. + t0 = time() + dataset = _generate(split2size, args.seed, args.data_pos_prob, args.data_low, args.data_high, + show_progress) + elapsed = time() - t0 + print(f'Generate: {elapsed:.3f} sec.') + + # Wipe output directory if exists. + if os.path.exists(args.data_root): + print(f'Found directory at {args.data_root}, wiping it for reuse') + rmtree(args.data_root) + + # Write each split in each desired formats, in order of size. + pretty_int = lambda num: f'{num:,}' + for split in _splits_by_size(dataset): + print() + print(f'Write split: {split}') + print(table.draw_line()) + print(table.draw_header()) + print(table.draw_line()) + + nums, txts = dataset[split] + names_objs = [] + for format_name in format_names: + split_root = os.path.join(args.data_root, 'gold', format_name, split) + write = format2write[format_name] + + t0 = time() + try: + obj = write(nums, txts, split_root) + if obj: + names_objs.append((format_name, obj)) + except: + raise # Getting Delta Java OOMs at gigabyte size. + elapsed = time() - t0 + + file_sizes = _get_file_sizes(split_root) + row = { + 'format': format_name, + 'sec': f'{elapsed:.3f}', + 'samples': pretty_int(len(nums)), + 'usec/sp': f'{1e6 * elapsed / len(nums):.3f}', + 'bytes': pretty_int(sum(file_sizes)), + 'files': pretty_int(len(file_sizes)), + 'bytes/file': pretty_int(sum(file_sizes) // len(file_sizes)), + 'max bytes/file': pretty_int(max(file_sizes)), + } + print(table.draw_row(row)) + print(table.draw_line()) + + for name, obj in names_objs: + text = json.dumps(obj, indent=4, sort_keys=True) + print() + print(f'{name}:') + print(text) + + +if __name__ == '__main__': + main(_parse_args()) diff --git a/streaming/base/format/base/tabulation.py b/streaming/base/format/base/tabulation.py new file mode 100644 index 000000000..179194d6c --- /dev/null +++ b/streaming/base/format/base/tabulation.py @@ -0,0 +1,125 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Line by line text table printer.""" + +from typing import Any, Dict, List, Optional, Tuple + +from typing_extensions import Self + +__all__ = ['Tabulator'] + + +class Tabulator: + """Line by line text table printer. + + Example: + conf = ''' + < format 8 + > sec 6 + > samples 12 + > usec/sp 8 + > bytes 14 + > files 6 + > bytes/file 12 + > max bytes/file 14 + ''' + left = 4 * ' ' + tab = Tabulator.from_conf(conf, left) + + Args: + cols (List[Tuple[str, str, int]]: Each column config (i.e., just, name, width). + left (str, optional): Print this before each line (e.g., indenting). Defaults to ``None``. + """ + + def __init__(self, cols: List[Tuple[str, str, int]], left: Optional[str] = None) -> None: + self.cols = cols + self.col_justs = [] + self.col_names = [] + self.col_widths = [] + for just, name, width in cols: + if just not in {'<', '>'}: + raise ValueError(f'Invalid justify (must be one of "<" or ">"): {just}.') + + if not name: + raise ValueError('Name must be non-empty.') + elif width < len(name): + raise ValueError(f'Name is too wide for its column width: {width} vs {name}.') + + if width <= 0: + raise ValueError(f'Width must be positive, but got: {width}.') + + self.col_justs.append(just) + self.col_names.append(name) + self.col_widths.append(width) + + self.left = left + + self.box_chr_horiz = chr(0x2500) + self.box_chr_vert = chr(0x2502) + + @classmethod + def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: + """Initialize a Tabulator from a text table defining its columns. + + Args: + conf (str): The table config. + left (str, optional): Optional string that is printed before each line (e.g., indents). + """ + cols = [] + for line in conf.strip().split('\n'): + words = line.split() + + if len(words) < 3: + raise ValueError(f'Invalid col config (must be "just name width"): {line}.') + + just = words[0] + name = ' '.join(words[1:-1]) + width = int(words[-1]) + cols.append((just, name, width)) + return cls(cols, left) + + def draw_row(self, row: Dict[str, Any]) -> str: + """Draw a row, given a mapping of column name to field value. + + Args: + row (Dict[str, Any]): Mapping of column name to field value. + + Returns: + str: Text line. + """ + fields = [] + for just, name, width in self.cols: + val = row[name] + + txt = val if isinstance(val, str) else str(val) + if width < len(txt): + raise ValueError(f'Field is too wide for its column: column (just: {just}, ' + + f'name: {name}, width: {width}) vs field {txt}.') + + txt = txt.ljust(width) if just == '<' else txt.rjust(width) + fields.append(txt) + + left_txt = self.left or '' + fields_txt = f' {self.box_chr_vert} '.join(fields) + return f'{left_txt}{self.box_chr_vert} {fields_txt} {self.box_chr_vert}' + + def draw_header(self) -> str: + """Draw a header row. + + Returns: + str: Text line. + """ + row = dict(zip(self.col_names, self.col_names)) + return self.draw_row(row) + + def draw_line(self) -> str: + """Draw a divider row. + + Returns: + str: Text line. + """ + seps = (self.box_chr_horiz * width for width in self.col_widths) + row = dict(zip(self.col_names, seps)) + line = self.draw_row(row) + return line.replace(self.box_chr_vert, self.box_chr_horiz) diff --git a/streaming/base/format/base/timer.py b/streaming/base/format/base/timer.py new file mode 100644 index 000000000..4bc2331a4 --- /dev/null +++ b/streaming/base/format/base/timer.py @@ -0,0 +1,241 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A recursive timer contextmanager, whose state is serializable, which calculates stats.""" + +from time import time_ns +from types import TracebackType +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +from numpy.typing import NDArray +from typing_extensions import Self + +__all__ = ['Timer'] + + +class Timer: + """A recursive timer contextmanager, whose state is serializable, which calculates stats. + + Args: + named_timers (List[Union[str, Tuple[str, Self]]], optional): List of pairs of (name, + Timer). Defaults to ``None``. + spans (List[Tuple[int, int]] | NDArray[np.int64], optional): Either a list of pairs of + (enter, exit) times, or the same in array form. Times are given in nanoseconds since + the epoch. ``None`` means the empty list. Defaults to ``None``. + """ + + def __init__(self, + named_timers: Optional[List[Union[str, Tuple[str, Self]]]] = None, + spans: Optional[Union[List[Tuple[int, int]], NDArray[np.int64]]] = None) -> None: + if spans is None: + self.spans = [] + elif isinstance(spans, list): + np.asarray(spans, np.int64) + self.spans = spans + else: + self.spans = spans.tolist() + + self.named_timers = [] + for named_timer in named_timers or []: + if isinstance(named_timer, str): + named_timer = named_timer, Timer() + name, timer = named_timer + if hasattr(self, name): + raise ValueError(f'Timer name is already taken: {name}.') + setattr(self, name, timer) + self.named_timers.append((name, timer)) + + @classmethod + def from_bytes(cls, data: bytes, offset: int = 0) -> Tuple[Self, int]: + """Efficiently deserialize our state from bytes. + + The offset is needed because this process is recursive. + + Args: + data (bytes): Buffer containing its serialized form. + offset (int): Byte offset into the given buffer. Defaults to ``0``. + + Returns: + Tuple[Self, int]: Pair of (loaded class, byte offset into the buffer afterward). + """ + dtype = np.int64() + + if len(data) % dtype: + raise ValueError(f'`data` size must be divisible by {dtype.nbytes}, but got: ' + + f'{len(data)}.') + + if offset % dtype.nbytes: + raise ValueError(f'`offset` must be divisible by {dtype.nbytes}, but got: {offset}.') + + arr = np.frombuffer(data, np.int64) + idx = offset // dtype.nbytes + + num_spans = int(arr[idx]) + idx += 1 + + spans = arr[idx:idx + num_spans * 2].reshape(num_spans, 2) + idx += num_spans * 2 + + num_named_timers = arr[idx] + idx += 1 + + named_timers = [] + for _ in range(num_named_timers): + name_size = arr[idx] + idx += 1 + + pad_size = dtype.nbytes - name_size % dtype.nbytes + name_units = (name_size + pad_size) // 8 + name_bytes = arr[idx:idx + name_units].tobytes()[:-pad_size] + name = name_bytes.decode('utf-8') + idx += name_units + + subtimer, offset = cls.from_bytes(data, idx * dtype.nbytes) + idx = offset // dtype.nbytes + + named_timer = name, subtimer + named_timers.append(named_timer) + + timer = cls(named_timers, spans) + return timer, offset + + def to_bytes(self) -> bytes: + """Efficiently serialize our state to bytes. + + Returns: + bytes: Serialized state. + """ + num_spans = np.int64(len(self.spans)) + spans = np.asarray(self.spans, np.int64) + num_named_timers = np.int64(len(self.named_timers)) + parts = [num_spans.tobytes(), spans.tobytes(), num_named_timers.tobytes()] + dtype = np.int64() + for name, timer in self.named_timers: + name_bytes = name.encode('utf-8') + name_size = np.int64(len(name_bytes)) + pad_size = dtype.nbytes - name_size % dtype.nbytes + pad_bytes = '\0' * pad_size + parts += [name_size.tobytes(), name_bytes, pad_bytes, timer.to_bytes()] + return b''.join(parts) + + def to_dynamic(self) -> Self: + """Convert our spans to a dynamic-size list of pairs for fast data collection. + + Returns: + Self: This object. + """ + if isinstance(self.spans, np.ndarray): + self.spans = self.spans.tolist() + for _, timer in self.named_timers: + timer.to_dynamic() + return self + + def to_fixed(self) -> Self: + """Convert our spans to a fixed-size array for fast data analysis. + + Returns: + Self: This object. + """ + if isinstance(self.spans, list): + self.spans = np.asarray(self.spans, np.int64) + for _, timer in self.named_timers: + timer.to_fixed() + return self + + def _get_duration_stats(self, num_groups: Optional[int] = 100) -> Dict[str, Any]: + """Calculate duration statistics. + + Args: + num_groups (int, optional): Number of groups for quantiling. Defaults to ``100``. + + Returns: + Dict[str, Any]: Duration statistics. + """ + if isinstance(self.spans, list): + self.spans = np.asarray(self.spans, np.int64) + + durs = self.spans[:, 1] - self.spans[:, 0] + durs = durs / 1e9 + + obj = { + 'total': float(durs.sum()), + } + + if 1 < len(durs): + obj.update({ + 'count': len(durs), + 'min': float(min(durs)), + 'max': float(max(durs)), + 'mean': float(durs.mean()), + 'std': float(durs.std()), + }) + + if num_groups: + fracs = np.linspace(0, 1, num_groups + 1) + obj['quantiles'] = np.quantile(durs, fracs).tolist() + + return obj + + def get_stats(self, num_groups: Optional[int] = 100) -> Dict[str, Any]: + """Get statistics. + + Args: + num_groups (int, optional): Number of groups for quantiling. Defaults to ``100``. + + Returns: + Dict[str, Any]: Recursive dict of duration statistics. + """ + obj = { + 'stats': self._get_duration_stats(num_groups), + } + + if self.named_timers: + named_timers = [] + for name, timer in self.named_timers: + named_timers.append([name, timer.get_stats(num_groups)]) + + whole = obj['stats']['total'] + sum_of_parts = 0 + for name, timer in named_timers: + part = timer['stats']['total'] + timer['stats']['frac_of_whole'] = part / whole + sum_of_parts += part + + for name, timer in named_timers: + part = timer['stats']['total'] + timer['stats']['frac_of_parts'] = part / sum_of_parts + + obj['named_timers'] = named_timers # pyright: ignore + + return obj + + def __enter__(self) -> Self: + """Enter context manager. + + Returns: + Self: This object. + """ + if isinstance(self.spans, np.ndarray): + self.spans = self.spans.tolist() + + span = time_ns(), 0 + self.spans.append(span) + return self + + def __exit__(self, + exc_type: Optional[Type[BaseException]] = None, + exc: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None) -> None: + """Exit context manager. + + Args: + exc_type (Type[BaseException], optional): Exception type. Defaults to ``None``. + exc (BaseException, optional): Exception. Defaults to ``None``. + traceback (TracebackType, optional): Traceback. Defaults to ``None``. + """ + if isinstance(self.spans, np.ndarray): + self.spans = self.spans.tolist() + + span = self.spans[-1] + self.spans[-1] = span[0], time_ns() diff --git a/streaming/base/format/base/writer.py b/streaming/base/format/base/writer.py index 7cc3add3d..8ea34588a 100644 --- a/streaming/base/format/base/writer.py +++ b/streaming/base/format/base/writer.py @@ -19,6 +19,7 @@ from typing_extensions import Self from streaming.base.compression import compress, get_compression_extension, is_compression +from streaming.base.format.base.timer import Timer from streaming.base.format.index import get_index_basename from streaming.base.hashing import get_hash, is_hash from streaming.base.storage.upload import CloudUploader @@ -111,23 +112,52 @@ def __init__(self, self.size_limit = size_limit_value self.extra_bytes_per_shard = extra_bytes_per_shard self.extra_bytes_per_sample = extra_bytes_per_sample + self.new_samples: List[bytes] self.new_shard_size: int - self.shards = [] - self.cloud_writer = CloudUploader.get(out, keep_local, kwargs.get('progress_bar', False), - kwargs.get('retry', 2)) - self.local = self.cloud_writer.local - self.remote = self.cloud_writer.remote + self._uploader = CloudUploader.get(out, keep_local, kwargs.get('progress_bar', False), + kwargs.get('retry', 2)) + self.local = self._uploader.local + self.remote = self._uploader.remote # `max_workers`: The maximum number of threads that can be executed in parallel. # One thread is responsible for uploading one shard file to a remote location. self.executor = ThreadPoolExecutor(max_workers=kwargs.get('max_workers', None)) # Create an event to track an exception in a thread. self.event = Event() + self.timer = self._get_timer() + self._reset_cache() + @classmethod + def _get_timer(cls) -> Timer: + """Get a timer tree for the process of writing a dataset. + + Returns: + Timer: The tree of timers. + """ + return Timer([ + ('write', Timer([ + 'serialize_sample', + 'flush_shard', + ])), + ('finish', + Timer([ + 'flush_last_shard', + ('save_index', + Timer([ + 'serialize_index', + 'write_index', + 'wait_for_shard_uploads', + 'upload_index', + ])), + 'shutdown_executor', + 'remove_local', + ])), + ]) + def _reset_cache(self) -> None: """Reset our internal shard-building cache. @@ -245,53 +275,88 @@ def write(self, sample: Dict[str, Any]) -> None: self.cancel_future_jobs() raise Exception('One of the threads failed. Check other traceback for more ' + 'details.') - # Execute the task if there is no exception in any of the async threads. - new_sample = self.encode_sample(sample) - new_sample_size = len(new_sample) + self.extra_bytes_per_sample - if self.size_limit and self.size_limit < self.new_shard_size + new_sample_size: - self.flush_shard() - self._reset_cache() - self.new_samples.append(new_sample) - self.new_shard_size += new_sample_size - - def _write_index(self) -> None: + + with self.timer.write as ctx: # pyright: ignore + + with ctx.serialize_sample: + new_sample = self.encode_sample(sample) + new_sample_size = len(new_sample) + self.extra_bytes_per_sample + + with ctx.flush_shard: + if self.size_limit and self.size_limit < self.new_shard_size + new_sample_size: + self.flush_shard() + self._reset_cache() + + self.new_samples.append(new_sample) + self.new_shard_size += new_sample_size + + def _upload_file(self, basename: str) -> None: + """Do the file upload, calling the callback when done. + + Args: + basename (str): File basename. + """ + future = self.executor.submit(self._uploader.upload_file, basename) + future.add_done_callback(self.exception_callback) + + def _save_index(self) -> None: """Write the index, having written all the shards.""" + ctx = self.timer.finish.save_index # pyright: ignore + if self.new_samples: raise RuntimeError('Internal error: not all samples have been written.') + if self.event.is_set(): - # Shutdown the executor and cancel all the pending futures due to exception in one of - # the threads. + # Shutdown the executor and cancel all the pending futures due to exception in + # one of the threads. self.cancel_future_jobs() return - basename = get_index_basename() - filename = os.path.join(self.local, basename) - obj = { - 'version': 2, - 'shards': self.shards, - } - with open(filename, 'w') as out: - json.dump(obj, out, sort_keys=True) - # Execute the task if there is no exception in any of the async threads. - while self.executor._work_queue.qsize() > 0: - logger.debug( - f'Queue size: {self.executor._work_queue.qsize()}. Waiting for all ' + - f'shard files to get uploaded to {self.remote} before uploading index.json') - sleep(1) - logger.debug(f'Queue size: {self.executor._work_queue.qsize()}. Uploading ' + - f'index.json to {self.remote}') - future = self.executor.submit(self.cloud_writer.upload_file, basename) - future.add_done_callback(self.exception_callback) + + with ctx.serialize_index: + basename = get_index_basename() + filename = os.path.join(self.local, basename) + obj = { + 'version': 2, + 'shards': self.shards, + } + text = json.dumps(obj, sort_keys=True) + data = text.encode('utf-8') + + with ctx.write_index: + with open(filename, 'wb') as out: + out.write(data) + + with ctx.wait_for_shard_uploads: + while self.executor._work_queue.qsize(): + logger.debug( + f'Queue size: {self.executor._work_queue.qsize()}. Waiting for all ' + + f'shard files to get uploaded to {self.remote} before uploading index.json') + sleep(1) + + with ctx.upload_index: + logger.debug(f'Queue size: {self.executor._work_queue.qsize()}. Uploading ' + + f'index.json to {self.remote}') + self._upload_file(basename) def finish(self) -> None: """Finish writing samples.""" - if self.new_samples: - self.flush_shard() - self._reset_cache() - self._write_index() - logger.debug(f'Waiting for all shard files to get uploaded to {self.remote}') - self.executor.shutdown(wait=True) - if self.remote and not self.keep_local: - shutil.rmtree(self.local, ignore_errors=True) + ctx = self.timer.finish # pyright: ignore + + with ctx.flush_last_shard: + if self.new_samples: + self.flush_shard() + self._reset_cache() + + with ctx.save_index: + self._save_index() + + with ctx.shutdown_executor: + logger.debug(f'Waiting for all shard files to get uploaded to {self.remote}') + self.executor.shutdown(wait=True) + + with ctx.remove_local: + if self.remote and not self.keep_local: + shutil.rmtree(self.local, ignore_errors=True) def cancel_future_jobs(self) -> None: """Shutting down the executor and cancel all the pending jobs.""" @@ -315,12 +380,10 @@ def exception_callback(self, future: Future) -> None: Raises: exception: re-raise an exception """ - exception = future.exception() - if exception: - # Set the event to let other pool thread know about the exception - self.event.set() - # re-raise the same exception - raise exception + err = future.exception() + if err: + self.event.set() # Set the event to let other pool thread know about the exception. + raise err # Re-raise the same exception. def __enter__(self) -> Self: """Enter context manager. @@ -328,23 +391,33 @@ def __enter__(self) -> Self: Returns: Self: This object. """ + self.timer = self._get_timer() + self.timer.__enter__() return self - def __exit__(self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], - traceback: Optional[TracebackType]) -> None: + def __exit__(self, + exc_type: Optional[Type[BaseException]] = None, + exc: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None) -> None: """Exit context manager. Args: - exc_type (Type[BaseException], optional): Exc type. - exc (BaseException, optional): Exc. - traceback (TracebackType, optional): Traceback. + exc_type (Type[BaseException], optional): Exception type. Defaults to ``None``. + exc (BaseException, optional): Exception. Defaults to ``None``. + traceback (TracebackType, optional): Traceback. Defaults to ``None``. """ + ctx = self.timer + if self.event.is_set(): # Shutdown the executor and cancel all the pending futures due to exception in one of # the threads. self.cancel_future_jobs() return - self.finish() + + with ctx.finish: # pyright: ignore + self.finish() + + ctx.__exit__() class JointWriter(Writer): @@ -429,10 +502,7 @@ def flush_shard(self) -> None: obj.update(self.get_config()) self.shards.append(obj) - # Execute the task if there is no exception in any of the async threads. - future = self.executor.submit(self.cloud_writer.upload_file, zip_data_basename or - raw_data_basename) - future.add_done_callback(self.exception_callback) + self._upload_file(zip_data_basename or raw_data_basename) class SplitWriter(Writer): @@ -521,12 +591,5 @@ def flush_shard(self) -> None: obj.update(self.get_config()) self.shards.append(obj) - # Execute the task if there is no exception in any of the async threads. - future = self.executor.submit(self.cloud_writer.upload_file, zip_data_basename or - raw_data_basename) - future.add_done_callback(self.exception_callback) - - # Execute the task if there is no exception in any of the async threads. - future = self.executor.submit(self.cloud_writer.upload_file, zip_meta_basename or - raw_meta_basename) - future.add_done_callback(self.exception_callback) + self._upload_file(zip_data_basename or raw_data_basename) + self._upload_file(zip_meta_basename or raw_meta_basename)