Skip to content

Commit

Permalink
TensorToImage : Deal with some channel mismatch stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
johnhaddon committed Nov 9, 2024
1 parent 48d0e63 commit ddc62fc
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
58 changes: 58 additions & 0 deletions python/GafferMLTest/TensorToImageTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@

import unittest

import imath

import IECore

import Gaffer
import GafferTest
import GafferImage
import GafferML

class TensorToImageTest( GafferTest.TestCase ) :
Expand All @@ -48,5 +53,58 @@ def testNoInput( self ) :
with self.assertRaisesRegex( Gaffer.ProcessException, "Empty tensor" ) :
node["out"].dataWindow()

def testNonMatchingChannels( self ) :

tensor = GafferML.Tensor(
IECore.Color3fVectorData( [ imath.Color3f( 1, 2, 3 ) ] ),
[ 1, 1, 3 ]
)

tensorToImage = GafferML.TensorToImage()
tensorToImage["tensor"].setValue( tensor )
tensorToImage["interleavedChannels"].setValue( True )
self.assertEqual( tensorToImage["out"].dataWindow(), imath.Box2i( imath.V2i( 0 ), imath.V2i( 1 ) ) )

# Only two channels specified.

tensorToImage["channels"].setValue( IECore.StringVectorData( [ "R", "G" ] ) )
self.assertEqual( tensorToImage["out"].channelNames(), IECore.StringVectorData( [ "R", "G" ] ) )
self.assertEqual( tensorToImage["out"].channelData( "R", imath.V2i( 0 ) )[0], 1 )
self.assertEqual( tensorToImage["out"].channelData( "G", imath.V2i( 0 ) )[0], 2 )

with self.assertRaisesRegex( RuntimeError, 'Invalid channel "B"' ) :
tensorToImage["out"].channelData( "B", imath.V2i( 0 ) )

# Duplicate channels specified. We just take the first.

tensorToImage["channels"].setValue( IECore.StringVectorData( [ "R", "R", "B" ] ) )
self.assertEqual( tensorToImage["out"].channelNames(), IECore.StringVectorData( [ "R", "B" ] ) )
self.assertEqual( tensorToImage["out"].channelData( "R", imath.V2i( 0 ) )[0], 1 )
self.assertEqual( tensorToImage["out"].channelData( "B", imath.V2i( 0 ) )[0], 3 )

with self.assertRaisesRegex( RuntimeError, 'Invalid channel "G' ) :
tensorToImage["out"].channelData( "G", imath.V2i( 0 ) )

# Too many channels specified. We error if the extra channel is accessed.

tensorToImage["channels"].setValue( IECore.StringVectorData( [ "R", "G", "B", "A" ] ) )
self.assertEqual( tensorToImage["out"].channelNames(), IECore.StringVectorData( [ "R", "G", "B", "A" ] ) )
self.assertEqual( tensorToImage["out"].channelData( "R", imath.V2i( 0 ) )[0], 1 )
self.assertEqual( tensorToImage["out"].channelData( "G", imath.V2i( 0 ) )[0], 2 )
self.assertEqual( tensorToImage["out"].channelData( "B", imath.V2i( 0 ) )[0], 3 )

with self.assertRaisesRegex( RuntimeError, 'Channel "A" out of range' ) :
tensorToImage["out"].channelData( "A", imath.V2i( 0 ) )

# Channels skipped by entering empty strings.

tensorToImage["channels"].setValue( IECore.StringVectorData( [ "R", "", "B" ] ) )
self.assertEqual( tensorToImage["out"].channelNames(), IECore.StringVectorData( [ "R", "B" ] ) )
self.assertEqual( tensorToImage["out"].channelData( "R", imath.V2i( 0 ) )[0], 1 )
self.assertEqual( tensorToImage["out"].channelData( "B", imath.V2i( 0 ) )[0], 3 )

with self.assertRaisesRegex( RuntimeError, 'Invalid channel "G' ) :
tensorToImage["out"].channelData( "G", imath.V2i( 0 ) )

if __name__ == "__main__":
unittest.main()
23 changes: 18 additions & 5 deletions src/GafferML/TensorToImage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace
struct ImageShape
{
Box2i dataWindow;
int numChannels;
size_t numChannels;
};

ImageShape imageShape( const Tensor *tensor, bool interleavedChannels )
Expand Down Expand Up @@ -93,14 +93,14 @@ ImageShape imageShape( const Tensor *tensor, bool interleavedChannels )
{
return {
Box2i( V2i( 0 ), V2i( (int)shape[i], (int)shape[i+1] ) ),
(int)shape[i+2]
(size_t)shape[i+2]
};
}
else
{
return {
Box2i( V2i( 0 ), V2i( (int)shape[i+1], (int)shape[i+2] ) ),
(int)shape[i]
(size_t)shape[i]
};
}
}
Expand Down Expand Up @@ -238,6 +238,17 @@ IECore::ConstStringVectorDataPtr TensorToImage::computeChannelNames( const Gaffe
),
result->writable().end()
);
// Remove empty channel names since they wouldn't be valid in the output.
result->writable().erase(
std::remove_if(
result->writable().begin(),
result->writable().end(),
[] ( const string &channelName ) {
return channelName.empty();
}
),
result->writable().end()
);
return result;
}

Expand Down Expand Up @@ -275,8 +286,10 @@ IECore::ConstFloatVectorDataPtr TensorToImage::computeChannelData( const std::st
const size_t channelIndex = channelIt - channelsData->readable().begin();

const ImageShape imageShape = ::imageShape( tensorData.get(), interleavedChannels );
// TODO : ERROR IF CHANNEL INDEX IS OUTSIDE OF TENSOR BOUNDS
// AND ALLOW EMPTY CHANNEL NAME TO SKIP CHANNELS.
if( channelIndex >= imageShape.numChannels )
{
throw IECore::Exception( fmt::format( "Channel \"{}\" out of range", channelName ) );
}

FloatVectorDataPtr outData = new FloatVectorData;
vector<float> &out = outData->writable();
Expand Down

0 comments on commit ddc62fc

Please sign in to comment.