diff --git a/docs/website/pages/docs/loading.mdx b/docs/website/pages/docs/loading.mdx index fb69882..850db4c 100644 --- a/docs/website/pages/docs/loading.mdx +++ b/docs/website/pages/docs/loading.mdx @@ -307,7 +307,7 @@ int main() // Create an input tensor uint64_t shape[]{1}; auto tensor = carton::Tensor(carton::DataType::kString, shape); - tensor.set_string(0, "Today is a good [MASK]."); + tensor.at(0) = "Today is a good [MASK]."; // Create a map of inputs std::unordered_map inputs; @@ -322,7 +322,7 @@ int main() const auto scores_data = static_cast(scores.data()); - std::cout << "Got output token: " << tokens.get_string(0) << std::endl; + std::cout << "Got output token: " << tokens.at(0) << std::endl; std::cout << "Got output scores: " << scores_data[0] << std::endl; } ``` diff --git a/docs/website/pages/quickstart.mdx b/docs/website/pages/quickstart.mdx index b5a7945..f13a0b2 100644 --- a/docs/website/pages/quickstart.mdx +++ b/docs/website/pages/quickstart.mdx @@ -63,7 +63,7 @@ int main() // Create an input tensor uint64_t shape[]{1}; auto tensor = carton::Tensor(carton::DataType::kString, shape); - tensor.set_string(0, "Today is a good [MASK]."); + tensor.at(0) = "Today is a good [MASK]."; // Create a map of inputs std::unordered_map inputs; @@ -78,7 +78,7 @@ int main() const auto scores_data = static_cast(scores.data()); - std::cout << "Got output token: " << tokens.get_string(0) << std::endl; + std::cout << "Got output token: " << tokens.at(0) << std::endl; std::cout << "Got output scores: " << scores_data[0] << std::endl; } ``` diff --git a/source/carton-bindings-c/src/tensor.rs b/source/carton-bindings-c/src/tensor.rs index adbfed4..59e5b58 100644 --- a/source/carton-bindings-c/src/tensor.rs +++ b/source/carton-bindings-c/src/tensor.rs @@ -185,8 +185,9 @@ impl CartonTensor { } } - /// For a string tensor, get a string at a particular (flattened) index into the tensor. + /// For a string tensor, get a string at a particular flattened index into the tensor. /// Note: any returned pointers are only valid until the tensor is modified. + /// Note: `index` should take strides into account. #[no_mangle] pub extern "C" fn carton_tensor_get_string( &self, @@ -196,7 +197,10 @@ impl CartonTensor { ) { if let carton_core::types::Tensor::String(v) = &self.inner { let view = v.view(); - let item = view.iter().nth(index as _).unwrap(); + let ptr = view.as_ptr(); + + // TODO: assert that the index is in bounds + let item = unsafe { &*ptr.add(index as _) }; unsafe { *string_out = item.as_ptr() as *const _; *strlen_out = item.len() as _; @@ -206,13 +210,18 @@ impl CartonTensor { } } - /// For a string tensor, set a string at a particular (flattened) index. + /// For a string tensor, set a string at a particular flattened index. + /// Copies the null-terminated string `string` into the tensor at the specified index. + /// Note: `index` should take strides into account. #[no_mangle] pub extern "C" fn carton_tensor_set_string(&mut self, index: u64, string: *const c_char) { let new = unsafe { CStr::from_ptr(string).to_str().unwrap().to_owned() }; self.carton_tensor_set_string_inner(index, new); } + /// For a string tensor, set a string at a particular flattened index. + /// Copies `strlen` bytes of `string` into the tensor at the specified index. + /// Note: `index` should take strides into account. #[no_mangle] pub extern "C" fn carton_tensor_set_string_with_strlen( &mut self, @@ -232,7 +241,10 @@ impl CartonTensor { fn carton_tensor_set_string_inner(&mut self, index: u64, string: String) { if let carton_core::types::Tensor::String(v) = &mut self.inner { let mut view = v.view_mut(); - let item = view.iter_mut().nth(index as _).unwrap(); + let ptr = view.as_mut_ptr(); + + // TODO: assert that the index is in bounds + let item = unsafe { &mut *ptr.add(index as _) }; *item = string; } else { panic!("Tried to call `set_string` on a non-string tensor") diff --git a/source/carton-bindings-cpp/src/carton.hh b/source/carton-bindings-cpp/src/carton.hh index 04b2e8c..7428336 100644 --- a/source/carton-bindings-cpp/src/carton.hh +++ b/source/carton-bindings-cpp/src/carton.hh @@ -93,6 +93,19 @@ namespace carton friend class TensorMap; + template + friend class TensorStringValue; + + // For a string tensor, set a string at a particular flattened index + // This will copy data from the provided string_view. + // Note: `index` should take strides into account. + void set_string(uint64_t index, std::string_view string); + + // For a string tensor, get a string at a particular flattened index + // Note: the returned view is only valid until the tensor is modified. + // Note: `index` should take strides into account. + std::string_view get_string(uint64_t index) const; + public: // Create a tensor with dtype `dtype` and shape `shape` Tensor(DataType dtype, std::span shape); @@ -131,15 +144,57 @@ namespace carton // Note: the returned span is only valid while this Tensor is in scope std::span strides() const; - // For a string tensor, set a string at a particular (flattened) index - // This will copy data from the provided string_view. - // TODO: do some template magic to make this easy to use - void set_string(uint64_t index, std::string_view string); + // Using the accessor methods can be faster than `at` when accessing many elements + // because they avoid making function calls on each element access. + // See `TensorAccessor` below for usage. + template + auto accessor(); + + // Using the accessor methods can be faster than `at` when accessing many elements + // because they avoid making function calls on each element access. + // See `TensorAccessor` below for usage. + template + auto accessor() const; + + // Get an element at an index + // This is a convenience wrapper that creates an `accessor` and uses it + // Consider explicitly creating an accessor if you need to access many elements + template + auto at(Index... index) const; + + // Get an element at an index + // This is a convenience wrapper that creates an `accessor` and uses it + // Consider explicitly creating an accessor if you need to access many elements + template + auto at(Index... index); + }; - // For a string tensor, get a string at a particular (flattened) index - // Note: the returned view is only valid until the tensor is modified. - // TODO: do some template magic to make this easy to use - std::string_view get_string(uint64_t index) const; + // The return type of the `accessor` methods of `Tensor` + template + class TensorAccessor + { + private: + DataContainer data_; + + // The strides of the tensor + const std::span strides_; + + friend class Tensor; + TensorAccessor(DataContainer data, std::span strides) : data_(data), strides_(strides) {} + + public: + // Return the element at `index` + // One value of `index` must be provided for each dimension. + // + // ``` + // auto acc = t.accessor(); + // auto val = acc[1, 2, 3]; + // ``` + // + // Note: For string values, the returned view is only valid until the tensor is modified. Users + // should make a copy if they need to both persist the value and modify the tensor. + template + auto operator[](Index... index) const; }; template diff --git a/source/carton-bindings-cpp/src/carton_impl.hh b/source/carton-bindings-cpp/src/carton_impl.hh index 15cbb45..6298b19 100644 --- a/source/carton-bindings-cpp/src/carton_impl.hh +++ b/source/carton-bindings-cpp/src/carton_impl.hh @@ -82,6 +82,121 @@ namespace carton } } + // Utility to let us read and write strings more easily + // This is returned by `TensorAccessor` when indexing string tensors + template + class TensorStringValue + { + private: + T &tensor_; + + uint64_t index_; + + public: + TensorStringValue(T &tensor, uint64_t index) : tensor_(tensor), index_(index) {} + + // Assignment of a string type + void operator=(std::string_view val) + { + tensor_.set_string(index_, val); + } + + // Reading of a string type + operator std::string_view() const + { + return tensor_.get_string(index_); + } + }; + + template + std::ostream &operator<<(std::ostream &os, const TensorStringValue &v) + { + os << std::string_view(v); + return os; + } + + // Impl for Tensor + // Using the accessor methods can be faster when accessing many elements because + // they avoid making function calls on each element access + template + auto Tensor::accessor() + { + // TODO: assert N == ndims + // TODO: assert data type + if constexpr (std::is_same_v || std::is_same_v) + { + return TensorAccessor(*this, strides()); + } + else + { + return TensorAccessor(data(), strides()); + } + } + + // Using the accessor methods can be faster when accessing many elements because + // they avoid making function calls on each element access + template + auto Tensor::accessor() const + { + // TODO: assert N == ndims + // TODO: assert data type + if constexpr (std::is_same_v || std::is_same_v) + { + return TensorAccessor(*this, strides()); + } + else + { + return TensorAccessor(data(), strides()); + } + } + + template + auto Tensor::at(Index... index) const + { + constexpr auto N = sizeof...(Index); + auto acc = accessor(); + return acc.operator[](std::forward(index)...); + } + + template + auto Tensor::at(Index... index) + { + constexpr auto N = sizeof...(Index); + auto acc = accessor(); + return acc.operator[](std::forward(index)...); + } + + // Impl for TensorAccessor + template + template + auto TensorAccessor::operator[](Index... index) const + { + constexpr auto num_indices = sizeof...(Index); + static_assert(NumDims == num_indices, "Incorrect number of indices"); + + // Compute the index. This all gets flattened out at compile time + int i = 0; + + // Basically sets up a dot product of `index` and `strides` + auto offset = ([&] + { return index * strides_[i++]; }() + + ...); + + if constexpr (std::is_same_v || std::is_same_v) + { + // Handle string tensors separately + // For convenience, we allow T to be std::string, but we always use `std::string_view` + // to avoid unnecessary copies. + return TensorStringValue(data_, offset); + } + else + { + // Numeric tensors + static_assert(std::is_arithmetic_v, "accessor() only supports string and numeric tensors"); + return static_cast(data_)[offset * sizeof(T)]; + } + } + // Impl for AsyncNotifier template AsyncNotifier::AsyncNotifier() : AsyncNotifierBase() {} diff --git a/source/carton-bindings-cpp/tests/callback.cc b/source/carton-bindings-cpp/tests/callback.cc index 885b390..5fe017a 100644 --- a/source/carton-bindings-cpp/tests/callback.cc +++ b/source/carton-bindings-cpp/tests/callback.cc @@ -32,10 +32,12 @@ void infer_callback(Result infer_result, void *arg) const auto tokens = out.get_and_remove("tokens"); const auto scores = out.get_and_remove("scores"); - const auto scores_data = static_cast(scores.data()); + // Can use a template arg of `std::string` or `std::string_view` + std::cout << "Got output token: " << tokens.at(0) << std::endl; + std::cout << "Got output scores: " << scores.at(0) << std::endl; - std::cout << "Got output token: " << tokens.get_string(0) << std::endl; - std::cout << "Got output scores: " << scores_data[0] << std::endl; + assert(tokens.at(0) == std::string_view("day")); + assert(std::abs(scores.at(0) - 14.5513) < 0.0001); exit(0); } @@ -47,7 +49,7 @@ void load_callback(Result model_result, void *arg) uint64_t shape[]{1}; auto tensor = Tensor(DataType::kString, shape); - tensor.set_string(0, "Today is a good [MASK]."); + tensor.at(0) = "Today is a good [MASK]."; std::unordered_map inputs; inputs.insert(std::make_pair("input", std::move(tensor))); diff --git a/source/carton-bindings-cpp/tests/future.cc b/source/carton-bindings-cpp/tests/future.cc index 65b7bcb..a3bcddb 100644 --- a/source/carton-bindings-cpp/tests/future.cc +++ b/source/carton-bindings-cpp/tests/future.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "../src/carton.hh" @@ -28,7 +29,9 @@ int main() uint64_t shape[]{1}; auto tensor = Tensor(DataType::kString, shape); - tensor.set_string(0, "Today is a good [MASK]."); + + // Can use a template arg of `std::string` or `std::string_view` + tensor.at(0) = "Today is a good [MASK]."; std::unordered_map inputs; inputs.insert(std::make_pair("input", std::move(tensor))); @@ -41,6 +44,9 @@ int main() const auto scores_data = static_cast(scores.data()); - std::cout << "Got output token: " << tokens.get_string(0) << std::endl; + std::cout << "Got output token: " << tokens.at(0) << std::endl; std::cout << "Got output scores: " << scores_data[0] << std::endl; + + assert(tokens.at(0) == std::string_view("day")); + assert(std::abs(scores_data[0] - 14.5513) < 0.0001); } \ No newline at end of file diff --git a/source/carton-bindings-cpp/tests/notifier.cc b/source/carton-bindings-cpp/tests/notifier.cc index b8b7b3d..3a714ff 100644 --- a/source/carton-bindings-cpp/tests/notifier.cc +++ b/source/carton-bindings-cpp/tests/notifier.cc @@ -38,7 +38,7 @@ int main() uint64_t shape[]{1}; auto tensor = Tensor(DataType::kString, shape); - tensor.set_string(0, "Today is a good [MASK]."); + tensor.at(0) = "Today is a good [MASK]."; std::unordered_map inputs; inputs.insert(std::make_pair("input", std::move(tensor))); @@ -56,6 +56,13 @@ int main() const auto scores_data = static_cast(scores.data()); - std::cout << "Got output token: " << tokens.get_string(0) << std::endl; + // If you're accessing a few elements, you can just use `.at`, but we'll use + // an accessor here for testing + const auto token_accessor = tokens.accessor(); + + std::cout << "Got output token: " << token_accessor[0] << std::endl; std::cout << "Got output scores: " << scores_data[0] << std::endl; + + assert(token_accessor[0] == std::string_view("day")); + assert(std::abs(scores_data[0] - 14.5513) < 0.0001); } \ No newline at end of file