Skip to content

Commit

Permalink
c++ hyb part test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Sep 19, 2022
1 parent 67008db commit 805d708
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 26 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/parser/*.cc
src/printer/*.cc
src/support/*.cc
src/sparse/*.cc
)

tvm_file_glob(GLOB CODEGEN_SRCS
Expand All @@ -287,6 +286,9 @@ tvm_file_glob(GLOB CODEGEN_SRCS

list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})

tvm_file_glob(GLOB SPARSE_SRCS
src/sparse/*.cc
)
tvm_file_glob(GLOB_RECURSE RELAY_OP_SRCS
src/relay/op/*.cc
)
Expand All @@ -310,6 +312,7 @@ list(APPEND COMPILER_SRCS ${RELAY_PASS_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_BACKEND_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_IR_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_QNN_SRCS})
list(APPEND COMPILER_SRCS ${SPARSE_SRCS})

tvm_file_glob(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
Expand All @@ -320,6 +323,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/vm/*.cc
)


if(BUILD_FOR_HEXAGON)
if(NOT BUILD_STATIC_RUNTIME)
# Allow undefined symbols (there will be some from libc).
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
"""Python-interface for Sparse-TIR"""

from .lower import lower_sparse_iter, lower_sparse_buffer
from .format_rewrite import FormatRewriteRule
from .format_rewrite import FormatRewriteRule, column_part_hyb
from .specialize import specialize_buffer
6 changes: 6 additions & 0 deletions python/tvm/sparse/format_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,9 @@ def __init__(
idx_map,
inv_idx_map,
) # type: ignore


def column_part_hyb(num_rows, num_cols, indptr_nd, indices_nd, num_col_parts, buckets):
return _ffi_api.ColumnPartHyb(
num_rows, num_cols, indptr_nd, indices_nd, num_col_parts, buckets # type: ignore
)
119 changes: 99 additions & 20 deletions src/sparse/format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,51 +17,130 @@
* under the License.
*/


/*!
* \file format.cc
* \brief format conversion routine.
*/

#include <tvm/runtime/ndarray.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
#include <cassert>
#include <unordered_map>


