Skip to content

Commit

Permalink
Fixes for x86 CI workflow (#26)
Browse files Browse the repository at this point in the history
* Fix RelWithDebInfo build.

Signed-off-by: Ilya Enkovich <[email protected]>

* Skip fp8 cast tests on CPU.

Signed-off-by: Ilya Enkovich <[email protected]>

* Fix segfault.

Signed-off-by: Ilya Enkovich <[email protected]>

* [BACKEND] Update LLVM version to llvm/llvm-project@765206e (triton-lang#4059)

* Add -s option to pytest run.

Signed-off-by: Ilya Enkovich <[email protected]>

* Add a workaround for LLVM bug causing test failure on Skylake CPU.

Signed-off-by: Ilya Enkovich <[email protected]>

* Add a workaround for LLVM fpext bug causing test failure on Skylake CPU.

Signed-off-by: Ilya Enkovich <[email protected]>

* Fix formatting.

Signed-off-by: Ilya Enkovich <[email protected]>

---------

Signed-off-by: Ilya Enkovich <[email protected]>
Co-authored-by: Pablo Zimmermann <[email protected]>
  • Loading branch information
ienkovich and karupayun authored Jun 18, 2024
1 parent 8f21d05 commit ede8a8e
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ jobs:
- name: Run python unit tests
run: |
python -m pytest -n 32 --device cpu python/test/unit/language/test_core.py -m cpu
python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3a8316216807d64a586b971f51695e23883331f7
765206e050453018e861637a08a4520f29238074
5 changes: 4 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1615,7 +1615,10 @@ void init_triton_ir(py::module &&m) {
});

::llvm::DebugFlag = true;
::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
// For release build setCurrentDebugTypes is a macro, so avoid
// namespace prefix
using namespace llvm;
setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
}

if (failed(self.run(mod.getOperation())))
Expand Down
15 changes: 15 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,15 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"):
pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.')

if is_cpu() and (dtype_x in torch_float8_dtypes or dtype_z in torch_float8_dtypes):
pytest.skip(f'test_cast{(dtype_x, dtype_z)} is not supported on CPU.')

# fptrunc fp32->fp16 is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/95274
# TODO: remove the change after the bug is fixed.
if is_cpu() and dtype_x == "float32" and dtype_z == "float16":
size = 512

# bf16 vector cast is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/92471
# TODO: Remove the change after the bug is fixed.
Expand Down Expand Up @@ -2138,6 +2147,12 @@ def kernel(X, Z, BLOCK: tl.constexpr):
def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested

# fpext fp16->fp32 is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/95278
# TODO: remove the change after the bug is fixed.
if is_cpu() and dtype_str == "float16":
shape = (min(shape[0], 512), min(shape[1], 512))

@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr,
AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr):
Expand Down
27 changes: 13 additions & 14 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@
import triton
import triton.language as tl


BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 32
GROUP_SIZE_M = 8
USE_GPU = True


@triton.jit
def matmul_kernel(
# Pointers to matrices
Expand Down Expand Up @@ -227,7 +227,7 @@ def matmul_kernel(
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

# Convert the accumulator to the output matrix C's type if needed.
c = accumulator

Expand All @@ -236,14 +236,13 @@ def matmul_kernel(
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]

#TODO: Currently masked load is not supported yet.
#c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
#tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c)



# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
Expand All @@ -256,9 +255,10 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
M, K = a.shape
K, N = b.shape
#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
if c is None:
# Allocates output.
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
assert c.shape == (M, N), "Incompatible dimensions"
Expand All @@ -270,9 +270,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K, #
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, #
GROUP_SIZE_M=GROUP_SIZE_M, #
)
return c
Expand All @@ -298,7 +296,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ TritonCPU and TorchCPU match")
else:
print("❌ TritonCPU and TorchCPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}')
print("❌ TritonCPU and TorchCPU differ, the maximum difference is "
f'{torch.max(torch.abs(triton_output - torch_output))}')

# %%
# Benchmark
Expand Down Expand Up @@ -326,13 +325,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ TritonGPU and TorchGPU match")
else:
print("❌ TritonGPU and TorchGPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}')
print("❌ TritonGPU and TorchGPU differ, the maximum difference is "
f'{torch.max(torch.abs(triton_output - torch_output))}')

LINE_VALS += ['triton-gpu', 'torch-gpu']
LINE_NAMES += ['TritonGPU', 'TorchGPU']
LINE_STYLES += [('yellow', '-'), ('red', '-')]


# %%
# Seems like we're good to go!

Expand All @@ -359,7 +358,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))

def benchmark(M, N, K, provider):
import os

Expand All @@ -383,7 +381,8 @@ def benchmark(M, N, K, provider):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles)
elif provider == 'torch-cpu':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles,
is_cpu=True)
elif provider == 'triton-cpu-single':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
Expand Down
4 changes: 3 additions & 1 deletion third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,12 @@ struct ReduceScanOpConversionBase : public OpConversionPattern<OpT> {
createShuffleDummies(Location loc, ValueRange inputs,
ConversionPatternRewriter &rewriter) const {
if (shuffleDummies.empty()) {
SmallVector<int64_t, 1> dummyShape({1});
for (auto val : inputs) {
auto ty = cast<VectorType>(val.getType());
shuffleDummies.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(ty.cloneWith(1, ty.getElementType()))));
loc, rewriter.getZeroAttr(
ty.cloneWith(dummyShape, ty.getElementType()))));
}
}
return shuffleDummies;
Expand Down

0 comments on commit ede8a8e

Please sign in to comment.