Skip to content

Commit

Permalink
update base docker image
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Jun 5, 2024
1 parent 69ce080 commit a6af475
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_release-2.1.2"
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_staging"

ARG COMMON_WORKDIR=/app

Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/fp8/gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA
auto d_scaleD = scaleD.data_ptr();

auto handle = at::cuda::getCurrentCUDABlasLtHandle();
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto stream = at::cuda::getCurrentCUDAStream();

hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(0);
Expand Down Expand Up @@ -218,7 +218,7 @@ torch::Tensor fp8_gemm_16(
auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr();

auto handle = at::cuda::getCurrentCUDABlasLtHandle();
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto stream = at::cuda::getCurrentCUDAStream();

hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(0);
Expand Down

0 comments on commit a6af475

Please sign in to comment.