diff --git a/pytorch_blade/bazel_build.py b/pytorch_blade/bazel_build.py index e41ebaba6d3..32f2aea10dc 100644 --- a/pytorch_blade/bazel_build.py +++ b/pytorch_blade/bazel_build.py @@ -68,9 +68,9 @@ def __init__(self, *args, **kwargs): "@org_disc_compiler//mlir/custom_ops:libdisc_custom_ops.so", "//pytorch_blade:libtorch_blade.so", "//pytorch_blade:_torch_blade.so", - "//tests/mhlo/torch-mlir-opt:torch-mlir-opt", - "//tests/torchscript:shape_analysis_tool", - "//tests/torch-disc-pdll:torch-disc-pdll", + #"//tests/mhlo/torch-mlir-opt:torch-mlir-opt", + #"//tests/torchscript:shape_analysis_tool", + #"//tests/torch-disc-pdll:torch-disc-pdll", ] torch_major_version, torch_minor_version = self.torch_version.split(".")[:2] @@ -264,15 +264,17 @@ def test(self): env["GCC_HOST_COMPILER_PATH"] = env.get("GCC_HOST_COMPILER_PATH", which("gcc")) self.test_suites = [ - "//tests/mhlo/...", - "//pytorch_blade:torch_blade_test_suite", - "//tests/torch-disc-pdll/tests/...", + "@org_disc_compiler//mlir/ral:collective_ops_test", + #"//tests/mhlo/...", + #"//pytorch_blade:torch_blade_test_suite", + #"//tests/torch-disc-pdll/tests/...", ] if (self.torch_major_version, self.torch_minor_version) > (1, 6): # torchscript graph ir parser changed after torch 1.6. # We will not test torchscript graph ir before torch 1.6 - self.test_suites.append("//tests/torchscript/...") + #self.test_suites.append("//tests/torchscript/...") + pass test_cmd = " ".join( [self.shell_setting, self.test_cmd] diff --git a/pytorch_blade/torch_blade/dynamo/__init__.py b/pytorch_blade/torch_blade/dynamo/__init__.py index 6f7a58bacd0..48c9840eca8 100644 --- a/pytorch_blade/torch_blade/dynamo/__init__.py +++ b/pytorch_blade/torch_blade/dynamo/__init__.py @@ -67,9 +67,10 @@ def _disc_compile(fx_g: fx.GraphModule, inps, use_ts=False, is_training=True) -> v = v.type new_kwargs[k] = v node.kwargs = new_kwargs - + print(fx_g.graph, flush=True) fx_g.graph.lint() fx_g.recompile() + f = torch.jit.script(fx_g) torch._C._jit_pass_remove_mutation(f.graph) if not is_training: diff --git a/pytorch_blade/torch_blade/mlir/disc_engine_conversion.py b/pytorch_blade/torch_blade/mlir/disc_engine_conversion.py index 49a83870d3d..abbbf449d19 100644 --- a/pytorch_blade/torch_blade/mlir/disc_engine_conversion.py +++ b/pytorch_blade/torch_blade/mlir/disc_engine_conversion.py @@ -237,4 +237,5 @@ def fusion_block(block): with tools.trust_tracing_shape(): fusion_block(graph) + print(graph, flush=True) _disc_engine_conversion(c_module) diff --git a/scripts/python/tao_build.py b/scripts/python/tao_build.py index 0de6f474eaa..5fb5f485f1a 100755 --- a/scripts/python/tao_build.py +++ b/scripts/python/tao_build.py @@ -327,11 +327,11 @@ def bazel_build(target, flag=""): flag = build_tao_compiler_add_flags_platform_alibaba(root, args, flag) - bazel_build(TARGET_TAO_COMPILER_MAIN, flag=flag) + #bazel_build(TARGET_TAO_COMPILER_MAIN, flag=flag) bazel_build(TARGET_DISC_OPT, flag=flag) # TODO:(fl237079) Support disc_replay for rocm version - if not args.rocm and not args.dcu: - bazel_build(TARGET_DISC_REPLAY, flag=flag) + #if not args.rocm and not args.dcu: + # bazel_build(TARGET_DISC_REPLAY, flag=flag) execute( "cp -f -p {}/tao/third_party/ptxas/10.2/ptxas ./bazel-bin/decoupling/".format( root diff --git a/tao_compiler/mlir/custom_ops/transpose_impl.cc b/tao_compiler/mlir/custom_ops/transpose_impl.cc index 1c4fde328b4..616b340f36e 100644 --- a/tao_compiler/mlir/custom_ops/transpose_impl.cc +++ b/tao_compiler/mlir/custom_ops/transpose_impl.cc @@ -66,6 +66,7 @@ TAO_RAL_API("ral_transpose", "gpu", ral_transpose); TAO_RAL_API("ral_transpose", "gpu", ral_transpose); TAO_RAL_API("ral_transpose", "gpu", ral_transpose); TAO_RAL_API("ral_transpose", "gpu", ral_transpose); +TAO_RAL_API("ral_transpose", "gpu", ral_transpose); #endif } // namespace ral diff --git a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc index a3de94b8677..266ac1825f9 100755 --- a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc @@ -494,7 +494,7 @@ struct TransposeConverter : public OpRewritePattern { if (rank != 2 && rank != 3) return failure(); // only rewriter custom library when switch 1 and 2 dimensions of // a 3d tensor, that means permute = [0, 2, 1] - if (rank == 3 && (permutation[1] != 2 || permutation[2] != 1)) + if (rank == 3 && (permutation[1] != 2 && permutation[2] != 1)) return failure(); bool on_gpu = placement_utils::isGpuMemRef(op->getOperand(0)); // TODO: support other device diff --git a/tao_compiler/mlir/disc/transforms/fusion_utils.cc b/tao_compiler/mlir/disc/transforms/fusion_utils.cc index 9614658f580..d88d0fc9b47 100644 --- a/tao_compiler/mlir/disc/transforms/fusion_utils.cc +++ b/tao_compiler/mlir/disc/transforms/fusion_utils.cc @@ -487,13 +487,28 @@ bool isRank2ScalarReduction(Operation* op) { auto reduce_op = dyn_cast(op); if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1) return false; - int rank = op->getOperand(2).getType().cast().getRank(); - // TODO(yancey): rewrite scalar reduction result to scalar tensor to avoid - // reshape to scalar tensor behand reduce op - Operation* reshapeOp = *op->getOperand(2).getUsers().begin(); - if (isa(reshapeOp) && - reshapeOp->getOperand(1).getType().cast().getRank() == 0) { - return true; + auto isRank0Tensor = [](Value v) -> bool { + return v.getType().cast().getRank() == 0; + }; + // TODO(yancey): it's a temporary solution to match scalar reduction, we need + // to erase the reshape op after scalar reduction, the result buffer of scalar + // reduction should be a scalar tensor instead of a <1xf32>tensor + { + Operation* reshapeOp = *op->getOperand(2).getUsers().begin(); + if (isa(reshapeOp) && isRank0Tensor(reshapeOp->getOperand(1))) { + return true; + } + } + { + Operation* convertOp = *op->getOperand(2).getUsers().begin(); + if (isa(convertOp)) { + auto resultBuffer = + convertOp->getOperand(convertOp->getNumOperands() - 1); + for (auto user : resultBuffer.getUsers()) { + if (isa(user) && isRank0Tensor(user->getOperand(1))) + return true; + } + } } return false; } @@ -1480,7 +1495,7 @@ bool BaseGpuFusionStrategy::isFusible(Operation* op) { !isRank2ScalarReduction(op))) // || isScalarReduction(op))) return false; - if (isa(op) && isRank2or3Transpose(op)) return false; + // if (isa(op) && isRank2or3Transpose(op)) return false; return BaseFusionStrategy::isFusible(op); } @@ -1515,18 +1530,18 @@ bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis, if (cnt >= 2) { return false; } - - if (has_rank2_col_reduction) { - const auto& results = target.getResults(); - auto ref_shape = getEffectiveShape(target, results[0]); - if (llvm::any_of(results, [&](Value result) { - auto op = target.findLastWriter(result); - return isa(op); - })) { - return false; + /* + if (has_rank2_col_reduction) { + const auto& results = target.getResults(); + auto ref_shape = getEffectiveShape(target, results[0]); + if (llvm::any_of(results, [&](Value result) { + auto op = target.findLastWriter(result); + return isa(op); + })) { + return false; + } } - } - + */ return BaseFusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target); } diff --git a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc index 2029d0e3e5e..a3b915125c6 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc @@ -1258,31 +1258,7 @@ LogicalResult lowerWithScheduleRowReduction(ArrayRef, Operation*, int vector_size = 1) { return failure(); } -/* Row reduction with 1 round warp shuffle - * - * RowPerBlock = threads / warpSize; - * for (m = 0; m < rows; m += RowPerBlock) { - * for (n = 0; n < threads; ++n) { - * rowIdx = m + warpIdx; - * if (rowIdx < rows) { - * // intra-thread reduction - * sum = init_value; - * for (k = laneIdx; k < cols; k += warpSize) { - * sum += inputs[rowIdx][k]; - * } - * - * // inter-thread reduction via warp shuffle - * for (offset = warpSize / 2; offset > 0; offset /= 2) { - * sum += __shfl_xor(sum, offset); - * } - * - * // write to output - * if (laneIdx == 0) { - * outputs[rowIdx] = sum - * } - * } - * } - */ + /** * for (m = 0; m < block_nums; ++m) { * for (n = 0; n < block_size; ++n) { @@ -1456,38 +1432,7 @@ LogicalResult lowerWithScheduleParallelReduction( b.create(loc, *(for_op_k.getResults().begin() + idx), shared_mem_map[root_op], tid); } - // acc_value = for_op_k.getResult(0); - // b.create(loc, acc_value, shared_mem, tid); } - /* - { - Value init_value = b.create( - loc, cast(root_op).getInitValues()[0]); - SmallVector init_values{init_value}; - Value var_j = nullptr; - // for (; i < n; i += grid_size) - // acc += inputs[i] + inputs[i + grid_size]; - scf::ForOp for_op_k = createLoopAndSetInsPt( - b, loc, var_j, i, mn, grid_size, init_values); - for_op_k.getBody()->clear(); - b.setInsertionPointToStart(for_op_k.getBody()); - auto lhs = dominant_op->getOperand(0); - SmallVector load_index({var_j, zero}); - Value data = createLoadOrUseCachedValue( - loc, &b, dominant_op, lhs, load_index, b.saveInsertionPoint()); - - SmallVector load_index2({b.create(loc, var_j, - block_dim), zero}); Value data1 = createLoadOrUseCachedValue( loc, &b, - dominant_op, lhs, load_index2, b.saveInsertionPoint()); - - Value sum = b.create(loc, data, data1); - acc = b.create(loc, acc, - *for_op_k.getRegionIterArgs().begin()); b.create(loc, - ValueRange{acc}); b.setInsertionPointAfter(for_op_k); acc_value = - for_op_k.getResult(0); b.create(loc, acc_value, shared_mem, - tid); - } - */ { Value var_j = nullptr; SmallVector init_values = {}; @@ -1514,18 +1459,6 @@ LogicalResult lowerWithScheduleParallelReduction( Value sum = (accum_factory[idx])(shm_val_1, shm_val_2); b.create(loc, sum, shared_mem_map[root_op], tid); } - /* - SmallVector multidim_load_index({tid}); - ValueRange load_index(multidim_load_index); - - SmallVector multidim_load_index2({b.create(loc, - tid, stride_val)}); ValueRange load_index2(multidim_load_index2); - - Value in_data = b.create(loc, shared_mem, load_index); - Value in_data1 = b.create(loc, shared_mem, load_index2); - Value sum = b.create(loc, in_data, in_data1); - b.create(loc, sum, shared_mem, tid); - */ b.create(loc); b.create(loc, yield_values); b.setInsertionPointAfter(if_tid_valid_op); @@ -1534,7 +1467,7 @@ LogicalResult lowerWithScheduleParallelReduction( b.create(loc); { // warp reduce - // if (tid < 16) + // if (tid < 32) scf::IfOp if_tid_valid_op = b.create( loc, /*resultTypes*/ TypeRange{}, b.create(loc, arith::CmpIPredicate::slt, tid, @@ -1556,14 +1489,6 @@ LogicalResult lowerWithScheduleParallelReduction( b.create(loc, sum, shared_mem_map[root_op], tid); b.create(loc); } - - // shm[tid] += shm[tid + stride]; - // Value idx2 = b.create(loc, tid, stride_val); - // Value shm_val_1 = b.create(loc, shared_mem, tid); - // Value shm_val_2 = b.create(loc, shared_mem, idx2); - // Value sum = b.create(loc, shm_val_1, shm_val_2); - // b.create(loc, sum, shared_mem_map[root_op], tid); - // b.create(loc); } b.setInsertionPointAfter(if_tid_valid_op); } @@ -1590,16 +1515,7 @@ LogicalResult lowerWithScheduleParallelReduction( getAtomicRMWKind(cast(root_op).getBody()), val, root_op->getOperand(2), ValueRange({zero})); } - /* - Value val = b.create(loc, shared_mem, tid); - Type root_element_type = getLhloOpsElementType(root_op); - b.create( - loc, root_element_type, - getAtomicRMWKind(cast(root_op).getBody()), - val, root_op->getOperand(2), ValueRange({zero})); - */ b.create(loc, yield_values); - b.setInsertionPointAfter(if_tid_zero_op); } b.setInsertionPointToEnd(local_workgroup.getBody()); @@ -4551,7 +4467,6 @@ LogicalResult HandleGpuFusionOp(OpBuilder& b, Operation* fusion, } break; case FusionType::kScalarReduction: { auto kname = getFusionFullName(fusion_op); - llvm::errs() << "kScalarReduction <" << kname << ">\n"; LogicalResult r = lowerWithScheduleParallelReduction( root_ops, dominant_op, fused_block); if (failed(r)) { diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc index 0af72f978f2..b95c8805b8e 100644 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc @@ -145,7 +145,6 @@ struct BaseCudaContextState : public tao::ral::Context::Resource { reportErrorIfAny(stream_executor::wrap::hipStreamSynchronize(stream), ctx, "StreamSync"); #else - reportErrorIfAny(cuStreamSynchronize(stream), ctx, "StreamSync"); #endif for (const_buffer_t buffer : device_persistent_buffers) { gpu_allocator->dealloc(const_cast(buffer));