Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Mar 5, 2024
1 parent efcc56b commit 6fc8883
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 119 deletions.
16 changes: 9 additions & 7 deletions pytorch_blade/bazel_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(self, *args, **kwargs):
"@org_disc_compiler//mlir/custom_ops:libdisc_custom_ops.so",
"//pytorch_blade:libtorch_blade.so",
"//pytorch_blade:_torch_blade.so",
"//tests/mhlo/torch-mlir-opt:torch-mlir-opt",
"//tests/torchscript:shape_analysis_tool",
"//tests/torch-disc-pdll:torch-disc-pdll",
#"//tests/mhlo/torch-mlir-opt:torch-mlir-opt",
#"//tests/torchscript:shape_analysis_tool",
#"//tests/torch-disc-pdll:torch-disc-pdll",
]

torch_major_version, torch_minor_version = self.torch_version.split(".")[:2]
Expand Down Expand Up @@ -264,15 +264,17 @@ def test(self):
env["GCC_HOST_COMPILER_PATH"] = env.get("GCC_HOST_COMPILER_PATH", which("gcc"))

self.test_suites = [
"//tests/mhlo/...",
"//pytorch_blade:torch_blade_test_suite",
"//tests/torch-disc-pdll/tests/...",
"@org_disc_compiler//mlir/ral:collective_ops_test",
#"//tests/mhlo/...",
#"//pytorch_blade:torch_blade_test_suite",
#"//tests/torch-disc-pdll/tests/...",
]

if (self.torch_major_version, self.torch_minor_version) > (1, 6):
# torchscript graph ir parser changed after torch 1.6.
# We will not test torchscript graph ir before torch 1.6
self.test_suites.append("//tests/torchscript/...")
#self.test_suites.append("//tests/torchscript/...")
pass

test_cmd = " ".join(
[self.shell_setting, self.test_cmd]
Expand Down
3 changes: 2 additions & 1 deletion pytorch_blade/torch_blade/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ def _disc_compile(fx_g: fx.GraphModule, inps, use_ts=False, is_training=True) ->
v = v.type
new_kwargs[k] = v
node.kwargs = new_kwargs

print(fx_g.graph, flush=True)
fx_g.graph.lint()
fx_g.recompile()

f = torch.jit.script(fx_g)
torch._C._jit_pass_remove_mutation(f.graph)
if not is_training:
Expand Down
1 change: 1 addition & 0 deletions pytorch_blade/torch_blade/mlir/disc_engine_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,5 @@ def fusion_block(block):

with tools.trust_tracing_shape():
fusion_block(graph)
print(graph, flush=True)
_disc_engine_conversion(c_module)
6 changes: 3 additions & 3 deletions scripts/python/tao_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,11 @@ def bazel_build(target, flag=""):

flag = build_tao_compiler_add_flags_platform_alibaba(root, args, flag)

bazel_build(TARGET_TAO_COMPILER_MAIN, flag=flag)
#bazel_build(TARGET_TAO_COMPILER_MAIN, flag=flag)
bazel_build(TARGET_DISC_OPT, flag=flag)
# TODO:(fl237079) Support disc_replay for rocm version
if not args.rocm and not args.dcu:
bazel_build(TARGET_DISC_REPLAY, flag=flag)
#if not args.rocm and not args.dcu:
# bazel_build(TARGET_DISC_REPLAY, flag=flag)
execute(
"cp -f -p {}/tao/third_party/ptxas/10.2/ptxas ./bazel-bin/decoupling/".format(
root
Expand Down
1 change: 1 addition & 0 deletions tao_compiler/mlir/custom_ops/transpose_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ TAO_RAL_API("ral_transpose", "gpu", ral_transpose<float, 2>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<float, 3>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::half, 2>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::half, 3>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<bool, 2>);
#endif

} // namespace ral
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ struct TransposeConverter : public OpRewritePattern<lmhlo::TransposeOp> {
if (rank != 2 && rank != 3) return failure();
// only rewriter custom library when switch 1 and 2 dimensions of
// a 3d tensor, that means permute = [0, 2, 1]
if (rank == 3 && (permutation[1] != 2 || permutation[2] != 1))
if (rank == 3 && (permutation[1] != 2 && permutation[2] != 1))
return failure();
bool on_gpu = placement_utils::isGpuMemRef(op->getOperand(0));
// TODO: support other device
Expand Down
53 changes: 34 additions & 19 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,28 @@ bool isRank2ScalarReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1)
return false;
int rank = op->getOperand(2).getType().cast<MemRefType>().getRank();
// TODO(yancey): rewrite scalar reduction result to scalar tensor to avoid
// reshape to scalar tensor behand reduce op
Operation* reshapeOp = *op->getOperand(2).getUsers().begin();
if (isa<ReshapeOp>(reshapeOp) &&
reshapeOp->getOperand(1).getType().cast<MemRefType>().getRank() == 0) {
return true;
auto isRank0Tensor = [](Value v) -> bool {
return v.getType().cast<MemRefType>().getRank() == 0;
};
// TODO(yancey): it's a temporary solution to match scalar reduction, we need
// to erase the reshape op after scalar reduction, the result buffer of scalar
// reduction should be a scalar tensor instead of a <1xf32>tensor
{
Operation* reshapeOp = *op->getOperand(2).getUsers().begin();
if (isa<ReshapeOp>(reshapeOp) && isRank0Tensor(reshapeOp->getOperand(1))) {
return true;
}
}
{
Operation* convertOp = *op->getOperand(2).getUsers().begin();
if (isa<ConvertOp>(convertOp)) {
auto resultBuffer =
convertOp->getOperand(convertOp->getNumOperands() - 1);
for (auto user : resultBuffer.getUsers()) {
if (isa<ReshapeOp>(user) && isRank0Tensor(user->getOperand(1)))
return true;
}
}
}
return false;
}
Expand Down Expand Up @@ -1480,7 +1495,7 @@ bool BaseGpuFusionStrategy::isFusible(Operation* op) {
!isRank2ScalarReduction(op))) // || isScalarReduction(op)))
return false;

if (isa<lmhlo::TransposeOp>(op) && isRank2or3Transpose(op)) return false;
// if (isa<lmhlo::TransposeOp>(op) && isRank2or3Transpose(op)) return false;
return BaseFusionStrategy::isFusible(op);
}

