diff --git a/python/GafferMLTest/TensorToImageTest.py b/python/GafferMLTest/TensorToImageTest.py index 63c523a5293..b7a2b05d7a5 100644 --- a/python/GafferMLTest/TensorToImageTest.py +++ b/python/GafferMLTest/TensorToImageTest.py @@ -43,9 +43,10 @@ import Gaffer import GafferTest import GafferImage +import GafferImageTest import GafferML -class TensorToImageTest( GafferTest.TestCase ) : +class TensorToImageTest( GafferImageTest.ImageTestCase ) : def testNoInput( self ) : @@ -106,5 +107,24 @@ def testNonMatchingChannels( self ) : with self.assertRaisesRegex( RuntimeError, 'Invalid channel "G' ) : tensorToImage["out"].channelData( "G", imath.V2i( 0 ) ) + def testRoundTripWithImageToTensor( self ) : + + image = GafferImage.Checkerboard() + + imageToTensor = GafferML.ImageToTensor() + imageToTensor["image"].setInput( image["out"] ) + imageToTensor["channels"].setInput( image["out"]["channelNames"]) + + tensorToImage = GafferML.TensorToImage() + tensorToImage["tensor"].setInput( imageToTensor["tensor"] ) + tensorToImage["channels"].setInput( image["out"]["channelNames"]) + + self.assertImagesEqual( tensorToImage["out"], image["out"] ) + + imageToTensor["interleaveChannels"].setValue( True ) + tensorToImage["interleavedChannels"].setValue( True ) + + self.assertImagesEqual( tensorToImage["out"], image["out"] ) + if __name__ == "__main__": unittest.main() diff --git a/src/GafferML/TensorToImage.cpp b/src/GafferML/TensorToImage.cpp index d51ec9d77cf..3f2ec34716a 100644 --- a/src/GafferML/TensorToImage.cpp +++ b/src/GafferML/TensorToImage.cpp @@ -92,14 +92,14 @@ ImageShape imageShape( const Tensor *tensor, bool interleavedChannels ) if( interleavedChannels ) { return { - Box2i( V2i( 0 ), V2i( (int)shape[i], (int)shape[i+1] ) ), + Box2i( V2i( 0 ), V2i( (int)shape[i+1], (int)shape[i] ) ), (size_t)shape[i+2] }; } else { return { - Box2i( V2i( 0 ), V2i( (int)shape[i+1], (int)shape[i+2] ) ), + Box2i( V2i( 0 ), V2i( (int)shape[i+2], (int)shape[i+1] ) ), (size_t)shape[i] }; }