Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into DIP_jpeg
Browse files Browse the repository at this point in the history
  • Loading branch information
Guan-schoolmate committed Oct 18, 2023
2 parents 671fd48 + 404b4e2 commit 67e9de1
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 34 deletions.
8 changes: 3 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ else()
# MLIR/LLVM Configuration
#-------------------------------------------------------------------------------

# Allow passing external LLVM source, instead of forcing user using the vendored one
if (NOT LLVM_PROJECT_SOURCE_DIR)
set(LLVM_PROJECT_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/llvm")
message(STATUS "Using LLVM Project ${LLVM_PROJECT_SOURCE_DIR}")
endif()
# Allow using out-of-tree llvm directory
set(LLVM_PROJECT_SOURCE_DIR ${LLVM_MAIN_SRC_DIR}/..)
message(STATUS "Using LLVM Project ${LLVM_PROJECT_SOURCE_DIR}")

set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
set(MLIR_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include)
Expand Down
14 changes: 10 additions & 4 deletions examples/MLIRPython/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ def foo(c, a, b):


foo_mlir = dynamo.optimize(DynamoCompiler)(foo)
a = torch.randn(3, 2)
b = torch.randn(2, 3)
c = torch.randn(3, 3)
foo_mlir(c, a, b)

a_float32 = torch.randn(3, 2)
b_float32 = torch.randn(2, 3)
c_float32 = torch.randn(3, 3)
foo_mlir(c_float32, a_float32, b_float32)

a_int32 = torch.randint(10, (3, 2)).to(torch.int32)
b_int32 = torch.randint(10, (2, 3)).to(torch.int32)
c_int32 = torch.randint(10, (3, 3)).to(torch.int32)
foo_mlir(c_int32, a_int32, b_int32)
12 changes: 9 additions & 3 deletions examples/MLIRPython/arith_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
import torch
import torch._dynamo as dynamo


def foo(x, y):
return x + y


foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo)
in1 = torch.randn(10)
in2 = torch.randn(10)
foo_mlir(in1, in2)
float32_in1 = torch.randn(10).to(torch.float32)
float32_in2 = torch.randn(10).to(torch.float32)
foo_mlir(float32_in1, float32_in2)

int32_in1 = torch.randint(0, 10, (10,)).to(torch.int32)
int32_in2 = torch.randint(0, 10, (10,)).to(torch.int32)
foo_mlir(int32_in1, int32_in2)
15 changes: 12 additions & 3 deletions examples/MLIRPython/buddy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,23 @@ def import_graph(self) -> ir.Module:
"""Import the FX graph, generate an MLIR module in high-level dialects.
Returns:
mlir.ir.Module: An MLIR moduel in high-level dialects.
mlir.ir.Module: An MLIR module in high-level dialects.
"""
with ir.InsertionPoint(self._module.body):
arguments = []
for arg in self._inputs:
shape_list = list(arg.shape)
f32 = ir.F32Type.get()
tensor_arg = ir.RankedTensorType.get(shape_list, f32)
dtype = arg.dtype
match dtype:
case torch.int32:
mlir_dtype = ir.IntegerType.get_signless(32)
case torch.float32:
mlir_dtype = ir.F32Type.get()
case _:
raise NotImplementedError(
f"Unsupported dtype {dtype} for argument {arg}")
tensor_arg = ir.RankedTensorType.get(shape_list, mlir_dtype)
arguments.append(tensor_arg)

