forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCTensorMath.cu
140 lines (116 loc) · 3.79 KB
/
THCTensorMath.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include "THCTensorMath.h"
#include "THCGeneral.h"
#include "THCTensorCopy.h"
#include "THCApply.cuh"
#include "THCNumerics.cuh"
#include "THCTensorMath.cuh"
#include "THCThrustAllocator.cuh"
#include "THCTensor.hpp"
#include "THCStream.h"
#include <thrust/copy.h>
#include <thrust/count.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <thrust/sequence.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>
#if CUDA_VERSION >= 7000
#include <thrust/system/cuda/execution_policy.h>
#endif
#include <cfloat>
template <typename T>
struct TensorFillOp {
TensorFillOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* v) { *v = val; }
const T val;
};
// copypasta from https://github.com/thrust/thrust/blob/master/examples/strided_range.cu
template <typename Iterator>
class strided_range
{
public:
typedef typename thrust::iterator_difference<Iterator>::type difference_type;
struct stride_functor : public thrust::unary_function<difference_type,
difference_type>
{
difference_type stride;
stride_functor(difference_type stride)
: stride(stride) {}
__host__ __device__
difference_type operator()(const difference_type& i) const
{
return stride * i;
}
};
typedef typename thrust::counting_iterator<difference_type> CountingIterator;
typedef typename thrust::transform_iterator<stride_functor, CountingIterator> TransformIterator;
typedef typename thrust::permutation_iterator<Iterator,TransformIterator> PermutationIterator;
// type of the strided_range iterator
typedef PermutationIterator iterator;
// construct strided_range for the range [first,last)
strided_range(Iterator first, Iterator last, difference_type stride)
: first(first), last(last), stride(stride) {}
iterator begin(void) const
{
return PermutationIterator(first,
TransformIterator(CountingIterator(0),
stride_functor(stride)));
}
iterator end(void) const
{
return begin() + ((last - first) + (stride - 1)) / stride;
}
protected:
Iterator first;
Iterator last;
difference_type stride;
};
struct idx_functor
{
int64_t div;
int64_t size;
__host__ __device__
idx_functor(int64_t div, int64_t size) : div(div), size(size) {}
__host__ __device__
int64_t operator()(int64_t val) {
return (val / div) % size + TH_INDEX_BASE;
}
};
template <typename T>
struct NonZeroOp
{
NonZeroOp() {}
__host__ __device__ bool operator()(T lhs) const {
if (THCNumerics<T>::ne(lhs, ScalarConvert<float, T>::to(0.0))) {
return true;
} else {
return false;
}
}
};
template<typename T, typename accT = T>
struct LinspaceOp {
__host__ __device__ LinspaceOp(accT start, accT step):
start_(start), step_(step) { }
__device__ __forceinline__ T operator()(ptrdiff_t index) {
accT increment = THCNumerics<accT>::mul(step_, ScalarConvert<ptrdiff_t,accT>::to(index));
accT value = THCNumerics<accT>::add(start_, increment);
return ScalarConvert<accT,T>::to(value);
}
const accT start_, step_;
};
template<typename T, typename accT = T>
struct LogspaceOp {
__host__ __device__ LogspaceOp(accT start, accT step):
start_(start), step_(step) { }
__device__ __forceinline__ T operator()(ptrdiff_t index) {
accT increment = THCNumerics<accT>::mul(step_, ScalarConvert<ptrdiff_t,accT>::to(index));
accT value = THCNumerics<accT>::exp10(THCNumerics<accT>::add(start_, increment));
return ScalarConvert<accT,T>::to(value);
}
const accT start_, step_;
};
#include "generic/THCTensorMath.cu"
#include "THCGenerateAllTypes.h"