Skip to content

Commit

Permalink
str: Refactor to support using a large prime field. Yields a 300x enc…
Browse files Browse the repository at this point in the history
…oding performance increase :)
  • Loading branch information
patniemeyer committed Mar 7, 2024
1 parent 3e3f704 commit 9bb3669
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 90 deletions.
70 changes: 53 additions & 17 deletions str-twincoding/encoding/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import math
import numpy as np
from numpy._typing import NDArray
from tqdm import tqdm


Expand All @@ -13,34 +14,51 @@

# Base for classes that read chunks from a single input file.
class ChunkReader:
def __init__(self, path: str, chunk_size: int):
def __init__(self, path: str, num_elements: int, element_size: int):
self.path = path
self.chunk_size = chunk_size

self.num_elements = num_elements
self.element_size = element_size
self.chunk_size = num_elements * element_size

self.file_length = os.path.getsize(path)

self.num_chunks = math.ceil(self.file_length / self.chunk_size)
self.mmap = None

# Returns an ndarray that shares memory with the mmap.
# The final chunk may be padded with zeros to fill the chunk size.
def get_chunk(self, i: int) -> np.ndarray:
# Returns an ndarray of shape (num_elements, elements_size) bytes that shares memory with the mmap.
# The final chunk will be padded with zeros if required to fill the chunk size.
def get_chunk(self, i: int):
if self.mmap is None:
self.mmap = np.memmap(self.path, dtype='uint8', mode='r')

start_idx = i * self.chunk_size
end_idx = (i + 1) * self.chunk_size
end_idx = start_idx + self.chunk_size

if start_idx >= self.file_length:
raise IndexError("Start index is out of bounds.")

# Ensure end_idx does not exceed the file length.
end_idx = min(end_idx, self.file_length)
chunk = self.mmap[start_idx:end_idx] # inclusive:exclusive

# Read the data chunk.
chunk = self.mmap[start_idx:end_idx]

# Pad if necessary.
if end_idx - start_idx < self.chunk_size:
padding_length = self.chunk_size - (end_idx - start_idx)
chunk = np.concatenate((chunk, np.zeros(padding_length, dtype=chunk.dtype)))
chunk = np.concatenate((chunk, np.zeros(padding_length, dtype='uint8')))

# Reshape the chunk to have the correct number of elements.
chunk = chunk.reshape((self.num_elements, self.element_size))

return chunk

# Returns an ndarray of num_elements where the elements are transformed to (possibly very large) ints.
def get_chunk_ints(self, i: int):
chunk = self.get_chunk(i)
return np.array([int.from_bytes(bytes=element, byteorder='big') for element in chunk])

def update_pbar(self, ci: int, pbar: tqdm, start: float):
rate = ci * self.chunk_size / (time.time() - start)
pbar.set_postfix({"Rate": f"{rate / (1024 * 1024):.4f}MB/s"}, refresh=True)
Expand All @@ -49,20 +67,25 @@ def update_pbar(self, ci: int, pbar: tqdm, start: float):

# Base for classes that read chunks from multiple files in parallel.
class ChunksReader:
def __init__(self,
file_map: dict[str, int],
chunk_size: int):
def __init__(self, file_map: dict[str, int], num_elements: int, element_size: int):
self.files: [str] = list(file_map.keys())
self.files_indices: [int] = list(file_map.values())
self.chunk_size = chunk_size

self.num_elements = num_elements
self.element_size = element_size
self.chunk_size = num_elements * element_size

self.num_chunks = self.validate_files(
files=self.files, file_indices=self.files_indices, chunk_size=chunk_size)
files=self.files, file_indices=self.files_indices, chunk_size=self.chunk_size)
self.mmaps = None

# Validate that the files are of the same length and are a multiple of chunk_size.
# Return the number of chunks.
@staticmethod
def validate_files(files: list[str], file_indices: [int], chunk_size: int) -> int:

print("chunk size:", chunk_size)

file_sizes = [os.path.getsize(file_path) for file_path in files]

# Check if all files are the same length
Expand All @@ -80,9 +103,10 @@ def validate_files(files: list[str], file_indices: [int], chunk_size: int) -> in

return file_size // chunk_size

# Read data chunks at index i from each of the (k) files.
# The files have been previously validated to be a multiple of chunk_size.
def get_chunks(self, i: int) -> [np.ndarray]:
# Read data chunks at index i from each of the specified files.
# Return a list of ndarrays, each of shape (num_elements, element_size) bytes.
# The files are expected to be a multiple of chunk_size.
def get_chunks(self, i: int) -> [NDArray]:
if i >= self.num_chunks:
raise IndexError("chunk index is out of bounds.")

Expand All @@ -91,7 +115,19 @@ def get_chunks(self, i: int) -> [np.ndarray]:

start_idx = i * self.chunk_size
end_idx = (i + 1) * self.chunk_size
return [mmap[start_idx:end_idx] for mmap in self.mmaps]

# The files have been previously validated to be a multiple of chunk_size.
file_chunks = [mmap[start_idx:end_idx] for mmap in self.mmaps]

# Reshape the chunk to have the correct number of elements.
return [chunk.reshape((self.num_elements, self.element_size)) for chunk in file_chunks]

