Skip to content

Commit

Permalink
Inference : Support cancellation
Browse files Browse the repository at this point in the history
This isn't particularly useful, because in practice the UI never cancels the compute. I think this is because the Viewer computes the data window and format on the UI thread first, and for TensorToImage that depends on the output from the inference.
  • Loading branch information
johnhaddon committed Nov 15, 2024
1 parent 50c8f82 commit 2691907
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 8 deletions.
19 changes: 19 additions & 0 deletions python/GafferMLTest/InferenceTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import GafferTest
import GafferML

## \todo Test cancellation. For this, we need a model that takes long enough to compute
# but is small enough to package with the tests.
class InferenceTest( GafferTest.TestCase ) :

def testLoadModel( self ) :
Expand Down Expand Up @@ -101,6 +103,23 @@ def testCompute( self ) :
IECore.FloatVectorData( [ 4 ] * 60 )
)

def testComputeError( self ) :

inference = GafferML.Inference()
inference["model"].setValue( pathlib.Path( __file__ ).parent / "models" / "add.onnx" )
inference.loadModel()

inference["in"][0].setValue(
GafferML.Tensor( IECore.FloatVectorData( [ 1 ] * 60 ), [ 60 ] )
)

inference["in"][1].setValue(
GafferML.Tensor( IECore.FloatVectorData( [ 2 ] * 60 ), [ 3, 4, 5 ] )
)

with self.assertRaisesRegex( Gaffer.ProcessException, "Invalid rank for input" ) :
inference["out"][0].getValue()

def testModelSearchPaths( self ) :

node = GafferML.Inference()
Expand Down
78 changes: 70 additions & 8 deletions src/GafferML/Inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "onnxruntime_cxx_api.h"

#include <mutex>
#include <condition_variable>

using namespace std;
using namespace Imath;
Expand Down Expand Up @@ -100,6 +101,56 @@ Ort::Session &acquireSession( const std::string &fileName )
return it->second;
}

struct AsyncWaiter
{

AsyncWaiter( Ort::RunOptions &runOptions )
: m_runOptions( runOptions )
{
}

void wait( const IECore::Canceller *canceller )
{
while( true )
{
std::unique_lock<std::mutex> lock( m_mutex );
m_conditionVariable.wait_for( lock, std::chrono::milliseconds( 100 ) );

if( m_resultStatus )
{
// Run has completed. Throw if it errored or was cancelled,
// otherwise return.
Ort::ThrowOnError( *m_resultStatus );
IECore::Canceller::check( canceller );
return;
}
else if( canceller && canceller->cancelled() )
{
m_runOptions.SetTerminate();
}
}
}

static void callback( void *userData, OrtValue **outputs, size_t numOutputs, OrtStatusPtr status )
{
// Run has completed. Set status so we can pick it up in `wait()`.
auto that = (AsyncWaiter *)userData;
{
std::unique_lock<std::mutex> lock( that->m_mutex );
that->m_resultStatus = status;
}
that->m_conditionVariable.notify_all();
}

private :

Ort::RunOptions &m_runOptions;
std::mutex m_mutex;
std::condition_variable m_conditionVariable;
std::optional<OrtStatusPtr> m_resultStatus;

};

} // namespace

//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -256,8 +307,9 @@ void Inference::compute( Gaffer::ValuePlug *output, const Gaffer::Context *conte
{
if( output == inferencePlug() )
{
const string model = modelPlug()->getValue();
// Set up input and output tensor arrays.

const string model = modelPlug()->getValue();
Ort::Session &session = acquireSession( model );

vector<Ort::AllocatedStringPtr> inputNameOwners;
Expand All @@ -276,30 +328,40 @@ void Inference::compute( Gaffer::ValuePlug *output, const Gaffer::Context *conte

vector<Ort::AllocatedStringPtr> outputNameOwners;
vector<const char *> outputNames;
vector<Ort::Value> outputs;
for( auto &p : TensorPlug::OutputRange( *outPlug() ) )
{
int outputIndex = StringAlgo::numericSuffix( p->getName().string() );
outputNameOwners.push_back( session.GetOutputNameAllocated( outputIndex, Ort::AllocatorWithDefaultOptions() ) );
outputNames.push_back( outputNameOwners.back().get() );
outputs.push_back( Ort::Value( nullptr ) );
}

// TODO : WE REALLY WANT TO BE ABLE TO CANCEL THIS
// LOOKS POSSIBLE VIA RUNOPTIONS, BUT IT ISN'T POLLED - WE'D
// NEED TO CALL `SetTerminate()` SOMEHOW.
// MAYBE WE CAN USE `RunAsync()`?
// Run inference asynchronously on an ONNX thread. This allows us
// to check for cancellation via our AsyncWaiter.

vector<Ort::Value> outputs = session.Run(
Ort::RunOptions(), inputNames.data(),
Ort::RunOptions runOptions;
AsyncWaiter waiter( runOptions );

session.RunAsync(
runOptions, inputNames.data(),
// The Ort C++ API wants us to pass `Ort::Value *`, but `Ort::Value`
// is non-copyable and the original `Ort::Value` instances are in
// separate TensorDatas and can't be moved. But `Ort::Value` has the
// same layout as `OrtValue *` (the underlying C type) so we can
// just reinterpret cast from the latter. Indeed, `Run()` is going
// to cast straight back to `OrtValue *` to call the C API!
reinterpret_cast<Ort::Value *>( inputs.data() ),
inputs.size(), outputNames.data(), outputNames.size()
inputs.size(),
outputNames.data(),
outputs.data(),
outputNames.size(),
waiter.callback,
&waiter
);

waiter.wait( context->canceller() );

CompoundObjectPtr result = new CompoundObject;
for( size_t i = 0; i < outputs.size(); ++i )
{
Expand Down

0 comments on commit 2691907

Please sign in to comment.