Skip to content

Commit

Permalink
(#32) Fixed chain multiplication factorization pass. Also fixed compo…
Browse files Browse the repository at this point in the history
…und multiplication (D=A*B*C) for dense tensors.
  • Loading branch information
pthomadakis committed Oct 13, 2023
1 parent 96baa11 commit ac292ab
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
30 changes: 30 additions & 0 deletions integration_test/compound_exps/Dense_chain_mult_matrix.ta
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# RUN: comet-opt --convert-ta-to-it --convert-to-loops --convert-to-llvm %s &> Dense_chain_mult_matrix.llvm
# RUN: mlir-cpu-runner Dense_chain_mult_matrix.llvm -O3 -e main -entry-point-result=void -shared-libs=%mlir_utility_library_dir/libmlir_runner_utils%shlibext,%comet_utility_library_dir/libcomet_runner_utils%shlibext | FileCheck %s


def main() {
#IndexLabel Declarations
IndexLabel [i] = [2];
IndexLabel [j] = [2];
IndexLabel [k] = [5];
IndexLabel [l] = [2];

#Tensor Declarations
Tensor<double> A([i, j], {Dense});
Tensor<double> B([j, k], {Dense});
Tensor<double> C([k, l], {Dense});
Tensor<double> D([i, l], {Dense});

#Tensor Fill Operation
A[i, j] = 2.2;
B[j, k] = 3.4;
C[k, l] = 1.0;
D[i, l] = 0.0;

D[i, l] = A[i, j] * B[j, k] * C[k,l];
print(D);
}

# Print the result for verification.
# CHECK: data =
# CHECK-NEXT: 74.8,74.8,74.8,74.8,
30 changes: 30 additions & 0 deletions integration_test/opts/chain_mult_factorize.ta
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# RUN: comet-opt --opt-multiop-factorize --convert-ta-to-it --convert-to-loops --convert-to-llvm %s &> chain_mult_factorize.llvm
# RUN: mlir-cpu-runner chain_mult_factorize.llvm -O3 -e main -entry-point-result=void -shared-libs=%mlir_utility_library_dir/libmlir_runner_utils%shlibext,%comet_utility_library_dir/libcomet_runner_utils%shlibext | FileCheck %s


def main() {
#IndexLabel Declarations
IndexLabel [i] = [2];
IndexLabel [j] = [2];
IndexLabel [k] = [5];
IndexLabel [l] = [2];

#Tensor Declarations
Tensor<double> A([i, j], {Dense});
Tensor<double> B([j, k], {Dense});
Tensor<double> C([k, l], {Dense});
Tensor<double> D([i, l], {Dense});

#Tensor Fill Operation
A[i, j] = 2.2;
B[j, k] = 3.4;
C[k, l] = 1.0;
D[i, l] = 0.0;

D[i, l] = A[i, j] * B[j, k] * C[k,l];
print(D);
}

# Print the result for verification.
# CHECK: data =
# CHECK-NEXT: 74.8,74.8,74.8,74.8,
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ void addTensorDecl(t op)
if (ret_format.compare("Dense") == 0)
{
itensor = builder.create<DenseTensorDeclOp>(location, ret_value.getType(), lbls_value, ret_format);
builder.create<TensorFillOp>(location, itensor, builder.getF64FloatAttr(0));
}
else
{
Expand Down

0 comments on commit ac292ab

Please sign in to comment.