# Read data chunks at index i from each of the (k) files.
# Return a list of ndarrays, each containing num_elements (possibly very large) ints.
def get_chunks_ints(self, i: int) -> [NDArray]:
file_chunks = self.get_chunks(i)
return [np.array([int.from_bytes(bytes=element, byteorder='big') for element in chunk])
for chunk in file_chunks]

def update_pbar(self, ci: int, num_files: int, pbar: tqdm, start: float):
rate = ci * self.chunk_size * num_files / (time.time() - start)
Expand Down
46 changes: 46 additions & 0 deletions str-twincoding/encoding/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import galois
import numpy as np
from galois import FieldArray
from numpy._typing import NDArray

# The largest prime that fits in 256 bits / 32 bytes
p = 2 ** 256 - 189

# For very large primes galois can take a long time to infer the primitive element; Specifying it is much faster.
primitive_element = 2

# The field scalars are read at the rounded down byte size (to guarantee that they remain less than p).
FIELD_SAFE_SCALAR_SIZE_BYTES: int = 31

# The stored size of the element is the rounded up byte size.
FIELD_ELEMENT_SIZE_BYTES: int = 32


# Initialize the Galois field object for the order of the field used in the BLS12-381 curve.
def get_field():
# Order / characteristic is q and degree is 1 for a prime field
return galois.GF(p, 1, primitive_element=primitive_element, verify=False)


def symbol_to_bytes(
symbol: FieldArray,
element_size: int,
) -> bytes:
return int(symbol).to_bytes(element_size, byteorder='big')


# Take the list or array of GF symbols and render them to a list of byte strings, each of length element_size.
def symbols_to_bytes_list(
symbols: list[FieldArray] | NDArray[FieldArray],
element_size: int,
) -> list[bytes]:
return [symbol_to_bytes(el, element_size) for el in symbols]


# Take the list or array of GF symbols and render them to a flattened byte string of
# length len(symbols) * element_size.
def symbols_to_bytes(
symbols: list[FieldArray] | NDArray[FieldArray],
element_size: int,
) -> bytes:
return b''.join(symbols_to_bytes_list(symbols, element_size))
51 changes: 33 additions & 18 deletions str-twincoding/encoding/file_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import time
import uuid
from collections import OrderedDict
import galois
from typing import Any

import numpy as np
from galois import FieldArray
from numpy import ndarray
from numpy._typing import NDArray
from tqdm import tqdm

from encoding.fields import get_field, FIELD_SAFE_SCALAR_SIZE_BYTES, FIELD_ELEMENT_SIZE_BYTES, symbols_to_bytes
from storage.storage_model import NodeType, EncodedFile
from storage.repository import Repository

