From 9bb36693d910c82251b9450b21737712f58b944d Mon Sep 17 00:00:00 2001 From: Patrick Niemeyer Date: Thu, 7 Mar 2024 17:39:47 -0600 Subject: [PATCH] str: Refactor to support using a large prime field. Yields a 300x encoding performance increase :) --- str-twincoding/encoding/chunks.py | 70 ++++++++++---- str-twincoding/encoding/fields.py | 46 +++++++++ str-twincoding/encoding/file_decoder.py | 51 ++++++---- str-twincoding/encoding/file_encoder.py | 36 ++++--- .../encoding/gf_file_rountrip_test.py | 93 +++++++++++++++++++ .../encoding/node_recovery_client.py | 71 +++++++------- .../encoding/node_recovery_source.py | 17 ++-- str-twincoding/encoding/twin_coding.py | 18 +++- 8 files changed, 312 insertions(+), 90 deletions(-) create mode 100644 str-twincoding/encoding/fields.py create mode 100644 str-twincoding/encoding/gf_file_rountrip_test.py diff --git a/str-twincoding/encoding/chunks.py b/str-twincoding/encoding/chunks.py index 8aae00f15..774f3a22e 100644 --- a/str-twincoding/encoding/chunks.py +++ b/str-twincoding/encoding/chunks.py @@ -5,6 +5,7 @@ import math import numpy as np +from numpy._typing import NDArray from tqdm import tqdm @@ -13,34 +14,51 @@ # Base for classes that read chunks from a single input file. class ChunkReader: - def __init__(self, path: str, chunk_size: int): + def __init__(self, path: str, num_elements: int, element_size: int): self.path = path - self.chunk_size = chunk_size + + self.num_elements = num_elements + self.element_size = element_size + self.chunk_size = num_elements * element_size + self.file_length = os.path.getsize(path) + self.num_chunks = math.ceil(self.file_length / self.chunk_size) self.mmap = None - # Returns an ndarray that shares memory with the mmap. - # The final chunk may be padded with zeros to fill the chunk size. - def get_chunk(self, i: int) -> np.ndarray: + # Returns an ndarray of shape (num_elements, elements_size) bytes that shares memory with the mmap. + # The final chunk will be padded with zeros if required to fill the chunk size. + def get_chunk(self, i: int): if self.mmap is None: self.mmap = np.memmap(self.path, dtype='uint8', mode='r') start_idx = i * self.chunk_size - end_idx = (i + 1) * self.chunk_size + end_idx = start_idx + self.chunk_size if start_idx >= self.file_length: raise IndexError("Start index is out of bounds.") + # Ensure end_idx does not exceed the file length. end_idx = min(end_idx, self.file_length) - chunk = self.mmap[start_idx:end_idx] # inclusive:exclusive + # Read the data chunk. + chunk = self.mmap[start_idx:end_idx] + + # Pad if necessary. if end_idx - start_idx < self.chunk_size: padding_length = self.chunk_size - (end_idx - start_idx) - chunk = np.concatenate((chunk, np.zeros(padding_length, dtype=chunk.dtype))) + chunk = np.concatenate((chunk, np.zeros(padding_length, dtype='uint8'))) + + # Reshape the chunk to have the correct number of elements. + chunk = chunk.reshape((self.num_elements, self.element_size)) return chunk + # Returns an ndarray of num_elements where the elements are transformed to (possibly very large) ints. + def get_chunk_ints(self, i: int): + chunk = self.get_chunk(i) + return np.array([int.from_bytes(bytes=element, byteorder='big') for element in chunk]) + def update_pbar(self, ci: int, pbar: tqdm, start: float): rate = ci * self.chunk_size / (time.time() - start) pbar.set_postfix({"Rate": f"{rate / (1024 * 1024):.4f}MB/s"}, refresh=True) @@ -49,20 +67,25 @@ def update_pbar(self, ci: int, pbar: tqdm, start: float): # Base for classes that read chunks from multiple files in parallel. class ChunksReader: - def __init__(self, - file_map: dict[str, int], - chunk_size: int): + def __init__(self, file_map: dict[str, int], num_elements: int, element_size: int): self.files: [str] = list(file_map.keys()) self.files_indices: [int] = list(file_map.values()) - self.chunk_size = chunk_size + + self.num_elements = num_elements + self.element_size = element_size + self.chunk_size = num_elements * element_size + self.num_chunks = self.validate_files( - files=self.files, file_indices=self.files_indices, chunk_size=chunk_size) + files=self.files, file_indices=self.files_indices, chunk_size=self.chunk_size) self.mmaps = None # Validate that the files are of the same length and are a multiple of chunk_size. # Return the number of chunks. @staticmethod def validate_files(files: list[str], file_indices: [int], chunk_size: int) -> int: + + print("chunk size:", chunk_size) + file_sizes = [os.path.getsize(file_path) for file_path in files] # Check if all files are the same length @@ -80,9 +103,10 @@ def validate_files(files: list[str], file_indices: [int], chunk_size: int) -> in return file_size // chunk_size - # Read data chunks at index i from each of the (k) files. - # The files have been previously validated to be a multiple of chunk_size. - def get_chunks(self, i: int) -> [np.ndarray]: + # Read data chunks at index i from each of the specified files. + # Return a list of ndarrays, each of shape (num_elements, element_size) bytes. + # The files are expected to be a multiple of chunk_size. + def get_chunks(self, i: int) -> [NDArray]: if i >= self.num_chunks: raise IndexError("chunk index is out of bounds.") @@ -91,7 +115,19 @@ def get_chunks(self, i: int) -> [np.ndarray]: start_idx = i * self.chunk_size end_idx = (i + 1) * self.chunk_size - return [mmap[start_idx:end_idx] for mmap in self.mmaps] + + # The files have been previously validated to be a multiple of chunk_size. + file_chunks = [mmap[start_idx:end_idx] for mmap in self.mmaps] + + # Reshape the chunk to have the correct number of elements. + return [chunk.reshape((self.num_elements, self.element_size)) for chunk in file_chunks] + + # Read data chunks at index i from each of the (k) files. + # Return a list of ndarrays, each containing num_elements (possibly very large) ints. + def get_chunks_ints(self, i: int) -> [NDArray]: + file_chunks = self.get_chunks(i) + return [np.array([int.from_bytes(bytes=element, byteorder='big') for element in chunk]) + for chunk in file_chunks] def update_pbar(self, ci: int, num_files: int, pbar: tqdm, start: float): rate = ci * self.chunk_size * num_files / (time.time() - start) diff --git a/str-twincoding/encoding/fields.py b/str-twincoding/encoding/fields.py new file mode 100644 index 000000000..a6cbb2666 --- /dev/null +++ b/str-twincoding/encoding/fields.py @@ -0,0 +1,46 @@ +import galois +import numpy as np +from galois import FieldArray +from numpy._typing import NDArray + +# The largest prime that fits in 256 bits / 32 bytes +p = 2 ** 256 - 189 + +# For very large primes galois can take a long time to infer the primitive element; Specifying it is much faster. +primitive_element = 2 + +# The field scalars are read at the rounded down byte size (to guarantee that they remain less than p). +FIELD_SAFE_SCALAR_SIZE_BYTES: int = 31 + +# The stored size of the element is the rounded up byte size. +FIELD_ELEMENT_SIZE_BYTES: int = 32 + + +# Initialize the Galois field object for the order of the field used in the BLS12-381 curve. +def get_field(): + # Order / characteristic is q and degree is 1 for a prime field + return galois.GF(p, 1, primitive_element=primitive_element, verify=False) + + +def symbol_to_bytes( + symbol: FieldArray, + element_size: int, +) -> bytes: + return int(symbol).to_bytes(element_size, byteorder='big') + + +# Take the list or array of GF symbols and render them to a list of byte strings, each of length element_size. +def symbols_to_bytes_list( + symbols: list[FieldArray] | NDArray[FieldArray], + element_size: int, +) -> list[bytes]: + return [symbol_to_bytes(el, element_size) for el in symbols] + + +# Take the list or array of GF symbols and render them to a flattened byte string of +# length len(symbols) * element_size. +def symbols_to_bytes( + symbols: list[FieldArray] | NDArray[FieldArray], + element_size: int, +) -> bytes: + return b''.join(symbols_to_bytes_list(symbols, element_size)) diff --git a/str-twincoding/encoding/file_decoder.py b/str-twincoding/encoding/file_decoder.py index 76eccdc15..4f2947d92 100644 --- a/str-twincoding/encoding/file_decoder.py +++ b/str-twincoding/encoding/file_decoder.py @@ -3,10 +3,15 @@ import time import uuid from collections import OrderedDict -import galois +from typing import Any + import numpy as np +from galois import FieldArray +from numpy import ndarray +from numpy._typing import NDArray from tqdm import tqdm +from encoding.fields import get_field, FIELD_SAFE_SCALAR_SIZE_BYTES, FIELD_ELEMENT_SIZE_BYTES, symbols_to_bytes from storage.storage_model import NodeType, EncodedFile from storage.repository import Repository @@ -44,16 +49,20 @@ def __init__(self, self.output_path = output_path or f"decoded_{uuid.uuid4()}.dat" self.overwrite = overwrite + assert org_file_length is not None self.org_file_length = org_file_length - chunk_size = self.k # individual columns of size k - super().__init__(file_map=file_map, chunk_size=chunk_size) + num_elements = self.k + super().__init__(file_map=file_map, + num_elements=num_elements, + element_size=FIELD_ELEMENT_SIZE_BYTES) + # print(f"num_elements = 'k' = {num_elements}, element size = {self.element_size}") # Init a file decoder from an encoded file dir. The dir must contain a config.json file and # at least k files of the same type. @staticmethod def from_encoded_dir(path: str, output_path: str = None, overwrite: bool = False): - file_config = EncodedFile.load(os.path.join(path, 'config.json')) + file_config = EncodedFile.load(os.path.join(path, 'config.json')) assert file_config.type0.k == file_config.type1.k, "Config node types must have the same k." recover_from_files = FileDecoder.get_threshold_files(path, k=file_config.type0.k) if os.path.basename(list(recover_from_files)[0]).startswith("type0_"): @@ -86,9 +95,9 @@ def get_threshold_files(cls, files_dir: str, k: int) -> dict[str, int]: # Decode the file to the output path. def decode(self): - with open_output_file(output_path=self.output_path, overwrite=self.overwrite) as out: + with (open_output_file(output_path=self.output_path, overwrite=self.overwrite) as out): k, n = self.node_type.k, self.node_type.n - GF = galois.GF(2 ** 8) + GF = get_field() G = rs_generator_matrix(GF, k=k, n=n) g = G[:, self.files_indices] ginv = np.linalg.inv(g) @@ -97,34 +106,40 @@ def decode(self): start = time.time() with tqdm(total=self.num_chunks, desc='Decoding', unit='chunk') as pbar: for ci in range(self.num_chunks): - chunks = self.get_chunks(ci) + # list of ndarrays, each containing num_elements big ints. + file_chunks_ints = self.get_chunks_ints(ci) + + # Reshape each chunk as a stack of column vectors forming a k x k matrix + matrix = np.hstack([chunk.reshape(-1, 1) for chunk in file_chunks_ints]) - # Decode each chunk as a stack of column vectors forming a k x k matrix - matrix = np.hstack([chunk.reshape(-1, 1) for chunk in chunks]) + # Decode the original data decoded = GF(matrix) @ ginv + if self.transpose: decoded = decoded.T - bytes = decoded.reshape(-1).tobytes() - # Trim the last chunk if it is padded - size = (ci + 1) * self.chunk_size * k - if size > self.org_file_length: - bytes = bytes[:self.org_file_length - size] - - # Write the data to the output file - out.write(bytes) + # Flatten the matrix back to an array of symbols + symbols: NDArray[FieldArray] = decoded.reshape(-1) + # Write the data to the output file with each symbol converted to bytes at the original size. + out.write(symbols_to_bytes(symbols, FIELD_SAFE_SCALAR_SIZE_BYTES)) # Progress bar self.update_pbar(ci=ci, num_files=k, pbar=pbar, start=start) + ... + ... ... + # Trim the output file to the original file length to account for padding at ingestion time. + with open(self.output_path, 'rb+') as f: + f.truncate(self.org_file_length) + def close(self): [mm.close() for mm in self.mmaps] if __name__ == '__main__': repo = Repository.default() - filename = 'file_1KB.dat' + filename = 'file_1MB.dat' original_file = repo.tmp_file_path(filename) encoded_file = repo.file_dir_path(filename) file_status = repo.file_status(filename) diff --git a/str-twincoding/encoding/file_encoder.py b/str-twincoding/encoding/file_encoder.py index 8bcdfa2e3..1b058db7d 100644 --- a/str-twincoding/encoding/file_encoder.py +++ b/str-twincoding/encoding/file_encoder.py @@ -1,9 +1,13 @@ import hashlib import os from contextlib import ExitStack -import galois + +from galois import FieldArray from icecream import ic +import numpy as np +from numpy._typing import NDArray +from encoding.fields import get_field, FIELD_SAFE_SCALAR_SIZE_BYTES, FIELD_ELEMENT_SIZE_BYTES, symbols_to_bytes from encoding.chunks import ChunkReader from encoding.twin_coding import rs_generator_matrix, Code, twin_code from storage.storage_model import EncodedFile, NodeType0, NodeType1 @@ -37,14 +41,21 @@ def __init__(self, assert node_type0.n > node_type0.k and node_type1.n > node_type1.k, "The node type must have n > k." self.node_type0 = node_type0 + print(f"node_type0 = {node_type0}") self.node_type1 = node_type1 + print(f"node_type1 = {node_type1}") self.k = node_type0.k + print(f"k = {self.k}") self.path = input_file self.output_dir = output_path or input_file + '.encoded' self.overwrite = overwrite self._file_hash = None - chunk_size = self.k ** 2 - super().__init__(path=input_file, chunk_size=chunk_size) + num_elements = self.k ** 2 + super().__init__(path=input_file, + num_elements=num_elements, + element_size=FIELD_SAFE_SCALAR_SIZE_BYTES) + print(f"element size = {self.element_size}, num_elements = 'k^2' = {num_elements}") + print(f"FileEncoder: chunk size = {self.chunk_size}, num_chunks = {self.num_chunks}") # Initialize the output directory that will hold the erasure-encoded chunks. def init_output_dir(self) -> bool: @@ -73,7 +84,7 @@ def encode(self): return # The symbol space - GF = galois.GF(2 ** 8) + GF = get_field() # The two coding schemes. k, n0, n1 = self.k, self.node_type0.n, self.node_type1.n @@ -97,15 +108,18 @@ def encode(self): start = time.time() with tqdm(total=self.num_chunks, desc='Encoding', unit='chunk') as pbar: for ci in range(self.num_chunks): - # Twin code the chunk - chunk = self.get_chunk(ci) - cols0, cols1 = twin_code(GF(chunk), C0, C1) + # Get the next chunk, converting each element to a big integer + chunk_ints: NDArray[int] = self.get_chunk_ints(ci) + + # Twin code the chunk (returns two lists of ndarray of symbols) + cols0, cols1 = twin_code(GF(chunk_ints), C0, C1) # Write the data to the respective files + # print(f"Writing chunk {ci} to files.") for fi in range(n0): - files0[fi].write(cols0[fi].tobytes()) + files0[fi].write(symbols_to_bytes(cols0[fi], FIELD_ELEMENT_SIZE_BYTES)) for fi in range(n1): - files1[fi].write(cols1[fi].tobytes()) + files1[fi].write(symbols_to_bytes(cols1[fi], FIELD_ELEMENT_SIZE_BYTES)) self.update_pbar(ci=ci, pbar=pbar, start=start) ... @@ -128,13 +142,13 @@ def close(self): repo = Repository.default() # Random test file - filename = 'file_1KB.dat' + filename = 'file_1MB.dat' file = repo.tmp_file_path(filename) ic(file) # If the file doesn't exist create it if not os.path.exists(file): with open(file, "wb") as f: - f.write(os.urandom(1024)) + f.write(os.urandom(1 * 1024 * 1024)) encoder = FileEncoder( node_type0=NodeType0(k=3, n=5, encoding='reed_solomon'), diff --git a/str-twincoding/encoding/gf_file_rountrip_test.py b/str-twincoding/encoding/gf_file_rountrip_test.py new file mode 100644 index 000000000..ec36c132d --- /dev/null +++ b/str-twincoding/encoding/gf_file_rountrip_test.py @@ -0,0 +1,93 @@ +import filecmp +import os + +from galois import FieldArray +from icecream import ic + +from encoding.fields import get_field, FIELD_SAFE_SCALAR_SIZE_BYTES, symbols_to_bytes +from encoding.chunks import ChunkReader +from storage.repository import Repository +from tqdm import tqdm +import time + + +# Test the round trip of a file through the Galois Field symbol encoding and decoding. +class GFFileRoundtripTest(ChunkReader): + def __init__(self, input_file: str, output_path: str = None): + self.path = input_file + self.output_path = output_path + self.k = 3 # example + num_elements = self.k ** 2 + super().__init__(path=input_file, + num_elements=num_elements, + element_size=FIELD_SAFE_SCALAR_SIZE_BYTES) + print(f"element size = {self.element_size}, num_elements = 'k^2' = {num_elements}") + print(f"FileEncoder: chunk size = {self.chunk_size}, num_chunks = {self.num_chunks}") + + # Encode the file to the output path + def encode(self): + # The symbol space + GF = get_field() + + # Round trip chunks to field elements and back + with open(self.output_path, 'wb') as outfile: + start = time.time() + with tqdm(total=self.num_chunks, desc='Encoding', unit='chunk') as pbar: + for ci in range(self.num_chunks): + # Get a chunk as ints + chunk_ints = self.get_chunk_ints(ci) + + # Convert to symbols + chunk_gf: FieldArray = GF(chunk_ints) + + # Write the data back to the output file + outfile.write(symbols_to_bytes(chunk_gf, FIELD_SAFE_SCALAR_SIZE_BYTES)) + + self.update_pbar(ci=ci, pbar=pbar, start=start) + ... + ... + ... + # Trim the output file to the original file length to account for padding at ingestion time. + with open(self.output_path, 'rb+') as f: + # input file length + f.truncate(self.file_length) + + +def test_gf_file_roundtrip(): + repo = Repository.default() + + # Random test file + filename = 'file_1MB.dat' + file = repo.tmp_file_path(filename) + outpath = repo.tmp_file_path('gf_test.dat') + ic(file, outpath) + # If the file doesn't exist create it + if not os.path.exists(file): + with open(file, "wb") as f: + f.write(os.urandom(1 * 1024 * 1024)) + + encoder = GFFileRoundtripTest( + input_file=file, + output_path=outpath, + ) + encoder.encode() + print("Passed" if filecmp.cmp(file, outpath) else "Failed") + + +def test_simple(): + # generate 31 random bytes + bytes = os.urandom(31) + + # convert to hex + print(f"Bytes: {len(bytes)}, {bytes.hex()}") + GF = get_field() + el_int = int.from_bytes(bytes=bytes, byteorder='big') + el_gf: FieldArray = GF(el_int) + print(f"Element: {el_gf}") + print("back to int: ", int(el_gf)) + print("back to bytes: ", int(el_gf).to_bytes(31, byteorder='big').hex()) + + +if __name__ == '__main__': + test_simple() + test_gf_file_roundtrip() diff --git a/str-twincoding/encoding/node_recovery_client.py b/str-twincoding/encoding/node_recovery_client.py index 98c07c69d..7ad20fcd5 100644 --- a/str-twincoding/encoding/node_recovery_client.py +++ b/str-twincoding/encoding/node_recovery_client.py @@ -4,11 +4,12 @@ import time import uuid from collections import OrderedDict -import galois import numpy as np +from numpy._typing import NDArray from tqdm import tqdm from encoding.chunks import ChunksReader, open_output_file +from encoding.fields import FIELD_ELEMENT_SIZE_BYTES, get_field, symbols_to_bytes from encoding.twin_coding import rs_generator_matrix from storage.storage_model import NodeType, NodeType1 from storage.repository import Repository @@ -37,8 +38,10 @@ def __init__(self, self.output_path = output_path or f"decoded_{uuid.uuid4()}.dat" self.overwrite = overwrite - # chunk size is 1 symbol (byte) from each file - super().__init__(file_map=file_map, chunk_size=1) + # Chunk size is one symbol from each file + super().__init__(file_map=file_map, + num_elements=1, + element_size=FIELD_ELEMENT_SIZE_BYTES) # Map recovery files in a directory. Exactly k recovery files should be present. @staticmethod @@ -61,13 +64,13 @@ def map_files(files_dir: str, def recover_node(self): print(f"Recovering node to: {self.output_path}") - GF = galois.GF(2 ** 8) + GF = get_field() G = rs_generator_matrix(GF, self.recovery_source_node_type.k, self.recovery_source_node_type.n) with open_output_file(output_path=self.output_path, overwrite=self.overwrite) as out: start = time.time() with tqdm(total=self.num_chunks, desc='Recovery', unit='chunk') as pbar: for ci in range(self.num_chunks): - chunks: [np.ndarray] = self.get_chunks(ci) + chunks: [NDArray[int]] = self.get_chunks_ints(ci) # turn chunks into a single column vector col = GF(np.concatenate(chunks)) @@ -75,10 +78,10 @@ def recover_node(self): # recovery source node type's encoding matrix. g = G[:, self.files_indices] ginv = np.linalg.inv(g) - recovered = (col @ ginv).tobytes() + recovered = col @ ginv - # Write the data to the output file - out.write(recovered) + # Write the data (recovered symbols) to the output file at encoded field size. + out.write(symbols_to_bytes(recovered, FIELD_ELEMENT_SIZE_BYTES)) # Progress bar self.update_pbar(ci=ci, num_files=self.k, pbar=pbar, start=start) @@ -86,27 +89,31 @@ def recover_node(self): if __name__ == '__main__': - file = 'file_1KB.dat' - repo = Repository.default() - - # Use recovery files generated for type 1 node index 0 to recover the lost data shard. - recovery_files_dir = repo.file_dir_path(file) - recover_node_type = 1 - recover_node_index = 0 - recovered_shard = repo.tmp_file_path( - f'recovered_{file}_type{recover_node_type}_node{recover_node_index}.dat') - - NodeRecoveryClient( - recovery_source_node_type=NodeType1(k=3, n=5, encoding='reed_solomon'), - file_map=NodeRecoveryClient.map_files( - files_dir=recovery_files_dir, - recover_node_type=recover_node_type, - recover_node_index=recover_node_index, - k=3), - output_path=recovered_shard, - overwrite=True - ).recover_node() - - original_shard = repo.shard_path(file, node_type=1, node_index=0) - print("Passed" if filecmp.cmp(original_shard, recovered_shard) else "Failed") - ... + def main(): + file = 'file_1MB.dat' + repo = Repository.default() + + # Use recovery files generated for type 1 node index 0 to recover the lost data shard. + recovery_files_dir = repo.file_dir_path(file) + recover_node_type = 1 + recover_node_index = 0 + recovered_shard = repo.tmp_file_path( + f'recovered_{file}_type{recover_node_type}_node{recover_node_index}.dat') + + NodeRecoveryClient( + recovery_source_node_type=NodeType1(k=3, n=5, encoding='reed_solomon'), + file_map=NodeRecoveryClient.map_files( + files_dir=recovery_files_dir, + recover_node_type=recover_node_type, + recover_node_index=recover_node_index, + k=3), + output_path=recovered_shard, + overwrite=True + ).recover_node() + + original_shard = repo.shard_path(file, node_type=1, node_index=0) + print("Passed" if filecmp.cmp(original_shard, recovered_shard) else "Failed") + ... + + + main() diff --git a/str-twincoding/encoding/node_recovery_source.py b/str-twincoding/encoding/node_recovery_source.py index 74acde403..8f1d4b714 100644 --- a/str-twincoding/encoding/node_recovery_source.py +++ b/str-twincoding/encoding/node_recovery_source.py @@ -1,8 +1,7 @@ import time -import galois -from icecream import ic from tqdm import tqdm +from encoding.fields import FIELD_ELEMENT_SIZE_BYTES, get_field, symbol_to_bytes from storage.renderable import Renderable from storage.storage_model import NodeType from storage.repository import Repository @@ -33,7 +32,11 @@ def __init__( output_path: str = None, overwrite: bool = False ): - super().__init__(path=data_path, chunk_size=recover_node_type.k) + super().__init__(path=data_path, + num_elements=recover_node_type.k, + # The encoded shard source file contains elements of the full field size. + element_size=FIELD_ELEMENT_SIZE_BYTES) + recover_node_type.assert_reed_solomon() self.recover_node_type = recover_node_type assert recover_node_index < recover_node_type.n, "Recover node index must be less than n." @@ -87,7 +90,7 @@ def for_repo( # Generate the node recovery file for the client node def render(self): - GF = galois.GF(2 ** 8) + GF = get_field() # The encoding vector of the failed node is the i'th column of the generator matrix of its type. G = rs_generator_matrix(GF, self.recover_node_type.k, self.recover_node_type.n) encoding_vector = G[:, self.recover_node_index] @@ -95,9 +98,9 @@ def render(self): start = time.time() with tqdm(total=self.num_chunks, desc='Gen Recovery', unit='chunk') as pbar: for ci in range(self.num_chunks): - chunk = GF(self.get_chunk(ci)) + chunk = GF(self.get_chunk_ints(ci)) symbol = encoding_vector @ chunk - out.write(symbol) + out.write(symbol_to_bytes(symbol, FIELD_ELEMENT_SIZE_BYTES)) self.update_pbar(ci=ci, pbar=pbar, start=start) ... @@ -109,7 +112,7 @@ def render(self): # --recover_node_index 0 --source_node_index 0 file_1KB.dat def main(): - filename = 'file_1KB.dat' + filename = 'file_1MB.dat' repo = Repository.default() # The node and shard to recover diff --git a/str-twincoding/encoding/twin_coding.py b/str-twincoding/encoding/twin_coding.py index d587ad5c7..7f92d2fa8 100644 --- a/str-twincoding/encoding/twin_coding.py +++ b/str-twincoding/encoding/twin_coding.py @@ -1,5 +1,9 @@ import numpy as np import galois +from galois import FieldArray +from numpy._typing import NDArray + +from encoding.fields import get_field # @@ -35,14 +39,14 @@ # share k but may have different encoded lengths n. # # Return: -# This method returns the two lists of n column vectors to be stored at the +# This method returns the two array of column vectors to be stored at the # respective node types. The two node sets will be different sizes if the codes # have different n values. # # The original paper: # https://www.cs.cmu.edu/~nihars/publications/repairAnyStorageCode_ISIT2011.pdf # -def twin_code(message: np.ndarray, C0: 'Code', C1: 'Code') -> (np.ndarray, np.ndarray): +def twin_code(message: np.ndarray, C0: 'Code', C1: 'Code') -> (NDArray[FieldArray], NDArray[FieldArray]): assert C0.k == C1.k # Reshape the message into k x k matrix @@ -81,7 +85,8 @@ def twin_code(message: np.ndarray, C0: 'Code', C1: 'Code') -> (np.ndarray, np.nd # https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction # def rs_generator_matrix(GF: galois.GF, k: int, n: int): - eval_points = GF.elements[1:n + 1] # n consecutive elements [1, 2, 3, 4, 5] + # produce n consecutive elements of GF + eval_points = [GF(i) for i in range(1, n + 1)] matrix = GF(np.zeros(shape=(k, n), dtype=int)) for row in range(k): for col in range(n): @@ -103,19 +108,22 @@ def __init__(self, GF: galois.GF, k: int, n: int, G: np.matrix): # if __name__ == "__main__": # The symbol space - GF = galois.GF(2 ** 8) + GF = get_field() # The two coding schemes. k = 3 # The schemes share k but may have different encoded lengths n C0 = Code(k=k, n=5, GF=GF, G=rs_generator_matrix(GF, k=k, n=5)) + print("initialized C0") C1 = Code(k=k, n=7, GF=GF, G=rs_generator_matrix(GF, k=k, n=7)) + print("initialized C1") - message = GF([1, 2, 3, 5, 8, 13, 21, 34, 55]) # k^2 = 9 symbols in GF(2^8) + message = GF([1, 2, 3, 5, 8, 13, 21, 34, 55]) # k^2 = 9 symbols in GF(n) print("message:", message) # Twin code the message nodes0, nodes1 = twin_code(message, C0, C1) print(f"{len(nodes0)} type 0 nodes\n{len(nodes1)} type 1 nodes") + print(f"nodes0[0] = {nodes0[0]}, len(nodes0[0]) = {len(nodes0[0])}") # # Simulate regular data collection: Gather data from any k nodes and