Skip to content
This repository has been archived by the owner on Jun 28, 2024. It is now read-only.

Commit

Permalink
fix lit test
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed May 12, 2023
1 parent 8b55aa3 commit 275fead
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ cmake-build-*
# cache dumps
triton_cache*
log_*

#
python/triton/third_party/cuda/bin/ptxas
2 changes: 1 addition & 1 deletion bin/triton-translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
}

llvm::LLVMContext llvmContext;
#ifdef USE_ROCM // USE_ROCM doesnot work here
#ifdef USE_ROCM
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), true /*isRocm*/);
#else
Expand Down
15 changes: 11 additions & 4 deletions python/triton/tools/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,31 @@
print(module)
sys.exit(0)

if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
# set arch depending on platform
if args.gfx:
arch = args.gfx
elif args.sm:
arch = args.sm
else:
raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation")

# triton-ir -> triton-gpu-ir
module = tc.ttir_to_ttgir(module, num_warps=args.num_warps)
module = tc.optimize_ttgir(module, num_stages=3, arch=args.sm)
module = tc.optimize_ttgir(module, num_stages=3, arch=arch)
if args.target == 'triton-gpu-ir':
print(module.str())
sys.exit(0)

# triton-gpu-ir -> llvm-ir
module = tc.ttgir_to_llir(module, extern_libs=None, arch=args.sm)
module = tc.ttgir_to_llir(module, extern_libs=None, arch=arch)
if args.target == 'llvm-ir':
print(module)
sys.exit(0)

# llvm-ir -> ptx
if args.target == 'ptx':
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
module = tc.llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
Expand Down
4 changes: 2 additions & 2 deletions scripts/amd/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ chmod -R 777 $LOG_DIR

bash scripts/amd/clean.sh
bash scripts/amd/build.sh
bash scripts/amd/test.sh 2>&1 |tee $LOG_DIR/test.log
# bash scripts/amd/test.sh 2>&1 |tee $LOG_DIR/test.log
# bash scripts/amd/pytorch.sh 2>&1 |tee $LOG_DIR/test.log
# bash scripts/amd/lit.sh 2>&1 |tee $LOG_DIR/lit.log
bash scripts/amd/lit.sh 2>&1 |tee $LOG_DIR/lit.log
# bash scripts/amd/test.sh backtrace 2>&1 |tee $LOG_DIR/backtrace.log
# bash scripts/amd/cache_print.sh 2>&1 |tee $LOG_DIR/cache.log
2 changes: 1 addition & 1 deletion scripts/amd/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ chmod -R 777 $LOG_DIR
sh scripts/amd/clean.sh

UNIT_TEST="python/test/unit/language/test_core_amd.py"
# UNIT_TEST="python/test/unit/language/test_core.py::test_empty_kernel[float32]"
# UNIT_TEST="python/test/unit/runtime/test_cache.py::test_compile_in_subproc"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_shift_op[int8-int8-<<]"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_shift_op[int32-int32->>]"
# UNIT_TEST="python/test/unit/language/test_core.py::test_empty_kernel[float32]"
# UNIT_TEST="python/test/unit/language/test_core.py::test_bin_op"
# UNIT_TEST="python/test/unit/language/test_core.py::test_bin_op[float32-float32-+]"
# UNIT_TEST="python/test/unit/language/test_core.py::test_bin_op[int8-float16-%]"
Expand Down
2 changes: 1 addition & 1 deletion test/Target/tritongpu_to_llvmir.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --gfx=90a | FileCheck %s

// == LLVM IR check begin ==
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
Expand Down

0 comments on commit 275fead

Please sign in to comment.