Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,58 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "random_op_impl.cuh"

__global__ void SetupKernel(int seed, curandState *globalState) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
curand_init(seed, id, 0, &globalState[id]);
}

template <typename T>
__global__ void NormalKernel(int seed, curandState *globalState, T *output, size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]);
output[i] = (T)curand_normal(&globalState[i]);
__global__ void NormalKernel(curandState *globalState, T *output, size_t count) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];

while (id < count) {
globalState[id] = localState;
output[id] = (T)curand_normal(&localState);
id += blockDim.x * gridDim.x;
}
return;
}

__device__ bool dev_error_res = false;

template <typename T>
__global__ void UniformIntKernel(int seed, curandState *globalState, T *input1, size_t input_size_1,
__global__ void UniformIntKernel(curandState *globalState, T *input1, size_t input_size_1,
T *input2, size_t input_size_2, T *output, size_t count) {
if (!(input1[0] < input2[0])) {
dev_error_res = false;
return;
}
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]);
output[i] = (T)(curand_uniform(&globalState[i]) * (input2[0] - input1[0])) + input1[0];

auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];

while (id < count) {
globalState[id] = localState;
output[id] = (T)(curand_uniform(&localState) * (input2[0] - input1[0])) + input1[0];
id += blockDim.x * gridDim.x;
}

dev_error_res = true;
return;
}

template <typename T>
__global__ void UniformRealKernel(int seed, curandState *globalState, T *output, size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]);
output[i] = (T)curand_uniform(&globalState[i]);
}
return;
}
__global__ void UniformRealKernel(curandState *globalState, T *output, size_t count) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];

template<typename S>
__global__ void TruncatedNormalKernel(int seed, curandState *globalState, S *output, size_t count) {
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
S random_data;
curand_init(seed, i, 0, &globalState[i]);
random_data = (S)curand_normal(&globalState[i]);
do {
random_data = (S)curand_normal(&globalState[i]);
}while(random_data < -(S)0.2 || random_data > (S)0.2);
output[i] = random_data;
while (id < count) {
globalState[id] = localState;
output[id] = (T)curand_uniform(&localState);
id += blockDim.x * gridDim.x;
}
return;
}
Expand All @@ -75,7 +80,10 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
} else {
RNG_seed = static_cast<int>(rd());
}
NormalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count);

SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
NormalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(globalState, output, count);

return;
}

Expand All @@ -92,8 +100,9 @@ bool UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t
RNG_seed = static_cast<int>(rd());
}
bool host_error_res = false;
SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
UniformIntKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>
(RNG_seed, globalState, input1, input_size_1, input2, input_size_2, output, count);
(globalState, input1, input_size_1, input2, input_size_2, output, count);
cudaDeviceSynchronize();
cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool));
return host_error_res;
Expand All @@ -110,22 +119,8 @@ void UniformReal(int seed, int seed2, curandState *globalState, T *output, size_
} else {
RNG_seed = static_cast<int>(rd());
}
UniformRealKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count);
return;
}

template<typename S>
void TruncatedNormal(int seed, int seed2, curandState *globalState, S *output, size_t count, cudaStream_t cuda_stream) {
int RNG_seed = 0;
std::random_device rd;
if (seed2 != 0) {
RNG_seed = seed2;
} else if (seed != 0) {
RNG_seed = seed;
} else {
RNG_seed = static_cast<int>(rd());
}
TruncatedNormalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count);
SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
UniformRealKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(globalState, output, count);
return;
}

Expand All @@ -143,9 +138,3 @@ template CUDA_LIB_EXPORT void UniformReal<float>(int seed, int seed2, curandStat
float *output, size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void UniformReal<int>(int seed, int seed2, curandState *globalState,
int *output, size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void TruncatedNormal<half>(int seed, int seed2, curandState *globalState,
half *output, size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void TruncatedNormal<float>(int seed, int seed2, curandState *globalState,
float *output, size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void TruncatedNormal<double>(int seed, int seed2, curandState *globalState,
double *output, size_t count, cudaStream_t cuda_stream);