Expand Down Expand Up @@ -1515,18 +1530,18 @@ bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
if (cnt >= 2) {
return false;
}

if (has_rank2_col_reduction) {
const auto& results = target.getResults();
auto ref_shape = getEffectiveShape(target, results[0]);
if (llvm::any_of(results, [&](Value result) {
auto op = target.findLastWriter(result);
return isa<lmhlo::TransposeOp>(op);
})) {
return false;
/*
if (has_rank2_col_reduction) {
const auto& results = target.getResults();
auto ref_shape = getEffectiveShape(target, results[0]);
if (llvm::any_of(results, [&](Value result) {
auto op = target.findLastWriter(result);
return isa<lmhlo::TransposeOp>(op);
})) {
return false;
}
}
}

*/
return BaseFusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
}

Expand Down
89 changes: 2 additions & 87 deletions tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1258,31 +1258,7 @@ LogicalResult lowerWithScheduleRowReduction(ArrayRef<Operation*>, Operation*,
int vector_size = 1) {
return failure();
}
/* Row reduction with 1 round warp shuffle
*
* RowPerBlock = threads / warpSize;
* for (m = 0; m < rows; m += RowPerBlock) {
* for (n = 0; n < threads; ++n) {
* rowIdx = m + warpIdx;
* if (rowIdx < rows) {
* // intra-thread reduction
* sum = init_value;
* for (k = laneIdx; k < cols; k += warpSize) {
* sum += inputs[rowIdx][k];
* }
*
* // inter-thread reduction via warp shuffle
* for (offset = warpSize / 2; offset > 0; offset /= 2) {
* sum += __shfl_xor(sum, offset);
* }
*
* // write to output
* if (laneIdx == 0) {
* outputs[rowIdx] = sum
* }
* }
* }
*/

/**
* for (m = 0; m < block_nums; ++m) {
* for (n = 0; n < block_size; ++n) {
Expand Down Expand Up @@ -1456,38 +1432,7 @@ LogicalResult lowerWithScheduleParallelReduction(
b.create<memref::StoreOp>(loc, *(for_op_k.getResults().begin() + idx),
shared_mem_map[root_op], tid);
}
// acc_value = for_op_k.getResult(0);
// b.create<memref::StoreOp>(loc, acc_value, shared_mem, tid);
}
/*
{
Value init_value = b.create<memref::LoadOp>(
loc, cast<lmhlo::ReduceOp>(root_op).getInitValues()[0]);
SmallVector<Value, 2> init_values{init_value};
Value var_j = nullptr;
// for (; i < n; i += grid_size)
// acc += inputs[i] + inputs[i + grid_size];
scf::ForOp for_op_k = createLoopAndSetInsPt(
b, loc, var_j, i, mn, grid_size, init_values);
for_op_k.getBody()->clear();
b.setInsertionPointToStart(for_op_k.getBody());
auto lhs = dominant_op->getOperand(0);
SmallVector<Value, 2> load_index({var_j, zero});
Value data = createLoadOrUseCachedValue(
loc, &b, dominant_op, lhs, load_index, b.saveInsertionPoint());
SmallVector<Value, 2> load_index2({b.create<arith::AddIOp>(loc, var_j,
block_dim), zero}); Value data1 = createLoadOrUseCachedValue( loc, &b,
dominant_op, lhs, load_index2, b.saveInsertionPoint());
Value sum = b.create<arith::AddFOp>(loc, data, data1);
acc = b.create<arith::AddFOp>(loc, acc,
*for_op_k.getRegionIterArgs().begin()); b.create<scf::YieldOp>(loc,
ValueRange{acc}); b.setInsertionPointAfter(for_op_k); acc_value =
for_op_k.getResult(0); b.create<memref::StoreOp>(loc, acc_value, shared_mem,
tid);
}
*/
{
Value var_j = nullptr;
SmallVector<Value, 4> init_values = {};
Expand All @@ -1514,18 +1459,6 @@ LogicalResult lowerWithScheduleParallelReduction(
Value sum = (accum_factory[idx])(shm_val_1, shm_val_2);
b.create<memref::StoreOp>(loc, sum, shared_mem_map[root_op], tid);
}
/*
SmallVector<Value, 2> multidim_load_index({tid});
ValueRange load_index(multidim_load_index);
SmallVector<Value, 2> multidim_load_index2({b.create<arith::AddIOp>(loc,
tid, stride_val)}); ValueRange load_index2(multidim_load_index2);
Value in_data = b.create<memref::LoadOp>(loc, shared_mem, load_index);
Value in_data1 = b.create<memref::LoadOp>(loc, shared_mem, load_index2);
Value sum = b.create<arith::AddFOp>(loc, in_data, in_data1);
b.create<memref::StoreOp>(loc, sum, shared_mem, tid);
*/
b.create<gpu::BarrierOp>(loc);
b.create<scf::YieldOp>(loc, yield_values);
b.setInsertionPointAfter(if_tid_valid_op);
Expand All @@ -1534,7 +1467,7 @@ LogicalResult lowerWithScheduleParallelReduction(
b.create<gpu::BarrierOp>(loc);
{
// warp reduce
// if (tid < 16)
// if (tid < 32)
scf::IfOp if_tid_valid_op = b.create<scf::IfOp>(
loc, /*resultTypes*/ TypeRange{},
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, tid,
Expand All @@ -1556,14 +1489,6 @@ LogicalResult lowerWithScheduleParallelReduction(
b.create<memref::StoreOp>(loc, sum, shared_mem_map[root_op], tid);
b.create<gpu::BarrierOp>(loc);
}

// shm[tid] += shm[tid + stride];
// Value idx2 = b.create<arith::AddIOp>(loc, tid, stride_val);
// Value shm_val_1 = b.create<memref::LoadOp>(loc, shared_mem, tid);
// Value shm_val_2 = b.create<memref::LoadOp>(loc, shared_mem, idx2);
// Value sum = b.create<arith::AddFOp>(loc, shm_val_1, shm_val_2);
// b.create<memref::StoreOp>(loc, sum, shared_mem_map[root_op], tid);
// b.create<gpu::BarrierOp>(loc);
}
b.setInsertionPointAfter(if_tid_valid_op);
}
Expand All @@ -1590,16 +1515,7 @@ LogicalResult lowerWithScheduleParallelReduction(
getAtomicRMWKind(cast<lmhlo::ReduceOp>(root_op).getBody()), val,
root_op->getOperand(2), ValueRange({zero}));
}
/*
Value val = b.create<memref::LoadOp>(loc, shared_mem, tid);
Type root_element_type = getLhloOpsElementType(root_op);
b.create<memref::AtomicRMWOp>(
loc, root_element_type,
getAtomicRMWKind(cast<lmhlo::ReduceOp>(root_op).getBody()),
val, root_op->getOperand(2), ValueRange({zero}));
*/
b.create<scf::YieldOp>(loc, yield_values);

b.setInsertionPointAfter(if_tid_zero_op);
}
b.setInsertionPointToEnd(local_workgroup.getBody());
Expand Down Expand Up @@ -4551,7 +4467,6 @@ LogicalResult HandleGpuFusionOp(OpBuilder& b, Operation* fusion,
} break;
case FusionType::kScalarReduction: {
auto kname = getFusionFullName(fusion_op);
llvm::errs() << "kScalarReduction <" << kname << ">\n";
LogicalResult r = lowerWithScheduleParallelReduction(
root_ops, dominant_op, fused_block);
if (failed(r)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ struct BaseCudaContextState : public tao::ral::Context::Resource {
reportErrorIfAny(stream_executor::wrap::hipStreamSynchronize(stream), ctx,
"StreamSync");
#else
reportErrorIfAny(cuStreamSynchronize(stream), ctx, "StreamSync");
#endif
for (const_buffer_t buffer : device_persistent_buffers) {
gpu_allocator->dealloc(const_cast<buffer_t>(buffer));
Expand Down

0 comments on commit 6fc8883

Please sign in to comment.