Skip to content

Commit

Permalink
Merge pull request #500 from AVSLab/feature/499-fix-leaks-swig-typemaps
Browse files Browse the repository at this point in the history
Feature/499 Fix SWIG Eigen type maps to remove data leaks
  • Loading branch information
schaubh authored Nov 19, 2023
2 parents 5007758 + beb7158 commit a83d541
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions src/architecture/_GeneralModuleFiles/swig_eigen.i
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ std::optional<std::pair<Py_ssize_t, Py_ssize_t>> getInputSize(PyObject *input)

PyObject *firstItem = PySequence_GetItem(input, 0);
Py_ssize_t numberColumns = PySequence_Check(firstItem) ? PySequence_Length(firstItem) : 1;
Py_DECREF(firstItem);
return {{numberRows, numberColumns}};
}
}
Expand Down Expand Up @@ -258,6 +259,8 @@ In that case, disambiguate using PyErr_Occurred().
template<class T>
T pyObjToEigenMatrix(PyObject *input)
{
using ScalarType = typename T::Scalar;

auto errorMsg = checkPyObjectIsMatrixLike<T>(input);
if (errorMsg)
{
Expand All @@ -278,23 +281,39 @@ T pyObjToEigenMatrix(PyObject *input)
for(Py_ssize_t row=0; row<numberRows; row++)
{
PyObject *rowPyObj = PySequence_GetItem(input, row);
bool rowPyObjIsSequence = PySequence_Check(rowPyObj);

for(Py_ssize_t col=0; col<numberColumns; col++)
{
// rowPyObj can be either a length 1 sequence or the value directly
PyObject *rowColPyObj = PySequence_Check(rowPyObj) ? PySequence_GetItem(rowPyObj, col) : rowPyObj;
std::variant<ScalarType, std::string> valueOrErrorMsg;
if (rowPyObjIsSequence)
{
PyObject *rowColPyObj = PySequence_GetItem(rowPyObj, col);
valueOrErrorMsg = castPyToC<ScalarType>(rowColPyObj);
Py_DECREF(rowColPyObj);
}
else
{
valueOrErrorMsg = castPyToC<ScalarType>(rowPyObj);
}

auto valueOrErrorMsg = castPyToC<typename T::Scalar>(rowColPyObj);
if (std::holds_alternative<std::string>(valueOrErrorMsg))
{
PyErr_SetString(PyExc_ValueError, (
"Row " + std::to_string(row) + ", Column " + std::to_string(col) +": "
+ std::get<std::string>(valueOrErrorMsg)
).c_str());

Py_DECREF(rowPyObj);

return {};
}

result(row, col) = std::get<typename T::Scalar>(valueOrErrorMsg);
}

Py_DECREF(rowPyObj);
}

return result;
Expand Down Expand Up @@ -358,9 +377,12 @@ void fillPyObjList(PyObject *input, const T& value)
PyObject *locRow = PyList_New(0);
for(auto j=0; j<value.outerSize(); j++)
{
PyList_Append(locRow, castCToPy<typename T::Scalar>(value(i,j)));
auto toAppend = castCToPy<typename T::Scalar>(value(i,j));
PyList_Append(locRow, toAppend);
Py_DECREF(toAppend);
}
PyList_Append(input, locRow);
Py_DECREF(locRow);
}
}

Expand Down

0 comments on commit a83d541

Please sign in to comment.