Skip to content

Commit

Permalink
[shortfin] Add conversion host ops. (#482)
Browse files Browse the repository at this point in the history
Ops added: `convert`, `round`, `ceil`, `floor`, `trunc`

All ops were implemented to the same pattern, supporting fused
conversion and output array.

Fixes #315
  • Loading branch information
stellaraccident authored Nov 12, 2024
1 parent 35bc60d commit 4759dbc
Show file tree
Hide file tree
Showing 3 changed files with 417 additions and 8 deletions.
314 changes: 306 additions & 8 deletions shortfin/python/array_host_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@ Implemented for dtypes: float16, float32.
A device_array of dtype=int64, allocated on the host and not visible to the device.
)";

static const char DOCSTRING_CONVERT[] =
R"(Does an elementwise conversion from one dtype to another.
The same behavior exists for several conversion ops:
* `convert` : element-wise conversion like a static cast.
* `round` : element-wise nearest integer to the input, rounding halfway cases
away from zero.
* `ceil` : element-wise smallest integer value not less than the input.
* `floor` : element-wise smallest integer value not greater than the input.
* `trunc` : element-wise nearest integer not greater in magnitude than the input.
For nearest-integer conversions (round, ceil, floor, trunc), the input dtype
must be a floating point array, and the output must be a byte-aligned integer
type between 8 and 32 bits.
Args:
input: An input array of a floating point dtype.
dtype: If given, then this is the explicit output dtype.
out: If given, then the results are written to this array. This implies the
output dtype.
device_visible: Whether to make the result array visible to devices. Defaults to
False.
Returns:
A device_array of the requested dtype, or the input dtype if not specified.
)";

static const char DOCSTRING_FILL_RANDN[] =
R"(Fills an array with numbers sampled from the standard ormal distribution.
Expand All @@ -63,7 +91,14 @@ static const char DOCSTRING_RANDOM_GENERATOR[] =
fixed number.
)";

} // namespace
#define SF_UNARY_FUNCTION_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): \
return compute.template operator()<cpp_type>()

#define SF_UNARY_THUNK_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): \
compute.template operator()<cpp_type>(); \
break

struct PyRandomGenerator {
public:
Expand All @@ -85,9 +120,261 @@ struct PyRandomGenerator {
xt::random::default_engine_type engine_;
};

#define SF_UNARY_COMPUTE_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): \
return compute.template operator()<cpp_type>()
// Generic conversion templates, split into a bindable template and functors
// that operate on pre-allocated outputs.
template <typename ConvertFunc>
device_array GenericElementwiseConvert(device_array &input,
std::optional<DType> dtype,
std::optional<device_array> out,
bool device_visible) {
// Argument check and output allocation.
if (!dtype) {
dtype = out ? out->dtype() : input.dtype();
} else {
if (out && out->dtype() != dtype) {
throw std::invalid_argument(
"if both dtype and out are specified, they must match");
}
}
if (!out) {
out.emplace(device_array::for_host(input.device(), input.shape(), *dtype,
device_visible));
}

ConvertFunc::Invoke(input, *dtype, *out);
return *out;
}

// Generic elementwise conversion functor
struct ConvertFunctor {
static void Invoke(device_array &input, DType dtype, device_array &out) {
SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::convert");
auto compute = [&]<typename EltTy>() -> void {
auto input_t = input.map_xtensor<EltTy>();
// Casted output.
#define SF_STORE_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): { \
auto out_t = out.map_xtensor_w<cpp_type>(); \
*out_t = xt::cast<cpp_type>(*input_t); \
break; \
}
switch (dtype) {
SF_STORE_CASE(float16, half_float::half);
SF_STORE_CASE(float32, float);
SF_STORE_CASE(float64, double);
SF_STORE_CASE(uint8, uint8_t);
SF_STORE_CASE(int8, int8_t);
SF_STORE_CASE(uint16, uint16_t);
SF_STORE_CASE(int16, int16_t);
SF_STORE_CASE(uint32, uint32_t);
SF_STORE_CASE(int32, int32_t);
SF_STORE_CASE(uint64, uint64_t);
SF_STORE_CASE(int64, int64_t);
default:
throw std::invalid_argument("Invalid output dtype for convert op");
}

#undef SF_STORE_CASE
};

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(float32, float);
SF_UNARY_THUNK_CASE(float64, double);
SF_UNARY_THUNK_CASE(uint8, uint8_t);
SF_UNARY_THUNK_CASE(int8, int8_t);
SF_UNARY_THUNK_CASE(uint16, uint16_t);
SF_UNARY_THUNK_CASE(int16, int16_t);
SF_UNARY_THUNK_CASE(uint32, uint32_t);
SF_UNARY_THUNK_CASE(int32, uint32_t);
SF_UNARY_THUNK_CASE(uint64, uint64_t);
SF_UNARY_THUNK_CASE(int64, int64_t);
default:
throw std::invalid_argument(fmt::format(
"Unsupported dtype({}) for converting nearest integer op",
dtype.name()));
}
}
};

