Skip to content

Commit

Permalink
TruncatedMod and Unique Op fixes (#2750)
Browse files Browse the repository at this point in the history
  • Loading branch information
kranipa authored Sep 26, 2024
1 parent 5961809 commit 7c79153
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 22 deletions.
1 change: 0 additions & 1 deletion itex/core/kernels/gpu/cwise_op_mod.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.

namespace itex {

REGISTER2(BinaryOp, GPU, "TruncateMod", functor::safe_mod, int32, int64);
REGISTER3(BinaryOp, GPU, "TruncateMod", functor::fmod, float, Eigen::bfloat16,
Eigen::half);

Expand Down
31 changes: 16 additions & 15 deletions itex/core/kernels/gpu/unique_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,24 +440,25 @@ Status DispatchRadixSort(OpKernelContext* context, const int32_t size,
keys_out = mutable_keys_out;
}

if (size <= KEYS_PER_ITEM * GROUP_SIZE) {
using Rsortor = GroupRadixSortor<
KeyT, /*key_per_item==*/KEYS_PER_ITEM, /*group_size=*/GROUP_SIZE,
/*subgroup_size =*/SUBGROUP_SIZE, sycl::group<1>, ValueT>;
// Compute the required local memory size
size_t local_memory_size = Rsortor::LocalStorage::SIZE;
const int32_t num_wg = 1;
sycl::range<1> global_range(num_wg * GROUP_SIZE);
sycl::range<1> local_range(GROUP_SIZE);

return LaunchRadixSortKernel<KeyT, ValueT, KEYS_PER_ITEM, SUBGROUP_SIZE,
Rsortor>(
stream, size, keys_in, indices_in, keys_out, indices_out, global_range,
local_range, local_memory_size, num_bits);
} else {
if (size > KEYS_PER_ITEM * GROUP_SIZE &&
!std::is_floating_point_v<KeyT>) { // DeviceRadixSort will write OOM for
// float/double point types.
return DispatchDeviceRadixSort(context, keys_in, indices_in, keys_out,
indices_out, size);
}
using Rsortor = GroupRadixSortor<
KeyT, /*key_per_item==*/KEYS_PER_ITEM, /*group_size=*/GROUP_SIZE,
/*subgroup_size =*/SUBGROUP_SIZE, sycl::group<1>, ValueT>;
// Compute the required local memory size
size_t local_memory_size = Rsortor::LocalStorage::SIZE;
const int32_t num_wg = 1;
sycl::range<1> global_range(num_wg * GROUP_SIZE);
sycl::range<1> local_range(GROUP_SIZE);

return LaunchRadixSortKernel<KeyT, ValueT, KEYS_PER_ITEM, SUBGROUP_SIZE,
Rsortor>(
stream, size, keys_in, indices_in, keys_out, indices_out, global_range,
local_range, local_memory_size, num_bits);
}

template <typename InputIteratorT, typename OutputIteratorT, typename BinaryOp>
Expand Down
10 changes: 4 additions & 6 deletions test/benchmark/test_TruncateMod.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@

try:
from intel_extension_for_tensorflow.python.test_func import test
INT_COMPUTE_TYPE = [dtypes.int32, dtypes.int64]
INT_COMPUTE_TYPE = [dtypes.int32]
except ImportError:
from tensorflow.python.platform import test
INT_COMPUTE_TYPE = [dtypes.int32, dtypes.int64]
INT_COMPUTE_TYPE = [dtypes.int32]

ITERATION = 5

class TruncateModTest(test.TestCase):
def _test_impl(self, x_size, y_size, dtype):
x = np.random.normal(size=x_size)
x = constant_op.constant(x, dtype=dtype)
y = np.random.normal(size=x_size)
y = constant_op.constant(y, dtype=dtype)
x = tf.random.uniform(shape=x_size, minval=0, maxval=100, dtype=dtype)
y = tf.random.uniform(shape=y_size, minval=1, maxval=100, dtype=dtype)
flush_cache()
out_gpu = tf.raw_ops.TruncateMod(x=x, y=y)

Expand Down

0 comments on commit 7c79153

Please sign in to comment.