@func.FuncOp.from_py_func(*arguments, name=self._func_name)
Expand Down Expand Up @@ -141,6 +149,7 @@ def Lowering(module: ir.Module):
print("-------------------------------------------------------------------")
print("Bufferizing the module ...")
pm = PassManager("builtin.module")
pm.add("func.func(tosa-to-linalg-named)")
pm.add("func.func(tosa-to-linalg)")
pm.add("func.func(tosa-to-tensor)")
pm.add("func.func(tosa-to-arith)")
Expand Down
31 changes: 16 additions & 15 deletions examples/MLIRPython/buddy/operators_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _broadcast_shape(tensor_input1: ir.Value,
tensor_input2: ir.Value) -> List[int]:
"""Calculate the broadcast shape of two tensors with broadcastable shapes
"""Calculate the broadcast shape of two tensors with broadcastable shapes
according to PyTorch's broadcast semantics: https://pytorch.org/docs/stable/notes/broadcasting.html"""
shp1 = ir.RankedTensorType(tensor_input1.type).shape
shp2 = ir.RankedTensorType(tensor_input2.type).shape
Expand Down Expand Up @@ -40,9 +40,9 @@ def AddOp(node: torch.fx.Node,
input2 = symbol_table.get((str(node.args[1]), 0))
broadcasted_shp = _broadcast_shape(input1, input2)
sizes = broadcasted_shp
f32 = ir.F32Type.get()
addResultTensorType = ir.RankedTensorType.get(sizes, f32)
op = tosa.AddOp(addResultTensorType, input1, input2)
result_element_type = ir.RankedTensorType(input1.type).element_type
add_result_tensor_type = ir.RankedTensorType.get(sizes, result_element_type)
op = tosa.AddOp(add_result_tensor_type, input1, input2)
return op


Expand All @@ -63,17 +63,18 @@ def AddMMOp(node: torch.fx.Node,
mat2 = symbol_table.get((str(node.args[2]), 0))
mat1_shp = ir.RankedTensorType(mat1.type).shape
mat2_shp = ir.RankedTensorType(mat2.type).shape
result_shp = [mat1_shp[0], mat2_shp[1]]
f32 = ir.F32Type.get()
element = ir.FloatAttr.get(f32, 0.0)
tensor_type = ir.RankedTensorType.get(result_shp, f32)
attr = ir.DenseElementsAttr.get_splat(tensor_type, element)
matmul_result_buffer = arith.ConstantOp(tensor_type, attr).result
# Generate matmul operation.
matmul_op_result = linalg.matmul(mat1, mat2, outs=[matmul_result_buffer])

add_result_tensor_type = ir.RankedTensorType.get(result_shp, f32)
op = tosa.AddOp(add_result_tensor_type, input_, matmul_op_result)
mat1 = tosa.ReshapeOp(mat1, [1, *mat1_shp]).output
mat2 = tosa.ReshapeOp(mat2, [1, *mat2_shp]).output

matmul_result_shp = [1, mat1_shp[0], mat2_shp[1]]
result_element_type = ir.RankedTensorType(input_.type).element_type
matmul_result_type = ir.RankedTensorType.get(matmul_result_shp, result_element_type)
matmul_op = tosa.MatMulOp(matmul_result_type, mat1, mat2)
matmul_result = tosa.ReshapeOp(matmul_op.c, matmul_result_shp[1:])

add_result_shp = [mat1_shp[0], mat2_shp[1]]
add_result_tensor_type = ir.RankedTensorType.get(add_result_shp, result_element_type)
op = tosa.AddOp(add_result_tensor_type, input_, matmul_result)
return op


Expand Down
7 changes: 7 additions & 0 deletions examples/MLIRSparseTensor/data/generated.mtx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
; extended FROSTT format
2 4
4 4
1 2 30
1 3 4
3 3 4
4 3 7
24 changes: 23 additions & 1 deletion examples/MLIRSparseTensor/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ endif

SPARSE_MATRIX_A := ./data/sa.mtx
SPARSE_MATRIX_B := ./data/sb.mtx
SPARSE_MATRIX_C := ./data/generated.mtx

sparse-tensor-fuse-tensor-cast-lower:
@${MLIR_OPT} ./sparse-tensor-fuse-tensor-cast.mlir\
Expand All @@ -28,7 +29,7 @@ sparse-tensor-new-translate:
sparse-tensor-new-run:
@${MLIR_OPT} ./sparse-tensor-new.mlir \
--sparse-compiler="enable-runtime-library=true" | \
TENSOR0=${SPARSE_MATRIX_A} TENSOR1=${SPARSE_MATRIX_B} \
TENSOR0=${SPARSE_MATRIX_A} TENSOR1=${SPARSE_MATRIX_B} TENSOR2=${SPARSE_MATRIX_C} \
${MLIR_CPU_RUNNER} -e main -O0 --entry-point-result=void \
--shared-libs=${MLIR_RUNNER_UTILS},${MLIR_C_RUNNER_UTILS}

Expand Down Expand Up @@ -104,3 +105,24 @@ sparse-tensor-expand-lower:
--linalg-generalize-named-ops \
--linalg-fuse-elementwise-ops \
--sparsification -o log.mlir

# This target will show the original for-loop without vectorization,
# which is useful to compare with the vectorized version.
sparse-tensor-vectorization-linalg-lower:
@${MLIR_OPT} ./sparse-tensor-vectorization.mlir \
--linalg-generalize-named-ops \
--linalg-fuse-elementwise-ops \
--sparsification \
-o log.mlir
sparse-tensor-vectorization-lower:
@${MLIR_OPT} ./sparse-tensor-vectorization.mlir \
--sparsification --cse \
--sparse-vectorization="vl=16" --cse \
-o log.mlir
# This example is used for code verification only, as there is currently no ARMSVE machine for us to run the code on.
# Do the same run, but with VLA enable
sparse-tensor-vla-vectorization-lower:
@${MLIR_OPT} ./sparse-tensor-vectorization.mlir \
--sparsification --cse \
--sparse-vectorization="vl=16 enable-vla-vectorization=true" --cse \
-o log.mlir
4 changes: 4 additions & 0 deletions examples/MLIRSparseTensor/sparse-tensor-new.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,9 @@ func.func @main() {
%val_buf_1 = memref.cast %val1 : memref<?xi32> to memref<*xi32>
call @printMemrefI32(%val_buf_1) : (memref<*xi32>) -> ()

%c2 = arith.constant 2 : index
%file2 = call @getTensorFilename(%c2) : (index) -> (!Filename)
sparse_tensor.out %1, %file2 : tensor<4x4xi32, #SparseMatrix>, !Filename

return
}
52 changes: 52 additions & 0 deletions examples/MLIRSparseTensor/sparse-tensor-vectorization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = ["compressed"]
}>

#trait_mul = {
indexing_maps = [
affine_map<(i) -> (i)>, // a
affine_map<(i) -> (i)>, // b
affine_map<(i) -> (i)> // x (out)
],
iterator_types = ["parallel"],
doc = "x(i) = a(i) * b(i)"
}

