Skip to content

Commit

Permalink
fix on sm<80
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Dec 5, 2023
1 parent 3c48119 commit 61d99cf
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/turbomind/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ struct ReluActivation<__nv_bfloat162> {
static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
{
const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f);
return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
return ::make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
}
};
#endif
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/llama_decoder_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp16.h>
Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/utils/cuda_type_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,12 @@ __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
template<>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return fabs(val);
return fabs(cuda_cast<float>(val));
}
template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return make_bfloat162(fabs(val.x), fabs(val.y));
return make_bfloat162(fabs(cuda_cast<float>(val.x)), fabs(cuda_cast<float>(val.y)));
}
#endif

Expand Down

0 comments on commit 61d99cf

Please sign in to comment.