-
Notifications
You must be signed in to change notification settings - Fork 228
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Scripts to parse DiskANN SSD index * Removed unnecessary code check and fixed argparse for bool * Added support for multi-sector nodes in the disk index --------- Co-authored-by: Gopal Srinivasa <[email protected]>
- Loading branch information
Showing
5 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import parse_common | ||
import argparse | ||
|
||
def get_data_type_code(data_type_name): | ||
if data_type_name == "float": | ||
return ('f', 4) | ||
elif data_type_name == "int8": | ||
return ('b', 1) | ||
elif data_type_name == "uint8": | ||
return ('B', 1) | ||
else: | ||
raise Exception("Only float, int8 and uint8 are supported.") | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Parse a file in .bin format") | ||
parser.add_argument("filename", help="The vector/matrix file to parse") | ||
parser.add_argument("data_type", help="Type of data in the vector file. Only float, int8 and uint8 are supported.") | ||
parser.add_argument("output_file", help="The file to write the parsed data to") | ||
args = parser.parse_args() | ||
|
||
data_type_code, data_type_size = get_data_type_code(args.data_type) | ||
|
||
datamat = parse_common.DataMat(data_type_code, data_type_size) | ||
datamat.load_bin(args.filename) | ||
|
||
with open(args.output_file, "w") as out_file: | ||
for i in range(len(datamat)): | ||
out_file.write(str(datamat[i].tolist()) + "\n") | ||
|
||
print("Parsed " + str(len(datamat)) + " vectors from " + args.filename + " and wrote output to " + args.output_file) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""Read a DiskANN index """ | ||
import argparse | ||
import struct | ||
import array | ||
|
||
import parse_disk_index as pdi | ||
import parse_pq as ppq | ||
|
||
|
||
def main(index_path_prefix, data_type, output_file_prefix, use_pq_vectors): | ||
data_type_size = 0 | ||
data_type_code = '' | ||
if data_type == "float": | ||
data_type_size = 4 | ||
data_type_code = 'f' | ||
elif data_type == "int8": | ||
data_type_code = 'b' | ||
data_type_size = 1 | ||
elif data_type == "uint8": | ||
data_type_code = 'B' | ||
data_type_size = 1 | ||
else: | ||
raise Exception("Unsupported data type. Supported data types are float, int8 and uint8") | ||
|
||
print(str.format("Parsing DiskANN index at {0} with data type {1} and writing output to {2}. Use PQ vectors: {3}", index_path_prefix, data_type, output_file_prefix, use_pq_vectors)) | ||
|
||
out_disk_index_file = output_file_prefix + "_disk.index.tsv" | ||
out_pq_vectors_file = output_file_prefix + "_compressed_vectors.tsv" | ||
out_pq_pivots_file = output_file_prefix + "_pivots.tsv" | ||
out_pq_chunks_file = output_file_prefix + "_chunk_offsets.tsv" | ||
out_centroids_file = output_file_prefix + "_centroids.tsv" | ||
|
||
print("** Parsing PQ data **") | ||
compressed_vectors = ppq.parse_compressed_vectors(index_path_prefix) | ||
pivots, centroids, chunk_offsets = ppq.parse_pivots_file(index_path_prefix) | ||
|
||
|
||
|
||
with open(out_pq_vectors_file, "w") as out_file: | ||
out_file.write("Id\tvector\n") | ||
for i in range(len(compressed_vectors)): | ||
out_file.write(str(i) + "\t" + str(compressed_vectors[i].tolist()) + "\n") | ||
print(str.format("** Wrote PQ data to file:{} **", out_pq_vectors_file)) | ||
|
||
with open(out_pq_pivots_file, "w") as out_file: | ||
out_file.write("Pivots\n") | ||
for i in range(len(pivots)): | ||
out_file.write(str(pivots[i].tolist()) + "\n") | ||
print(str.format("** Wrote PQ pivots to file:{} **", out_pq_pivots_file)) | ||
|
||
with open(out_centroids_file, "w") as out_file: | ||
out_file.write("Centroids\n") | ||
for i in range(len(centroids)): | ||
out_file.write(str(centroids[i].tolist()) + "\n") | ||
print(str.format("** Wrote PQ centroid data to file:{} **", out_centroids_file)) | ||
|
||
with open(out_pq_chunks_file, "w") as out_file: | ||
out_file.write("Chunk offsets\n") | ||
for i in range(len(chunk_offsets)): | ||
out_file.write(str(chunk_offsets[i].tolist()) + "\n") | ||
print(str.format("** Wrote chunk offsets to file:{} **", out_pq_chunks_file)) | ||
|
||
|
||
if use_pq_vectors: | ||
pdi.parse_index_with_PQ_vectors(index_path_prefix, data_type_code, data_type_size, out_disk_index_file, compressed_vectors) | ||
else: | ||
pdi.parse_index(index_path_prefix, data_type_code, data_type_size, out_disk_index_file) | ||
|
||
print("Parsed DiskANN index and wrote output to " + out_disk_index_file + ", " + out_pq_vectors_file + ", " + out_pq_pivots_file + ", " + out_centroids_file + ", " + out_pq_chunks_file) | ||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='Parse a DiskANN index') | ||
parser.add_argument('index_path_prefix', type=str, help='Path to the DiskANN index file without the extension') | ||
parser.add_argument('data_type', type=str, help='Data type of the vectors in the index. Supported data types are float, int8 and uint8') | ||
parser.add_argument('output_file_prefix', type=str, help='Output file prefix to write index and PQ vectors. The index is written in CSV format with the following columns: Id, vector, neighbours') | ||
parser.add_argument('--use_pq_vectors', default=False, action='store_true', help='Whether to replace FP vectors with PQ vectors in the output file.') | ||
args = parser.parse_args() | ||
main(args.index_path_prefix, args.data_type, args.output_file_prefix, args.use_pq_vectors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import array | ||
import struct | ||
|
||
"""Constants""" | ||
SECTOR_LEN = 4096 | ||
NUM_PQ_CENTROIDS = 256 | ||
|
||
class Node: | ||
def __init__(self, id, data_format_specifier, num_dims): | ||
self.id = id | ||
self.vector = None | ||
self.data_type = data_format_specifier | ||
self.num_dims = num_dims | ||
self.neighbors = None | ||
|
||
def __str__(self): | ||
if self.vector is None: | ||
raise Exception("Vector is not initialized") | ||
else: | ||
return str(self.id) + "\t" + str(self.vector.tolist()) + "\t" + str(self.neighbors.tolist()) | ||
|
||
def load_from(self, file): | ||
self.vector = array.array(self.data_type) | ||
self.vector.fromfile(file, self.num_dims) | ||
num_neighbors = struct.unpack('I', file.read(4))[0] | ||
self.neighbors = array.array('I') #unsigned int neighbor ids. | ||
self.neighbors.fromfile(file, num_neighbors) | ||
|
||
def add_neighbor(self, neighbor): | ||
self.neighbors.append(neighbor) | ||
|
||
def add_vector_dim(self, vector_dim): | ||
self.vector.append(vector_dim) | ||
|
||
|
||
class DataMat: | ||
def __init__(self, array_format_specifier, datatype_size): | ||
self.num_rows = 0 | ||
self.num_cols = 0 | ||
self.data = None | ||
self.data_format_specifier = array_format_specifier | ||
self.data_type_size = datatype_size | ||
|
||
|
||
def load_bin(self, file_name): | ||
with open(file_name, "rb") as file: | ||
self.load_bin_from_opened_file(file) | ||
|
||
def load_bin_metadata_from_opened_file(self, file): | ||
self.num_rows = struct.unpack('I', file.read(4))[0] | ||
self.num_cols = struct.unpack('I', file.read(4))[0] | ||
print(file.name + ": #rows: " + str(self.num_rows) + ", #cols: " + str(self.num_cols)) | ||
|
||
def load_bin_from_opened_file(self, file, file_offset_data=0): | ||
file.seek(file_offset_data, 0) | ||
self.load_bin_metadata_from_opened_file(file) | ||
self.data = array.array(self.data_format_specifier) | ||
self.data.fromfile(file, self.num_rows*self.num_cols) | ||
|
||
def load_data_only_from_opened_file(self, file, num_rows, num_cols, file_offset_data=0): | ||
file.seek(file_offset_data, 0) | ||
self.num_rows = num_rows | ||
self.num_cols = num_cols | ||
self.data = array.array(self.data_format_specifier) | ||
self.data.fromfile(file, self.num_rows*self.num_cols) | ||
|
||
def __len__(self): | ||
return self.num_rows | ||
|
||
|
||
def __getitem__(self, key): | ||
return self.data[key*self.num_cols:(key+1)*self.num_cols] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import struct | ||
import parse_common as pc | ||
import pathlib | ||
|
||
def process_sector(file, data_type_code, nodes_per_sector, num_dims, running_node_id, num_nodes): | ||
sector_offset = 0 | ||
num_nodes_read = 0 | ||
nodes = [] | ||
while num_nodes_read < nodes_per_sector and running_node_id < num_nodes: | ||
node = pc.Node(running_node_id, data_type_code, num_dims) | ||
node.load_from(file) | ||
|
||
num_nodes_read += 1 | ||
nodes.append(node) | ||
running_node_id += 1 | ||
return nodes | ||
|
||
|
||
def process_multi_sector_node(file, data_type_code, num_dims, running_node_id): | ||
node = pc.Node(running_node_id, data_type_code, num_dims) | ||
node.load_from(file) | ||
running_node_id += 1 | ||
return [node] | ||
|
||
|
||
|
||
def parse_index(index_path_prefix, data_type_code, data_type_size, out_graph_csv): | ||
return parse_index_with_PQ_vectors(index_path_prefix, data_type_code, data_type_size, out_graph_csv, None) | ||
|
||
def parse_index_with_PQ_vectors(index_path_prefix, data_type_code, data_type_size, out_graph_csv, compressed_vectors = None): | ||
disk_index_file_name = index_path_prefix + "_disk.index" | ||
|
||
with open(out_graph_csv, "w") as out_file: | ||
out_file.write("Id\tvector\tneighbours\n") | ||
|
||
with open(disk_index_file_name, "rb") as index_file: | ||
num_entries = struct.unpack('I', index_file.read(4))[0] | ||
num_dims = struct.unpack('I', index_file.read(4))[0] | ||
|
||
if num_dims != 1 or num_entries != 9: | ||
raise Exception("Mismatch in metadata. Expected 1 dimension and 9 entries. Got " + str(num_dims) + " dimensions and " + str(num_entries) + " entries.") | ||
|
||
# Read the metadata | ||
num_nodes = struct.unpack('Q', index_file.read(8))[0] | ||
num_dims = struct.unpack('Q', index_file.read(8))[0] | ||
medoid_id = struct.unpack('Q', index_file.read(8))[0] | ||
max_node_len = struct.unpack('Q', index_file.read(8))[0] | ||
nnodes_per_sector = struct.unpack('Q', index_file.read(8))[0] | ||
|
||
metadata_file = pathlib.Path.stem(out_graph_csv) + "_metadata.tsv" | ||
with open(metadata_file, "w") as metadata_out: | ||
str_metadata = "Num nodes: " + str(num_nodes) + "\n" + "Num dims: " + str(num_dims) + "\n" + "Medoid id: " + str(medoid_id) + "\n" + "Max node len: " + str(max_node_len) + "\n" + "Nodes per sector: " + str(nnodes_per_sector) + "\n" | ||
metadata_out.write(str_metadata) | ||
metadata_out.flush() | ||
|
||
print("Index properties: " + str(num_nodes) + " nodes, " | ||
+ str(num_dims) + " dimensions, medoid id: " | ||
+ str(medoid_id) + ", max node length: " + str(max_node_len) | ||
+ ", nodes per sector: " + str(nnodes_per_sector)) | ||
|
||
|
||
#skip the first sector | ||
index_file.seek(pc.SECTOR_LEN, 0) | ||
|
||
sectors_per_node = 1 | ||
if max_node_len > pc.SECTOR_LEN: | ||
if max_node_len % pc.SECTOR_LEN == 0: | ||
sectors_per_node = max_node_len // pc.SECTOR_LEN | ||
else: | ||
sectors_per_node = max_node_len // pc.SECTOR_LEN + 1 | ||
|
||
nodes_read = 0 | ||
sector_num = 1 | ||
while nodes_read < num_nodes: | ||
nodes = [] | ||
if sectors_per_node == 1: | ||
nodes = process_sector(index_file, data_type_code, nnodes_per_sector, num_dims, nodes_read, num_nodes) | ||
assert len(nodes) <= nnodes_per_sector | ||
else: | ||
nodes = process_multi_sector_node(index_file, data_type_code, num_dims, nodes_read) | ||
assert len(nodes) == 1 | ||
|
||
for node in nodes: | ||
if compressed_vectors is not None: | ||
compressed_vector = compressed_vectors[node.id] | ||
node.vector = compressed_vector | ||
out_file.write(str(node)) | ||
out_file.write("\n") | ||
out_file.flush() | ||
nodes_read += len(nodes) | ||
sector_num += sectors_per_node | ||
index_file.seek(sector_num * pc.SECTOR_LEN, 0) | ||
if sector_num % 100 == 0: | ||
print("Processed " + str(sector_num) + " sectors and " + str(nodes_read) + " nodes.") | ||
|
||
print("Processed " + str(nodes_read) + " points from index in " | ||
+ disk_index_file_name + " and wrote output to " + out_graph_csv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import parse_common as pc | ||
|
||
def parse_compressed_vectors(file_prefix) : | ||
file_name = file_prefix + "_pq_compressed.bin" | ||
compressed_vectors = pc.DataMat('B', 1) | ||
compressed_vectors.load_bin(file_name) | ||
return compressed_vectors | ||
|
||
def parse_pivots_file(file_prefix): | ||
file_name = file_prefix + "_pq_pivots.bin" | ||
with open(file_name, "rb") as file: | ||
metadata_mat = pc.DataMat('Q', 8) | ||
metadata_mat.load_bin_from_opened_file(file) | ||
num_metadata = metadata_mat.num_rows | ||
num_dims = metadata_mat.num_cols | ||
assert num_dims == 1 and (num_metadata == 4 or num_metadata == 5) | ||
|
||
|
||
for i in range(num_metadata): | ||
for j in range(num_dims): | ||
print (metadata_mat[i][j]) | ||
print("\n") | ||
|
||
pivots = pc.DataMat('f', 4) | ||
pivots.load_bin_from_opened_file(file, metadata_mat[0][0]) | ||
assert pivots.num_rows == pc.NUM_PQ_CENTROIDS | ||
|
||
centroids = pc.DataMat('f', 4) | ||
centroids.load_bin_from_opened_file(file, metadata_mat[1][0]) | ||
assert centroids.num_rows == pivots.num_cols | ||
assert centroids.num_cols == 1 | ||
|
||
#Assuming new file format =>(chunk offset is at offset 3) because we will not encounter old index formats now. | ||
chunk_offsets = pc.DataMat('I', 4) | ||
chunk_offsets.load_bin_from_opened_file(file, metadata_mat[2][0]) | ||
#assert chunk_offsets.num_rows == pivots.num_cols + 1 or chunk_offsets.num_rows == 0 | ||
assert chunk_offsets.num_cols == 1 | ||
#Ignoring rotmat for now. Also ignoring diskPQ | ||
|
||
return pivots, centroids, chunk_offsets |