Skip to content

Commit 1370df7

Browse files
authored
[ROCm] Support for building DeepRec on ROCm2.10.0. (#302)
1 parent cf3a3b1 commit 1370df7

19 files changed

+99
-61
lines changed

tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using nvinfer1::PluginFormat;
3838

3939
template <typename SrcT, typename DstT>
4040
__global__ void Cast(const SrcT* input, int num_elements, DstT* output) {
41-
for (int i : CudaGridRangeX(num_elements)) {
41+
for (int i : GpuGridRangeX(num_elements)) {
4242
output[i] = static_cast<DstT>(input[i]);
4343
}
4444
}

tensorflow/core/framework/resource_mgr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ Status ResourceMgr::DoCreate(const string& container, TypeIndex type,
178178
// key can contain a StringPiece that borrows from the string in the value.
179179
ResourceAndName resource_and_name(resource, name);
180180
StringPiece borrowed_name(*resource_and_name.name);
181-
Container::value_type key_and_value(Key(type.hash_code(), borrowed_name),
181+
Container::value_type key_and_value(Key(type.hash_code(), std::string(borrowed_name.data(), borrowed_name.size())),
182182
std::move(resource_and_name));
183183

184184
if ((*b)->insert(std::move(key_and_value)).second) {

tensorflow/core/kernels/bias_grad_ali_op_cpu.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ void TwoColumnWiseReduction(const float* A, int m, int n, int lda, float* Y,
195195

196196
void MultipleColumnWiseReduction(const float* A, int m, int n, int lda,
197197
float* Y, bool overwrite = true) {
198+
#if defined(__GNUC__) && (__GNUC__ >6)
198199
#ifdef __AVX512F__
199200
if (n >= block_size_avx512) {
200201
int block_num = n / block_size_avx512;
@@ -241,6 +242,7 @@ void MultipleColumnWiseReduction(const float* A, int m, int n, int lda,
241242
}
242243
return;
243244
}
245+
#endif
244246
#endif
245247
// Normal cases.
246248
if (overwrite) {
@@ -270,6 +272,7 @@ void SumIntoOneRow(const float* A, int m, int n, int lda, float* Y,
270272
}
271273
}
272274

275+
#if defined(__GNUC__) && (__GNUC__ >6)
273276
#ifdef __AVX512F__
274277
void ColumnParallel_512(const CPUDevice& d, float* input_data, float* output_data,
275278
int sum_size, int channel) {
@@ -355,6 +358,7 @@ void ColumnParallel_256(const CPUDevice& d, float* input_data, float* output_dat
355358
_mm256_mask_storeu_ps(output_data + offset_block_end, mask, sum);
356359
}
357360
#endif
361+
#endif
358362

359363
template <typename T>
360364
void BiasGrad2DInternal(const CPUDevice& d, typename TTypes<T>::ConstFlat input,
@@ -453,6 +457,7 @@ struct BiasGrad2D<CPUDevice, float> {
453457
Eigen::DSizes<int, 2>& two_dims,
454458
typename TTypes<float>::Flat output) {
455459
auto thread_num = d.numThreads();
460+
#if defined(__GNUC__) && (__GNUC__ >6)
456461
#ifdef __AVX512F__
457462
if (two_dims[1] >= block_size_avx512 * thread_num) {
458463
ColumnParallel_512(d, (float*)(input.data()), output.data(),
@@ -465,6 +470,7 @@ struct BiasGrad2D<CPUDevice, float> {
465470
two_dims[0], two_dims[1]);
466471
return;
467472
}
473+
#endif
468474
#endif
469475
BiasGrad2DInternal<float>(d, input, two_dims, output);
470476
}

tensorflow/core/kernels/cwise_op_gpu_select.cu.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
16+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
1717

1818
#define EIGEN_USE_GPU
1919

@@ -111,7 +111,7 @@ struct BatchSelectFunctor<GPUDevice, T> {
111111
template <typename T>
112112
__global__ void Select4ElementThenScalarFunctorKernel(
113113
const bool *c, const T *t, const T *e, size_t num, T *o) {
114-
CUDA_1D_KERNEL_LOOP(i, num) {
114+
GPU_1D_KERNEL_LOOP(i, num) {
115115
if (c[i]) {
116116
o[i] = t[0];
117117
} else {
@@ -123,7 +123,7 @@ __global__ void Select4ElementThenScalarFunctorKernel(
123123
template <typename T>
124124
__global__ void Select4ElementElseScalarFunctorKernel(
125125
const bool *c, const T *t, const T *e, size_t num, T *o) {
126-
CUDA_1D_KERNEL_LOOP(i, num) {
126+
GPU_1D_KERNEL_LOOP(i, num) {
127127
if (c[i]) {
128128
o[i] = t[i];
129129
} else {
@@ -174,7 +174,7 @@ struct Select4ElementScalarFunctor<GPUDevice, T> {
174174
template <typename T>
175175
__global__ void BatchSelect4BroadcastingThenScalarFunctorKernel(
176176
const bool *c, const T *t, const T *e, size_t batch, size_t batch_size, T *o) {
177-
CUDA_1D_KERNEL_LOOP(i, batch * batch_size) {
177+
GPU_1D_KERNEL_LOOP(i, batch * batch_size) {
178178
size_t offset = i / batch_size;
179179
if (c[offset]) {
180180
o[i] = t[0];
@@ -187,7 +187,7 @@ __global__ void BatchSelect4BroadcastingThenScalarFunctorKernel(
187187
template <typename T>
188188
__global__ void BatchSelect4BroadcastingElseScalarFunctorKernel(
189189
const bool *c, const T *t, const T *e, size_t batch, size_t batch_size, T *o) {
190-
CUDA_1D_KERNEL_LOOP(i, batch * batch_size) {
190+
GPU_1D_KERNEL_LOOP(i, batch * batch_size) {
191191
size_t offset = i / batch_size;
192192
if (c[offset]) {
193193
o[i] = t[i];

tensorflow/core/kernels/cwise_op_select.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ limitations under the License.
1515

1616
#define EIGEN_USE_THREADS
1717

18-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
18+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
1919
#define EIGEN_USE_GPU
2020
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2121

@@ -375,7 +375,7 @@ class SelectV2Op : public OpKernel {
375375

376376
TF_CALL_ALL_TYPES(REGISTER_SELECT);
377377

378-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
378+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
379379

380380
// Registration of the GPU implementations.
381381
#define REGISTER_SELECT_GPU(type) \
@@ -524,6 +524,7 @@ struct SelectFunctorBase<Device, float> {
524524
typename TTypes<bool>::ConstFlat cond_flat,
525525
typename TTypes<float>::ConstFlat then_flat,
526526
typename TTypes<float>::ConstFlat else_flat) {
527+
#if defined(__GNUC__) && (__GNUC__ >6)
527528
#ifdef __AVX512F__
528529
const size_t num = cond_flat.size();
529530
const bool* c = cond_flat.data();
@@ -557,6 +558,7 @@ struct SelectFunctorBase<Device, float> {
557558
Assign(d, out, out);
558559
#else
559560
Assign(d, out, cond_flat.select(then_flat, else_flat));
561+
#endif
560562
#endif
561563
}
562564
};
@@ -834,6 +836,7 @@ struct BatchSelectFunctor<CPUDevice, float> {
834836
const float* t = then_flat_outer_dims.data();
835837
const float* e = else_flat_outer_dims.data();
836838

839+
#if defined(__GNUC__) && (__GNUC__ >6)
837840
#ifdef __AVX512F__
838841
size_t quotient = batch_size / float_alignment;
839842
int remainder = batch_size - (quotient * float_alignment);
@@ -888,6 +891,7 @@ struct BatchSelectFunctor<CPUDevice, float> {
888891
}
889892
}
890893
};
894+
#endif
891895
#endif
892896
auto cost = Eigen::TensorOpCost(sizeof(float) * batch_size * 2, // ld bytes
893897
sizeof(float) * batch_size, // st bytes

tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
307307
c, status,
308308
errors::Internal("Failed to launch copy from device to host."), done);
309309

310+
#if GOOGLE_CUDA
310311
cudaDeviceSynchronize();
312+
#elif TENSORFLOW_USE_ROCM
313+
hipDeviceSynchronize();
314+
#endif
311315

312316
OpOutputList outputs;
313317
this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);

tensorflow/core/kernels/dynamic_stitch_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ class DynamicStitchOpImplBase : public OpKernel {
189189
}
190190
};
191191

192-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
192+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
193193

194194
template <typename T>
195195
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
@@ -570,7 +570,7 @@ TF_CALL_variant(REGISTER_DYNAMIC_STITCH);
570570
TF_CALL_QUANTIZED_TYPES(REGISTER_DYNAMIC_STITCH);
571571
#undef REGISTER_DYNAMIC_STITCH
572572

573-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
573+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
574574
#define REGISTER_DYNAMIC_STITCH_GPU(type) \
575575
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
576576
.Device(DEVICE_GPU) \

tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
16+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
1717

1818
#define EIGEN_USE_GPU
1919

@@ -52,7 +52,7 @@ __global__ void DynamicStitchKernelV2(const int32 slice_size,
5252
const int32* input_indices,
5353
T** input_ptrs,
5454
T* output) {
55-
CUDA_1D_KERNEL_LOOP(output_index, output_size) {
55+
GPU_1D_KERNEL_LOOP(output_index, output_size) {
5656
const int32 slice_id = output_index / slice_size;
5757
const int32 slice_offset = output_index % slice_size;
5858
const int32 input_index = input_indices[slice_id];
@@ -65,7 +65,7 @@ __global__ void DynamicStitchKernelV2(const int32 slice_size,
6565
__global__ void InitializeIndicesFlatWork(int32* indices_flat_work,
6666
const int32 flat_work_size,
6767
const int32 val) {
68-
CUDA_1D_KERNEL_LOOP(output_index, flat_work_size) {
68+
GPU_1D_KERNEL_LOOP(output_index, flat_work_size) {
6969
indices_flat_work[output_index] = val;
7070
}
7171
}
@@ -80,7 +80,7 @@ __global__ void DynamicStitchPrepKernel(const int32* indices_flat,
8080
const int32 slice_size,
8181
const int32 output_size) {
8282

83-
CUDA_1D_KERNEL_LOOP(output_index, output_size) {
83+
GPU_1D_KERNEL_LOOP(output_index, output_size) {
8484
// for indices
8585
indices_flat_work[indices_flat[output_index]] = output_index;
8686
// find the partition id
@@ -123,7 +123,7 @@ void DynamicStitchGPUImplV2(const Eigen::GpuDevice& gpu_device,
123123
Tensor* input_ptrs,
124124
T* output) {
125125
const int32 output_size = first_dim_size * slice_size;
126-
auto config = GetCudaLaunchConfig(output_size, gpu_device);
126+
auto config = GetGpuLaunchConfig(output_size, gpu_device);
127127

128128
DynamicStitchKernelV2<T>
129129
<<<config.block_count, config.thread_per_block, 0, gpu_device.stream()>>>(
@@ -146,13 +146,13 @@ void DynamicStitchGPUPrep(const Eigen::GpuDevice& gpu_device,
146146
const int32 first_dim_size) {
147147

148148
// initialize indices_flat_work by -1
149-
auto config = GetCudaLaunchConfig(first_dim_size, gpu_device);
149+
auto config = GetGpuLaunchConfig(first_dim_size, gpu_device);
150150
InitializeIndicesFlatWork
151151
<<<config.block_count, config.thread_per_block, 0, gpu_device.stream()>>>(
152152
indices_flat_work->flat<int32>().data(),
153153
first_dim_size, -1);
154154

155-
config = GetCudaLaunchConfig(data_elements_size, gpu_device);
155+
config = GetGpuLaunchConfig(data_elements_size, gpu_device);
156156
DynamicStitchPrepKernel<T>
157157
<<<config.block_count, config.thread_per_block, 0, gpu_device.stream()>>>(
158158
indices_flat->flat<int32>().data(),

tensorflow/core/kernels/reshape_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ REGISTER_KERNEL_BUILDER(Name("Reshape")
8686
#undef REGISTER_SYCL_KERNEL
8787
#endif // TENSORFLOW_USE_SYCL
8888

89-
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
89+
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) //|| \
9090
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
9191
// A special GPU kernel for int32.
9292
// TODO(b/25387198): Also enable int32 in device memory. This kernel

tensorflow/core/kernels/segment_reduction_ops.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class SegmentReductionOp : public OpKernel {
203203
}
204204
};
205205

206-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
206+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
207207
// SegmentSumGPUOp is a segment sum operator implemented for GPU only.
208208
// TODO: This implementation of SegmentSumGPUOp is sometimes slower than
209209
// its unsorted counterpart (mostly when problem size is small).
@@ -352,7 +352,7 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
352352
#undef REGISTER_REAL_CPU_KERNELS_ALL
353353
#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
354354

355-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
355+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
356356
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
357357
REGISTER_KERNEL_BUILDER(Name("SegmentSum") \
358358
.Device(DEVICE_GPU) \
@@ -637,7 +637,7 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
637637
#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
638638
#undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
639639

640-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
640+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
641641
#define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT( \
642642
name, type, index_type, initial_value_functor, reduction_kernel_functor) \
643643
REGISTER_KERNEL_BUILDER( \

tensorflow/core/kernels/segment_reduction_ops.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class OpKernelContext;
3434

3535
namespace functor {
3636

37-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
3838
typedef Eigen::GpuDevice GPUDevice;
3939
// Functor for SegmentSumGPUOp.
4040
// output_rows: the number of output segments (unique segment ids in
@@ -88,7 +88,7 @@ struct SetValueDefault {
8888
Tensor* target,
8989
T default_value);
9090
};
91-
#endif
91+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
9292

9393
template <typename Device, typename T, typename Index, typename InitialValueF,
9494
typename ReductionF>
@@ -99,7 +99,7 @@ struct UnsortedSegmentFunctor {
9999
typename TTypes<T, 2>::Tensor output);
100100
};
101101

102-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
102+
#if GOOGLE_CUDA //|| TENSORFLOW_USE_ROCM
103103
// reduction functors for the gpu
104104
template <typename T>
105105
struct SumOpGpu {

0 commit comments

Comments
 (0)