// Example for parallel loop vectorization
func.func @sparse_mul(%arga: tensor<1024xf32, #SparseVector>,
%argb: tensor<1024xf32>,
%argx: tensor<1024xf32>) -> tensor<1024xf32> {
%0 = linalg.generic #trait_mul
ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
outs(%argx: tensor<1024xf32>) {
^bb(%a: f32, %b: f32, %x: f32):
%0 = arith.mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<1024xf32>
return %0 : tensor<1024xf32>
}

#trait_reduction = {
indexing_maps = [
affine_map<(i) -> (i)>, // a
affine_map<(i) -> (i)>, // b
affine_map<(i) -> ()> // x (out)
],
iterator_types = ["reduction"],
doc = "x += a(i) * b(i)"
}

// Example for reduction loop vectorization
func.func @sparse_reduction(%arga: tensor<1024xf32, #SparseVector>,
%argb: tensor<1024xf32>,
%argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait_reduction
ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
outs(%argx: tensor<f32>) {
^bb(%a: f32, %b: f32, %x: f32):
%0 = arith.mulf %a, %b : f32
%1 = arith.addf %x, %0 : f32
linalg.yield %1 : f32
} -> tensor<f32>
return %0 : tensor<f32>
}
12 changes: 11 additions & 1 deletion frontend/Interfaces/buddy/DAP/AudioContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,17 @@ template <typename T, size_t N> class Audio {

template <typename T, size_t N> bool Audio<T, N>::save(std::string filename) {
if (!this->audioFile.samples) {
this->audioFile.samples.reset(this->data->release());
auto temp = this->data->release();
if constexpr (std::is_same_v<T, float>) {
for (int i = 0; i < audioFile.numSamples; i++) {
if (temp[i] != temp[i]) { // To handle NaN values
temp[i] = 0.9999999;
} else { // Clamp the values between -1.0 to 1.0
temp[i] = std::clamp(temp[i], float(-1.0), float(0.9999999));
}
}
}
this->audioFile.samples.reset(temp);
}
return this->audioFile.save(filename);
}
Expand Down
34 changes: 32 additions & 2 deletions midend/lib/Conversion/LowerDAP/LowerDAPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {
VectorType vectorTy32 = VectorType::get({stride}, f32);

Value zr = rewriter.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);
// calculate the upper bound of the FIR part <scf::ForOp>
Value strictN = rewriter.create<SubIOp>(loc, N, c2);
Value strideRem = rewriter.create<RemSIOp>(loc, strictN, strideVal);
Value upperN = rewriter.create<SubIOp>(loc, N, strideRem);

// loop over every row in SOS matrix
rewriter.create<scf::ForOp>(
Expand Down Expand Up @@ -245,10 +249,10 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {
builder.create<vector::BroadcastOp>(loc, vectorTy32, b2);

// A biquad filter expression:
// y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] + a1*y[n-1] + a2*y[n-2];
// y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] - a1*y[n-1] - a2*y[n-2];
// FIR part
builder.create<scf::ForOp>(
loc, c2, N, strideVal, ValueRange{std::nullopt},
loc, c2, upperN, strideVal, ValueRange{std::nullopt},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value idx0 = iv;
Expand All @@ -275,6 +279,32 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {
builder.create<scf::YieldOp>(loc, std::nullopt);
});

// process the remain data of FIR part
Value idx1 = builder.create<SubIOp>(loc, upperN, c1);
Value idx2 = builder.create<SubIOp>(loc, upperN, c2);
Value in1 =
builder.create<memref::LoadOp>(loc, input, ValueRange{idx1});
Value in2 =
builder.create<memref::LoadOp>(loc, input, ValueRange{idx2});

builder.create<scf::ForOp>(
loc, upperN, N, c1, ValueRange{in1, in2},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value in0 =
builder.create<memref::LoadOp>(loc, input, ValueRange{iv});

Value temp0 = builder.create<MulFOp>(loc, b0, in0);
Value temp1 = builder.create<MulFOp>(loc, b1, in1);
Value temp2 = builder.create<MulFOp>(loc, b2, in2);
Value sum0 = builder.create<AddFOp>(loc, temp0, temp1);
Value sum1 = builder.create<AddFOp>(loc, sum0, temp2);

builder.create<memref::StoreOp>(loc, sum1, output, ValueRange{iv});

builder.create<scf::YieldOp>(loc, std::vector<Value>{in0, in1});
});

// IIR part
builder.create<scf::ForOp>(
loc, c0, N, c1, ValueRange{z1, z2},
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
--pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu
torch == 2.2.0.dev20231015+cpu
transformers == 4.33.1

0 comments on commit 67e9de1

Please sign in to comment.