Skip to content

Commit

Permalink
more minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Jun 4, 2024
1 parent 6b13479 commit 8c132a8
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(

Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
const std::shared_ptr<ExtensionScalar>& scalar) {
const auto ext_scalar = internal::checked_pointer_cast<ExtensionScalar>(scalar);
const auto* ext_scalar = internal::checked_cast<const ExtensionScalar*>(scalar.get());
const auto* ext_type =
internal::checked_cast<FixedShapeTensorType*>(scalar->type.get());
if (!is_fixed_width(*ext_type->value_type())) {
Expand All @@ -218,8 +218,8 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
if (array->null_count() > 0) {
return Status::Invalid("Cannot convert data with nulls to Tensor.");
}
const auto value_type =
internal::checked_pointer_cast<FixedWidthType>(ext_type->value_type());
const auto* value_type =
internal::checked_cast<const FixedWidthType*>(ext_type->value_type().get());
const auto byte_width = value_type->byte_width();

std::vector<int64_t> permutation = ext_type->permutation();
Expand All @@ -237,14 +237,14 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
}

std::vector<int64_t> strides;
RETURN_NOT_OK(ComputeStrides(*value_type.get(), shape, permutation, &strides));
RETURN_NOT_OK(ComputeStrides(*value_type, shape, permutation, &strides));
const auto start_position = array->offset() * byte_width;
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
const auto buffer =
SliceBuffer(array->data()->buffers[1], start_position, size * byte_width);

return Tensor::Make(value_type, buffer, shape, strides, dim_names);
return Tensor::Make(ext_type->value_type(), buffer, shape, strides, dim_names);
}

Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor(
Expand Down

0 comments on commit 8c132a8

Please sign in to comment.