1
1
// Copyright (c) OpenMMLab. All rights reserved.
2
2
3
- #include " src/turbomind/kernels/reduce_kernel_utils.cuh"
4
- #include " src/turbomind/models/llama/llama_utils.h"
5
- #include " src/turbomind/utils/cuda_utils.h"
6
3
#include < cmath>
7
4
#include < cstdio>
8
5
#include < cstdlib>
9
6
#include < cstring>
7
+ #include < type_traits>
8
+ #include < vector>
9
+
10
10
#include < cuda_fp16.h>
11
11
#include < curand_kernel.h>
12
12
#include < thrust/device_vector.h>
13
13
#include < thrust/execution_policy.h>
14
14
#include < thrust/host_vector.h>
15
- #include < vector>
15
+
16
+ #include " src/turbomind/models/llama/llama_utils.h"
17
+ #include " src/turbomind/utils/cuda_utils.h"
16
18
17
19
namespace turbomind {
18
20
19
21
CmpMode compare_mode = kCmpRead ;
20
22
// CmpMode compare_mode = kCmpWrite;
21
23
22
- template <typename T>
23
- struct abs_diff_t {
24
- using type = T;
25
- };
26
-
27
- template <>
28
- struct abs_diff_t <half> {
29
- using type = float ;
30
- };
31
-
32
- template <>
33
- struct abs_diff_t <__nv_bfloat16> {
34
- using type = float ;
35
- };
36
-
37
- template <typename T>
38
- struct abs_diff : public thrust ::unary_function<thrust::tuple<T, T>, typename abs_diff_t <T>::type> {
39
- __host__ __device__ float operator ()(thrust::tuple<T, T> x) const
40
- {
41
- using R = typename abs_diff_t <T>::type;
42
- auto r = R (thrust::get<0 >(x)) - R (thrust::get<1 >(x));
43
- return r < R (0 ) ? -r : r;
44
- }
45
- };
46
-
47
24
template <typename T>
48
25
void CheckNan (const T* ptr, size_t size, std::string key, cudaStream_t stream)
49
26
{
@@ -64,10 +41,8 @@ void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream)
64
41
template <typename T>
65
42
void CmpRead (T* ptr, size_t size, std::string key, cudaStream_t stream)
66
43
{
67
- // wait for b
68
- check_cuda_error (cudaStreamSynchronize (stream));
69
44
// read a from file
70
- thrust::host_vector <T> h_a (size);
45
+ std::vector <T> h_a (size);
71
46
{
72
47
const auto filename = " tmp/" + key + " .cmp" ;
73
48
std::ifstream ifs (filename, std::ios::binary);
@@ -86,17 +61,21 @@ void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream)
86
61
}
87
62
ifs.read ((char *)h_a.data (), sizeof (T) * h_a.size ());
88
63
}
89
- // copy a to device
90
- thrust::device_vector<T> a = h_a;
91
- // create abs(a - b) iterator
92
- thrust::device_ptr<T> dev_ptr (ptr);
93
- auto zip_iter = thrust::make_zip_iterator (thrust::make_tuple (a.begin (), dev_ptr));
94
- auto transform_iter = thrust::make_transform_iterator (zip_iter, abs_diff<T>{});
95
- // sum(abs(a - b))
96
- auto asum = thrust::reduce (thrust::device, transform_iter, transform_iter + size);
64
+ std::vector<T> h_b (size);
65
+ check_cuda_error (cudaMemcpyAsync (h_b.data (), ptr, sizeof (T) * size, cudaMemcpyDefault, stream));
66
+ check_cuda_error (cudaStreamSynchronize (stream));
67
+
68
+ using Tacc = std::conditional_t <std::is_integral_v<T>, int64_t , float >;
69
+
70
+ Tacc asum{};
71
+ for (size_t i = 0 ; i < size; ++i) {
72
+ asum += std::abs ((Tacc)h_a[i] - (Tacc)h_b[i]);
73
+ }
74
+
97
75
std::cerr << key << " : " << asum << " " << asum / size << " \n " ;
98
76
99
77
check_cuda_error (cudaMemcpyAsync (ptr, h_a.data (), sizeof (T) * h_a.size (), cudaMemcpyDefault, stream));
78
+ check_cuda_error (cudaStreamSynchronize (stream));
100
79
}
101
80
102
81
template <typename T>
0 commit comments