Skip to content

Commit

Permalink
Tensor/DataToTensor : Support BoolVectorData
Browse files Browse the repository at this point in the history
  • Loading branch information
johnhaddon committed Nov 6, 2024
1 parent 4797fbc commit b60540a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
1 change: 1 addition & 0 deletions python/GafferMLTest/TensorTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TensorTest( GafferTest.TestCase ) :
def testAsData( self ) :

for data in [
IECore.BoolVectorData( [ True, False, True ] ),
IECore.FloatVectorData( [ 1, 2, 3 ] ),
IECore.DoubleVectorData( [ 1, 2, 3 ] ),
IECore.IntVectorData( [ 1, 2, 3 ] ),
Expand Down
3 changes: 2 additions & 1 deletion python/GafferMLUI/DataToTensorUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def setup( node, plugType ) :
node.setup( plugType() )

for plugType in (
Gaffer.BoolVectorDataPlug,
Gaffer.IntVectorDataPlug,
Gaffer.FloatVectorDataPlug,
None,
Expand All @@ -223,4 +224,4 @@ def setup( node, plugType ) :
}
)

return result
return result
49 changes: 34 additions & 15 deletions src/GafferML/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ void dispatchTensorData( const Ort::Value &value, F &&functor )
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE :
functor( value.GetTensorData<double>() );
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL :
functor( value.GetTensorData<bool>() );
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 :
functor( value.GetTensorData<uint16_t>() );
break;
Expand Down Expand Up @@ -183,25 +186,41 @@ Tensor::Tensor( const IECore::ConstDataPtr &data, std::vector<int64_t> shape )

using DataType = remove_const_t<remove_pointer_t<decltype( typedData )>>;
using BaseType = typename DataType::BaseType;
if constexpr( HasTensorType<BaseType>::value )

if( !shape.size() )
{
if( !shape.size() )
// Automatically infer shape from type.
if constexpr( TypeTraits::IsVectorTypedData<DataType>::value )
{
shape.push_back( typedData->readable().size() );
using ShapeT = ShapeTraits<typename DataType::ValueType::value_type>;
shape.insert( shape.end(), begin( ShapeT::shape ), end( ShapeT::shape ) );
}
else
{
// Automatically infer shape from type.
if constexpr( TypeTraits::IsVectorTypedData<DataType>::value )
{
shape.push_back( typedData->readable().size() );
using ShapeT = ShapeTraits<typename DataType::ValueType::value_type>;
shape.insert( shape.end(), begin( ShapeT::shape ), end( ShapeT::shape ) );
}
else
{
using ShapeT = ShapeTraits<typename DataType::ValueType>;
shape.insert( shape.end(), begin( ShapeT::shape ), end( ShapeT::shape ) );
}
using ShapeT = ShapeTraits<typename DataType::ValueType>;
shape.insert( shape.end(), begin( ShapeT::shape ), end( ShapeT::shape ) );
}
}

Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu( OrtArenaAllocator, OrtMemTypeDefault );

Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu( OrtArenaAllocator, OrtMemTypeDefault );
if constexpr( std::is_same_v<DataType, BoolVectorData> )
{
// Special case for the vector of bool fiasco.
auto array = std::make_unique<bool[]>( typedData->readable().size() );
std::copy( typedData->readable().begin(), typedData->readable().end(), array.get() );
m_state = new State{
Ort::Value::CreateTensor(
memoryInfo.GetConst(),
array.get(), typedData->readable().size(),
shape.data(), shape.size()
),
nullptr
};
}
else if constexpr( HasTensorType<BaseType>::value )
{
m_state = new State{
Ort::Value::CreateTensor(
memoryInfo.GetConst(),
Expand Down

0 comments on commit b60540a

Please sign in to comment.