Expand Down Expand Up @@ -44,16 +49,20 @@ def __init__(self,

self.output_path = output_path or f"decoded_{uuid.uuid4()}.dat"
self.overwrite = overwrite
assert org_file_length is not None
self.org_file_length = org_file_length

chunk_size = self.k # individual columns of size k
super().__init__(file_map=file_map, chunk_size=chunk_size)
num_elements = self.k
super().__init__(file_map=file_map,
num_elements=num_elements,
element_size=FIELD_ELEMENT_SIZE_BYTES)
# print(f"num_elements = 'k' = {num_elements}, element size = {self.element_size}")

# Init a file decoder from an encoded file dir. The dir must contain a config.json file and
# at least k files of the same type.
@staticmethod
def from_encoded_dir(path: str, output_path: str = None, overwrite: bool = False):
file_config = EncodedFile.load(os.path.join(path, 'config.json'))
file_config = EncodedFile.load(os.path.join(path, 'config.json'))
assert file_config.type0.k == file_config.type1.k, "Config node types must have the same k."
recover_from_files = FileDecoder.get_threshold_files(path, k=file_config.type0.k)
if os.path.basename(list(recover_from_files)[0]).startswith("type0_"):
Expand Down Expand Up @@ -86,9 +95,9 @@ def get_threshold_files(cls, files_dir: str, k: int) -> dict[str, int]:

# Decode the file to the output path.
def decode(self):
with open_output_file(output_path=self.output_path, overwrite=self.overwrite) as out:
with (open_output_file(output_path=self.output_path, overwrite=self.overwrite) as out):
k, n = self.node_type.k, self.node_type.n
GF = galois.GF(2 ** 8)
GF = get_field()
G = rs_generator_matrix(GF, k=k, n=n)
g = G[:, self.files_indices]
ginv = np.linalg.inv(g)
Expand All @@ -97,34 +106,40 @@ def decode(self):
start = time.time()
with tqdm(total=self.num_chunks, desc='Decoding', unit='chunk') as pbar:
for ci in range(self.num_chunks):
chunks = self.get_chunks(ci)
# list of ndarrays, each containing num_elements big ints.
file_chunks_ints = self.get_chunks_ints(ci)

# Reshape each chunk as a stack of column vectors forming a k x k matrix
matrix = np.hstack([chunk.reshape(-1, 1) for chunk in file_chunks_ints])

# Decode each chunk as a stack of column vectors forming a k x k matrix
matrix = np.hstack([chunk.reshape(-1, 1) for chunk in chunks])
# Decode the original data
decoded = GF(matrix) @ ginv

if self.transpose:
decoded = decoded.T
bytes = decoded.reshape(-1).tobytes()

# Trim the last chunk if it is padded
size = (ci + 1) * self.chunk_size * k
if size > self.org_file_length:
bytes = bytes[:self.org_file_length - size]

# Write the data to the output file
out.write(bytes)
# Flatten the matrix back to an array of symbols
symbols: NDArray[FieldArray] = decoded.reshape(-1)
# Write the data to the output file with each symbol converted to bytes at the original size.
out.write(symbols_to_bytes(symbols, FIELD_SAFE_SCALAR_SIZE_BYTES))

# Progress bar
self.update_pbar(ci=ci, num_files=k, pbar=pbar, start=start)
...
...
...

# Trim the output file to the original file length to account for padding at ingestion time.
with open(self.output_path, 'rb+') as f:
f.truncate(self.org_file_length)

def close(self):
[mm.close() for mm in self.mmaps]


if __name__ == '__main__':
repo = Repository.default()
filename = 'file_1KB.dat'
filename = 'file_1MB.dat'
original_file = repo.tmp_file_path(filename)
encoded_file = repo.file_dir_path(filename)
file_status = repo.file_status(filename)
Expand Down
36 changes: 25 additions & 11 deletions str-twincoding/encoding/file_encoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import hashlib
import os
from contextlib import ExitStack
import galois

from galois import FieldArray
from icecream import ic
import numpy as np
from numpy._typing import NDArray

from encoding.fields import get_field, FIELD_SAFE_SCALAR_SIZE_BYTES, FIELD_ELEMENT_SIZE_BYTES, symbols_to_bytes
from encoding.chunks import ChunkReader
from encoding.twin_coding import rs_generator_matrix, Code, twin_code
from storage.storage_model import EncodedFile, NodeType0, NodeType1
Expand Down Expand Up @@ -37,14 +41,21 @@ def __init__(self,
assert node_type0.n > node_type0.k and node_type1.n > node_type1.k, "The node type must have n > k."

self.node_type0 = node_type0
print(f"node_type0 = {node_type0}")
self.node_type1 = node_type1
print(f"node_type1 = {node_type1}")
self.k = node_type0.k
print(f"k = {self.k}")
self.path = input_file
self.output_dir = output_path or input_file + '.encoded'
self.overwrite = overwrite
self._file_hash = None
chunk_size = self.k ** 2
super().__init__(path=input_file, chunk_size=chunk_size)
num_elements = self.k ** 2
super().__init__(path=input_file,
num_elements=num_elements,
element_size=FIELD_SAFE_SCALAR_SIZE_BYTES)
print(f"element size = {self.element_size}, num_elements = 'k^2' = {num_elements}")
print(f"FileEncoder: chunk size = {self.chunk_size}, num_chunks = {self.num_chunks}")

# Initialize the output directory that will hold the erasure-encoded chunks.
def init_output_dir(self) -> bool:
Expand Down Expand Up @@ -73,7 +84,7 @@ def encode(self):
return

# The symbol space
GF = galois.GF(2 ** 8)
GF = get_field()

# The two coding schemes.
k, n0, n1 = self.k, self.node_type0.n, self.node_type1.n
Expand All @@ -97,15 +108,18 @@ def encode(self):
start = time.time()
with tqdm(total=self.num_chunks, desc='Encoding', unit='chunk') as pbar:
for ci in range(self.num_chunks):
# Twin code the chunk
chunk = self.get_chunk(ci)
cols0, cols1 = twin_code(GF(chunk), C0, C1)
# Get the next chunk, converting each element to a big integer
chunk_ints: NDArray[int] = self.get_chunk_ints(ci)

# Twin code the chunk (returns two lists of ndarray of symbols)
cols0, cols1 = twin_code(GF(chunk_ints), C0, C1)

# Write the data to the respective files
# print(f"Writing chunk {ci} to files.")
for fi in range(n0):
files0[fi].write(cols0[fi].tobytes())
files0[fi].write(symbols_to_bytes(cols0[fi], FIELD_ELEMENT_SIZE_BYTES))
for fi in range(n1):
files1[fi].write(cols1[fi].tobytes())
files1[fi].write(symbols_to_bytes(cols1[fi], FIELD_ELEMENT_SIZE_BYTES))

self.update_pbar(ci=ci, pbar=pbar, start=start)
...
Expand All @@ -128,13 +142,13 @@ def close(self):
repo = Repository.default()

# Random test file
filename = 'file_1KB.dat'
filename = 'file_1MB.dat'
file = repo.tmp_file_path(filename)
ic(file)
# If the file doesn't exist create it
if not os.path.exists(file):
with open(file, "wb") as f:
f.write(os.urandom(1024))
f.write(os.urandom(1 * 1024 * 1024))

encoder = FileEncoder(
node_type0=NodeType0(k=3, n=5, encoding='reed_solomon'),
Expand Down
Loading

0 comments on commit 9bb3669

Please sign in to comment.