// Converting round functor.
struct ConvertRoundFunctor {
static void Invoke(device_array &input, DType dtype, device_array &out) {
SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::round");
auto compute = [&]<typename EltTy>() -> void {
auto input_t = input.map_xtensor<EltTy>();
auto rounded = xt::round(*input_t);
if (input.dtype() == dtype) {
// Same type output.
auto out_t = out.map_xtensor_w<EltTy>();
*out_t = rounded;
} else {
// Casted output.
#define SF_STORE_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): { \
auto out_t = out.map_xtensor_w<cpp_type>(); \
*out_t = xt::cast<cpp_type>(rounded); \
break; \
}
switch (dtype) {
SF_STORE_CASE(uint8, uint8_t);
SF_STORE_CASE(int8, int8_t);
SF_STORE_CASE(uint16, uint16_t);
SF_STORE_CASE(int16, int16_t);
SF_STORE_CASE(uint32, uint32_t);
SF_STORE_CASE(int32, int32_t);
default:
throw std::invalid_argument(
"Invalid output dtype for converting nearest integer op");
}
}
#undef SF_STORE_CASE
};

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
"Unsupported dtype({}) for converting nearest integer op",
dtype.name()));
}
}
};

struct ConvertCeilFunctor {
static void Invoke(device_array &input, DType dtype, device_array &out) {
SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::ceil");
auto compute = [&]<typename EltTy>() -> void {
auto input_t = input.map_xtensor<EltTy>();
auto rounded = xt::ceil(*input_t);
if (input.dtype() == dtype) {
// Same type output.
auto out_t = out.map_xtensor_w<EltTy>();
*out_t = rounded;
} else {
// Casted output.
#define SF_STORE_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): { \
auto out_t = out.map_xtensor_w<cpp_type>(); \
*out_t = xt::cast<cpp_type>(rounded); \
break; \
}
switch (dtype) {
SF_STORE_CASE(uint8, uint8_t);
SF_STORE_CASE(int8, int8_t);
SF_STORE_CASE(uint16, uint16_t);
SF_STORE_CASE(int16, int16_t);
SF_STORE_CASE(uint32, uint32_t);
SF_STORE_CASE(int32, int32_t);
default:
throw std::invalid_argument(
"Invalid output dtype for converting nearest integer op");
}
}
#undef SF_STORE_CASE
};

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
"Unsupported dtype({}) for converting nearest integer op",
dtype.name()));
}
}
};

struct ConvertFloorFunctor {
static void Invoke(device_array &input, DType dtype, device_array &out) {
SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::floor");
auto compute = [&]<typename EltTy>() -> void {
auto input_t = input.map_xtensor<EltTy>();
auto rounded = xt::floor(*input_t);
if (input.dtype() == dtype) {
// Same type output.
auto out_t = out.map_xtensor_w<EltTy>();
*out_t = rounded;
} else {
// Casted output.
#define SF_STORE_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): { \
auto out_t = out.map_xtensor_w<cpp_type>(); \
*out_t = xt::cast<cpp_type>(rounded); \
break; \
}
switch (dtype) {
SF_STORE_CASE(uint8, uint8_t);
SF_STORE_CASE(int8, int8_t);
SF_STORE_CASE(uint16, uint16_t);
SF_STORE_CASE(int16, int16_t);
SF_STORE_CASE(uint32, uint32_t);
SF_STORE_CASE(int32, int32_t);
default:
throw std::invalid_argument(
"Invalid output dtype for converting nearest integer op");
}
}
#undef SF_STORE_CASE
};

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
"Unsupported dtype({}) for converting nearest integer op",
dtype.name()));
}
}
};