namespace tvm {

using runtime::NDArray;

Array<Array<Array<NDArray>>> ColumnPartHyb(int num_rows, int num_cols, NDArray indptr,
NDArray indices, int num_col_parts,
Array<Integer> buckets) {
int partition_size = (num_cols + num_col_parts - 1) / num_col_parts;
int num_bkts = buckets.size();
std::vector<int> buckets_vec;
for (const Integer& bucket_size : buckets) {
buckets_vec.push_back(bucket_size->value);
}

Array<Array<Array<NDArray>>> ColumnPartHyb(
int num_rows,
int num_cols,
NDArray indptr,
NDArray indices,
int column_parts,
Array<int> buckets
) {
Array<Array<NDArray>> rst_row_indices;
Array<Array<NDArray>> rst_col_indices;
int partition_size = (num_cols + column_parts - 1) / column_parts;

assert(indptr->dtype.bits == 32);
assert(indices->dtype.bits == 32);
CHECK_EQ(indptr->dtype.bits, 32) << "Only support int32 index data type, got "
<< int(indptr->dtype.bits) << " bits for indptr.";
CHECK_EQ(indices->dtype.bits, 32) << "Only support int32 index data type, got "
<< int(indices->dtype.bits) << " bits for indices.";
CHECK_EQ(indptr->device.device_type, kDLCPU) << "Only support ColumnPartHyb conversion on CPU.";
CHECK_EQ(indices->device.device_type, kDLCPU) << "Only support ColumnPartHyb conversion on CPU.";
int* indptr_data = static_cast<int*>(indptr->data);
int* indices_data = static_cast<int*>(indices->data);
std::vector<std::unordered_multiset<int>> degree_counter(num_col_parts);
for (int i = 0; i < num_rows; ++i) {
for (int j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {
int row_id = i;
int col_id = indices_data[j];
int part_id = col_id / partition_size;
degree_counter[part_id].insert(row_id);
}
}

/* (num_parts, num_buckets, ...) */
std::vector<std::vector<std::vector<int>>> row_indices(num_col_parts);
std::vector<std::vector<std::vector<int>>> col_indices(num_col_parts);
// init row_indices and col_indices
for (int part_id = 0; part_id < num_col_parts; ++part_id) {
for (int bucket_id = 0; bucket_id < num_bkts; ++bucket_id) {
row_indices[part_id].push_back(std::vector<int>());
col_indices[part_id].push_back(std::vector<int>());
}
}
for (int i = 0; i < num_rows; ++i) {
for (int j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {
int row_id = i;
int col_id = indices_data[j];
int part_id = col_id / partition_size;
int degree = degree_counter[part_id].count(row_id);
int bucket_id = std::upper_bound(buckets_vec.begin(), buckets_vec.end(), degree - 1) -
buckets_vec.begin();
if (bucket_id == num_bkts) {
bucket_id--;
}
int bucket_size = buckets_vec[bucket_id];
bool create_new_bucket = false;
int remainder = col_indices[part_id][bucket_id].size() % bucket_size;
if (remainder != 0) {
if (row_id != row_indices[part_id][bucket_id].back()) {
// padding
for (int k = remainder; k < bucket_size; ++k) {
col_indices[part_id][bucket_id].push_back(0);
}
create_new_bucket = true;
}
} else {
create_new_bucket = true;
}
if (create_new_bucket) {
ICHECK(col_indices[part_id][bucket_id].size() % bucket_size == 0) << "Invalid padding";
row_indices[part_id][bucket_id].push_back(row_id);
}
col_indices[part_id][bucket_id].push_back(col_id);
}
}

return {rst_row_indices, rst_col_indices};
// final padding and conversion to NDArray
Array<Array<NDArray>> row_indices_nd;
Array<Array<NDArray>> col_indices_nd;
for (int part_id = 0; part_id < num_col_parts; ++part_id) {
Array<NDArray> row_indices_part_local;
Array<NDArray> col_indices_part_local;
for (int bucket_id = 0; bucket_id < num_bkts; ++bucket_id) {
int bucket_size = buckets_vec[bucket_id];
// padding
int remainder = col_indices[part_id][bucket_id].size() % bucket_size;
if (remainder != 0) {
for (int k = remainder; k < bucket_size; ++k) {
col_indices[part_id][bucket_id].push_back(0);
}
}
// conversion to NDArray
int nnz = row_indices[part_id][bucket_id].size();
ICHECK(int(col_indices[part_id][bucket_id].size()) == nnz * bucket_size) << "Padding error.";
NDArray row_indices_bucket_local = NDArray::Empty({nnz}, {kDLInt, 32, 1}, {kDLCPU, 0});
NDArray col_indices_bucket_local =
NDArray::Empty({nnz, bucket_size}, {kDLInt, 32, 1}, {kDLCPU, 0});
row_indices_bucket_local.CopyFromBytes(row_indices[part_id][bucket_id].data(),
nnz * sizeof(int));
col_indices_bucket_local.CopyFromBytes(col_indices[part_id][bucket_id].data(),
nnz * bucket_size * sizeof(int));
row_indices_part_local.push_back(row_indices_bucket_local);
col_indices_part_local.push_back(col_indices_bucket_local);
}
row_indices_nd.push_back(row_indices_part_local);
col_indices_nd.push_back(col_indices_part_local);
}

// convert to NDArray

return {row_indices_nd, col_indices_nd};
}

namespace sparse {
TVM_REGISTER_GLOBAL("tir.sparse.ColumnPartHyb").set_body_typed(ColumnPartHyb);
TVM_REGISTER_GLOBAL("tir.sparse.ColumnPartHyb").set_body_typed(ColumnPartHyb);
} // namespace sparse
} // namespace tvm
115 changes: 111 additions & 4 deletions tests/python/sparsetir/test_format_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
# specific language governing permissions and limitations
# under the License.

from ast import copy_location
import tvm
from tvm.sparse import FormatRewriteRule
import dgl
import numpy as np
from tvm.sparse import FormatRewriteRule, column_part_hyb
from sparse_tir_scripts import csrmm
from sparse_tir_format_rewrite_scripts import (
bsr,
Expand Down Expand Up @@ -121,7 +124,111 @@ def test_csrmm_padding_rewrite():
tvm.ir.assert_structural_equal(mod["main"], padding_rewrite_with_preprocess, True)


def scipy_column_part_hyb(g, column_part, bucket_sizes):
mat = g.adj(transpose=True, scipy_fmt="csr")
buckets = bucket_sizes * column_part
m = mat.shape[0]
n = mat.shape[1]
nnz = mat.nnz
per_column_part_size = (n + column_part - 1) // column_part
sub_mats = [
mat[:, i * per_column_part_size : (i + 1) * per_column_part_size]
for i in range(column_part)
]

num_buckets = len(buckets)
ell_n = []

for partition in range(column_part):
sub_mat = sub_mats[partition]
in_degrees = sub_mat.indptr[1:] - sub_mat.indptr[:-1]
for i, bucket_size in enumerate(bucket_sizes[:-1]):
last_bucket_size = 0 if i == 0 else bucket_sizes[i - 1]
ell_n.append(int(((in_degrees > last_bucket_size) & (in_degrees <= bucket_size)).sum()))
sub_indegrees = in_degrees[in_degrees > bucket_sizes[-2]]
ell_n.append(int(((sub_indegrees + bucket_sizes[-1] - 1) // bucket_sizes[-1]).sum()))

ell_rows = []
ell_indices = []

for partition in range(column_part):
sub_mat = sub_mats[partition]
in_degrees = sub_mat.indptr[1:] - sub_mat.indptr[:-1]

for i, bucket_size in enumerate(bucket_sizes[:-1]):
last_bucket_size = 0 if i == 0 else bucket_sizes[i - 1]
ell_rows.append(
((in_degrees > last_bucket_size) & (in_degrees <= bucket_size)).nonzero()[0]
)
ell_rows.append((in_degrees > bucket_sizes[-2]).nonzero()[0])

for i, bucket_size in enumerate(bucket_sizes[:-1]):
indices = np.zeros(
(ell_n[partition * len(bucket_sizes) + i], bucket_size), dtype=np.int32
)
for j, row_id in enumerate(ell_rows[partition * len(bucket_sizes) + i]):
row = sub_mat[row_id]
indices[j, : row.nnz] = row.indices + partition * per_column_part_size
ell_indices.append(indices)

# split rows for the last bucket
indices = np.zeros(
(ell_n[(partition + 1) * len(bucket_sizes) - 1], bucket_sizes[-1]), dtype=np.int32
)
new_rows = np.zeros((ell_n[(partition + 1) * len(bucket_sizes) - 1],), dtype=np.int32)
bucket_size = bucket_sizes[-1]
i = 0
for row_id in ell_rows[-1]:
row = sub_mat[row_id]
for start_offset in range(0, row.nnz, bucket_size):
if start_offset + bucket_size >= row.nnz:
# last bucket
indices[i, : row.nnz - start_offset] = (
row.indices[start_offset:] + partition * per_column_part_size
)
else:
indices[i] = (
row.indices[start_offset : start_offset + bucket_size]
+ partition * per_column_part_size
)
new_rows[i] = row_id
i += 1

ell_indices.append(indices)
ell_rows[-1] = new_rows

return ell_rows, ell_indices


def test_column_part_hyb():
g = dgl.rand_graph(1000, 10000).int()
column_parts = 4
buckets = [1, 2, 4, 8]
indptr, indices, _ = g.adj_sparse("csc")
indptr_nd = tvm.nd.array(indptr.numpy(), device=tvm.cpu())
indices_nd = tvm.nd.array(indices.numpy(), device=tvm.cpu())
# built-in c++ funcion
row_indices, col_indices = column_part_hyb(
g.num_dst_nodes(), g.num_src_nodes(), indptr_nd, indices_nd, column_parts, buckets
)
# compute indices with scipy
row_indices_scipy, col_indices_scipy = scipy_column_part_hyb(g, column_parts, buckets)

for part_id in range(column_parts):
for bucket_id, _ in enumerate(buckets):
assert np.array_equal(
row_indices[part_id][bucket_id].numpy(),
row_indices_scipy[part_id * len(buckets) + bucket_id],
)
assert np.array_equal(
col_indices[part_id][bucket_id].numpy(),
col_indices_scipy[part_id * len(buckets) + bucket_id],
)


if __name__ == "__main__":
test_csrmm_bsr_rewrite()
test_csrmm_ell_rewrite()
test_csrmm_padding_rewrite()
# test_csrmm_bsr_rewrite()
# test_csrmm_ell_rewrite()
# test_csrmm_padding_rewrite()
test_column_part_hyb()
# test_condense()

0 comments on commit 805d708

Please sign in to comment.