Skip to content

Commit

Permalink
ImageToTensor : Add view plug
Browse files Browse the repository at this point in the history
  • Loading branch information
johnhaddon committed Nov 6, 2024
1 parent f61f594 commit a256cf6
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 13 deletions.
4 changes: 4 additions & 0 deletions include/GafferML/ImageToTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "GafferImage/ImagePlug.h"

#include "Gaffer/ComputeNode.h"
#include "Gaffer/StringPlug.h"

namespace GafferML
{
Expand All @@ -59,6 +60,9 @@ class GAFFERML_API ImageToTensor : public Gaffer::ComputeNode
GafferImage::ImagePlug *imagePlug();
const GafferImage::ImagePlug *imagePlug() const;

Gaffer::StringPlug *viewPlug();
const Gaffer::StringPlug *viewPlug() const;

Gaffer::StringVectorDataPlug *channelsPlug();
const Gaffer::StringVectorDataPlug *channelsPlug() const;

Expand Down
37 changes: 37 additions & 0 deletions python/GafferMLTest/ImageToTensorTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

import unittest

import imath

import IECore

import Gaffer
Expand Down Expand Up @@ -67,5 +69,40 @@ def testShufflingChannelsChangesHash( self ) :
tensor["channels"].setValue( IECore.StringVectorData( [ "B", "G", "R" ] ) )
self.assertNotEqual( tensor["tensor"].hash(), h1 )

def testView( self ) :

left = GafferImage.Constant()
left["color"].setValue( imath.Color4f( 1, 0, 0, 1 ) )
left["format"].setValue( GafferImage.Format( 1, 1 ) )

right = GafferImage.Constant()
right["color"].setValue( imath.Color4f( 0, 1, 0, 1 ) )
right["format"].setValue( GafferImage.Format( 1, 1 ) )

createViews = GafferImage.CreateViews()
createViews["views"].resize( 2 )
createViews["views"][0]["name"].setValue( "left" )
createViews["views"][0]["value"].setInput( left["out" ])
createViews["views"][1]["name"].setValue( "right" )
createViews["views"][1]["value"].setInput( right["out" ])

imageToTensor = GafferML.ImageToTensor()
imageToTensor["image"].setInput( createViews["out"] )

with self.assertRaisesRegex( Gaffer.ProcessException, "View does not exist" ) :
imageToTensor["tensor"].getValue()

imageToTensor["view"].setValue( "left" )
self.assertEqual(
imageToTensor["tensor"].getValue().asData(),
IECore.FloatVectorData( [ 1, 0, 0 ] )
)

imageToTensor["view"].setValue( "right" )
self.assertEqual(
imageToTensor["tensor"].getValue().asData(),
IECore.FloatVectorData( [ 0, 1, 0 ] )
)

if __name__ == "__main__":
unittest.main()
60 changes: 60 additions & 0 deletions python/GafferMLUI/ImageToTensorUI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
##########################################################################
#
# Copyright (c) 2024, Cinesite VFX Ltd. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above
# copyright notice, this list of conditions and the following
# disclaimer.
#
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with
# the distribution.
#
# * Neither the name of John Haddon nor the names of
# any other contributors to this software may be used to endorse or
# promote products derived from this software without specific prior
# written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
##########################################################################

import Gaffer
import GafferML

Gaffer.Metadata.registerNode(

GafferML.ImageToTensor,

# "description",
# """
# Converts Gaffer data to tensors for use with the Inference node.
# Potential data sources include PrimitiveVariableQuery nodes to fetch data
# from 3D scenes, or expressions to generate arbitrary input data.
# """,

plugs = {

"view" : [

"plugValueWidget:type", "GafferImageUI.ViewPlugValueWidget",

],

}
)
1 change: 1 addition & 0 deletions python/GafferMLUI/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
##########################################################################

from . import DataToTensorUI
from . import ImageToTensorUI
from . import TensorToImageUI
from . import InferenceUI

Expand Down
45 changes: 32 additions & 13 deletions src/GafferML/ImageToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ ImageToTensor::ImageToTensor( const std::string &name )
{
storeIndexOfNextChild( g_firstPlugIndex );
addChild( new ImagePlug( "image", Plug::In ) );
addChild( new StringPlug( "view", Plug::In, "default" ) );
addChild( new StringVectorDataPlug( "channels", Plug::In, new StringVectorData( { "R", "G", "B" } ) ) );
addChild( new BoolPlug( "interleaveChannels" ) );
addChild( new TensorPlug( "tensor", Plug::Out ) );
Expand All @@ -81,44 +82,56 @@ const GafferImage::ImagePlug *ImageToTensor::imagePlug() const
return getChild<ImagePlug>( g_firstPlugIndex );
}

Gaffer::StringPlug *ImageToTensor::viewPlug()
{
return getChild<StringPlug>( g_firstPlugIndex + 1 );
}

const Gaffer::StringPlug *ImageToTensor::viewPlug() const
{
return getChild<StringPlug>( g_firstPlugIndex + 1 );
}

Gaffer::StringVectorDataPlug *ImageToTensor::channelsPlug()
{
return getChild<StringVectorDataPlug>( g_firstPlugIndex + 1 );
return getChild<StringVectorDataPlug>( g_firstPlugIndex + 2 );
}

const Gaffer::StringVectorDataPlug *ImageToTensor::channelsPlug() const
{
return getChild<StringVectorDataPlug>( g_firstPlugIndex + 1 );
return getChild<StringVectorDataPlug>( g_firstPlugIndex + 2 );
}

Gaffer::BoolPlug *ImageToTensor::interleaveChannelsPlug()
{
return getChild<BoolPlug>( g_firstPlugIndex + 2 );
return getChild<BoolPlug>( g_firstPlugIndex + 3 );
}

const Gaffer::BoolPlug *ImageToTensor::interleaveChannelsPlug() const
{
return getChild<BoolPlug>( g_firstPlugIndex + 2 );
return getChild<BoolPlug>( g_firstPlugIndex + 3 );
}

TensorPlug *ImageToTensor::tensorPlug()
{
return getChild<TensorPlug>( g_firstPlugIndex + 3 );
return getChild<TensorPlug>( g_firstPlugIndex + 4 );
}

const TensorPlug *ImageToTensor::tensorPlug() const
{
return getChild<TensorPlug>( g_firstPlugIndex + 3 );
return getChild<TensorPlug>( g_firstPlugIndex + 4 );
}

void ImageToTensor::affects( const Gaffer::Plug *input, AffectedPlugsContainer &outputs ) const
{
ComputeNode::affects( input, outputs );

if(
input == imagePlug()->viewNamesPlug() ||
input == imagePlug()->dataWindowPlug() ||
input == imagePlug()->channelNamesPlug() ||
input == imagePlug()->channelDataPlug() ||
input == viewPlug() ||
input == channelsPlug() ||
input == interleaveChannelsPlug()
)
Expand All @@ -133,12 +146,16 @@ void ImageToTensor::hash( const Gaffer::ValuePlug *output, const Gaffer::Context
{
ComputeNode::hash( output, context, h );

const Box2i dataWindow = imagePlug()->dataWindow();
ConstStringVectorDataPtr inChannels = imagePlug()->channelNamesPlug()->getValue();
ConstStringVectorDataPtr channels = channelsPlug()->getValue();

interleaveChannelsPlug()->hash( h );

ImagePlug::ViewScope viewScope( context );
const std::string view = viewPlug()->getValue();
viewScope.setViewNameChecked( &view, imagePlug()->viewNames().get() );

const Box2i dataWindow = imagePlug()->dataWindow();
ConstStringVectorDataPtr inChannels = imagePlug()->channelNamesPlug()->getValue();

ImageAlgo::parallelGatherTiles(
imagePlug(),
channels->readable(),
Expand All @@ -162,7 +179,6 @@ void ImageToTensor::hash( const Gaffer::ValuePlug *output, const Gaffer::Context
);

h.append( dataWindow );

}
else
{
Expand All @@ -174,13 +190,16 @@ void ImageToTensor::compute( Gaffer::ValuePlug *output, const Gaffer::Context *c
{
if( output == tensorPlug() )
{
const Box2i dataWindow = imagePlug()->dataWindow();
ConstStringVectorDataPtr inChannels = imagePlug()->channelNamesPlug()->getValue();
ConstStringVectorDataPtr channelsData = channelsPlug()->getValue();
const auto &channels = channelsData->readable();

const bool interleaveChannels = interleaveChannelsPlug()->getValue();

ImagePlug::ViewScope viewScope( context );
const std::string view = viewPlug()->getValue();
viewScope.setViewNameChecked( &view, imagePlug()->viewNames().get() );

const Box2i dataWindow = imagePlug()->dataWindow();
ConstStringVectorDataPtr inChannels = imagePlug()->channelNamesPlug()->getValue();
const size_t numPixels = dataWindow.size().x * dataWindow.size().y;

FloatVectorDataPtr bufferData = new FloatVectorData;
Expand Down

0 comments on commit a256cf6

Please sign in to comment.