struct ConvertTruncFunctor {
static void Invoke(device_array &input, DType dtype, device_array &out) {
SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::trunc");
auto compute = [&]<typename EltTy>() -> void {
auto input_t = input.map_xtensor<EltTy>();
auto rounded = xt::trunc(*input_t);
if (input.dtype() == dtype) {
// Same type output.
auto out_t = out.map_xtensor_w<EltTy>();
*out_t = rounded;
} else {
// Casted output.
#define SF_STORE_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): { \
auto out_t = out.map_xtensor_w<cpp_type>(); \
*out_t = xt::cast<cpp_type>(rounded); \
break; \
}
switch (dtype) {
SF_STORE_CASE(uint8, uint8_t);
SF_STORE_CASE(int8, int8_t);
SF_STORE_CASE(uint16, uint16_t);
SF_STORE_CASE(int16, int16_t);
SF_STORE_CASE(uint32, uint32_t);
SF_STORE_CASE(int32, int32_t);
default:
throw std::invalid_argument(
"Invalid output dtype for converting nearest integer op");
}
}
#undef SF_STORE_CASE
};

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
"Unsupported dtype({}) for converting nearest integer op",
dtype.name()));
}
}
};

} // namespace

void BindArrayHostOps(py::module_ &m) {
// Simple op definitions.
Expand Down Expand Up @@ -121,8 +408,8 @@ void BindArrayHostOps(py::module_ &m) {
};

switch (input.dtype()) {
SF_UNARY_COMPUTE_CASE(float16, half_float::half);
SF_UNARY_COMPUTE_CASE(float32, float);
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
SF_UNARY_FUNCTION_CASE(float32, float);
default:
throw std::invalid_argument(
fmt::format("Unsupported dtype({}) for operator argmax",
Expand Down Expand Up @@ -150,15 +437,26 @@ void BindArrayHostOps(py::module_ &m) {
};

switch (out.dtype()) {
SF_UNARY_COMPUTE_CASE(float16, half_float::half);
SF_UNARY_COMPUTE_CASE(float32, float);
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
SF_UNARY_FUNCTION_CASE(float32, float);
default:
throw std::invalid_argument(
fmt::format("Unsupported dtype({}) for operator randn",
out.dtype().name()));
}
},
py::arg("out"), py::arg("generator") = py::none(), DOCSTRING_FILL_RANDN);

// Data-type conversion and rounding.
#define SF_DEF_CONVERT(py_name, target) \
m.def(py_name, target, py::arg("input"), py::kw_only(), \
py::arg("dtype") = py::none(), py::arg("out") = py::none(), \
py::arg("device_visible") = false, DOCSTRING_CONVERT)
SF_DEF_CONVERT("convert", GenericElementwiseConvert<ConvertFunctor>);
SF_DEF_CONVERT("ceil", GenericElementwiseConvert<ConvertCeilFunctor>);
SF_DEF_CONVERT("floor", GenericElementwiseConvert<ConvertFloorFunctor>);
SF_DEF_CONVERT("round", GenericElementwiseConvert<ConvertRoundFunctor>);
SF_DEF_CONVERT("trunc", GenericElementwiseConvert<ConvertTruncFunctor>);
}

} // namespace shortfin::python
10 changes: 10 additions & 0 deletions shortfin/python/shortfin/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@

# Ops.
argmax = _sfl.array.argmax
ceil = _sfl.array.ceil
convert = _sfl.array.convert
fill_randn = _sfl.array.fill_randn
floor = _sfl.array.floor
round = _sfl.array.round
trunc = _sfl.array.trunc
RandomGenerator = _sfl.array.RandomGenerator

__all__ = [
Expand Down Expand Up @@ -82,7 +87,12 @@
"DType",
# Ops.
"argmax",
"ceil",
"convert",
"fill_randn",
"floor",
"round",
"trunc",
"RandomGenerator",
]

Expand Down
Loading

0 comments on commit 4759dbc

Please sign in to comment.