Skip to content

Commit

Permalink
TensorToImage : Fix x/y mixup
Browse files Browse the repository at this point in the history
  • Loading branch information
johnhaddon committed Nov 9, 2024
1 parent ddc62fc commit 5b06225
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
22 changes: 21 additions & 1 deletion python/GafferMLTest/TensorToImageTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@
import Gaffer
import GafferTest
import GafferImage
import GafferImageTest
import GafferML

class TensorToImageTest( GafferTest.TestCase ) :
class TensorToImageTest( GafferImageTest.ImageTestCase ) :

def testNoInput( self ) :

Expand Down Expand Up @@ -106,5 +107,24 @@ def testNonMatchingChannels( self ) :
with self.assertRaisesRegex( RuntimeError, 'Invalid channel "G' ) :
tensorToImage["out"].channelData( "G", imath.V2i( 0 ) )

def testRoundTripWithImageToTensor( self ) :

image = GafferImage.Checkerboard()

imageToTensor = GafferML.ImageToTensor()
imageToTensor["image"].setInput( image["out"] )
imageToTensor["channels"].setInput( image["out"]["channelNames"])

tensorToImage = GafferML.TensorToImage()
tensorToImage["tensor"].setInput( imageToTensor["tensor"] )
tensorToImage["channels"].setInput( image["out"]["channelNames"])

self.assertImagesEqual( tensorToImage["out"], image["out"] )

imageToTensor["interleaveChannels"].setValue( True )
tensorToImage["interleavedChannels"].setValue( True )

self.assertImagesEqual( tensorToImage["out"], image["out"] )

if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions src/GafferML/TensorToImage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ ImageShape imageShape( const Tensor *tensor, bool interleavedChannels )
if( interleavedChannels )
{
return {
Box2i( V2i( 0 ), V2i( (int)shape[i], (int)shape[i+1] ) ),
Box2i( V2i( 0 ), V2i( (int)shape[i+1], (int)shape[i] ) ),
(size_t)shape[i+2]
};
}
else
{
return {
Box2i( V2i( 0 ), V2i( (int)shape[i+1], (int)shape[i+2] ) ),
Box2i( V2i( 0 ), V2i( (int)shape[i+2], (int)shape[i+1] ) ),
(size_t)shape[i]
};
}
Expand Down

0 comments on commit 5b06225

Please sign in to comment.