diff --git a/src/nupic/bindings/algorithms.i b/src/nupic/bindings/algorithms.i index 95a410ab98..5a1f94003b 100644 --- a/src/nupic/bindings/algorithms.i +++ b/src/nupic/bindings/algorithms.i @@ -1163,11 +1163,11 @@ void forceRetentionOfImageSensorLiteLibrary(void) { self._initFromCapnpPyBytes(proto.as_builder().to_bytes()) # copy * 2 %} - inline void compute(PyObject *py_x, bool learn, PyObject *py_y) + inline void compute(PyObject *py_inputArray, bool learn, PyObject *py_activeArray) { - PyArrayObject* x = (PyArrayObject*) py_x; - PyArrayObject* y = (PyArrayObject*) py_y; - self->compute((nupic::UInt*) PyArray_DATA(x), (bool)learn, (nupic::UInt*) PyArray_DATA(y)); + nupic::CheckedNumpyVectorWeakRefT inputArray(py_inputArray); + nupic::CheckedNumpyVectorWeakRefT activeArray(py_activeArray); + self->compute(inputArray.begin(), learn, activeArray.begin()); } inline void stripUnlearnedColumns(PyObject *py_x) diff --git a/src/nupic/py_support/NumpyVector.hpp b/src/nupic/py_support/NumpyVector.hpp index 9f16bf8541..4d2c4f7cf2 100644 --- a/src/nupic/py_support/NumpyVector.hpp +++ b/src/nupic/py_support/NumpyVector.hpp @@ -31,6 +31,7 @@ #include // For nupic::Real. #include // For NTA_ASSERT #include // For std::copy. +#include // for 'type_id' namespace nupic { @@ -438,6 +439,31 @@ namespace nupic { PyArrayObject* pyArray_; }; + /** + * Similar to NumpyVectorWeakRefT but also provides extra type checking + */ + template + class CheckedNumpyVectorWeakRefT : public NumpyVectorWeakRefT + { + public: + CheckedNumpyVectorWeakRefT(PyObject* pyArray) + : NumpyVectorWeakRefT(pyArray) + { + if (PyArray_NDIM(this->pyArray_) != 1) + { + NTA_THROW << "Expecting 1D array " + << "but got " << PyArray_NDIM(this->pyArray_) << "D array"; + } + if (!PyArray_EquivTypenums( + PyArray_TYPE(this->pyArray_), LookupNumpyDType((const T *) 0))) + { + boost::typeindex::stl_type_index expectedType = + boost::typeindex::stl_type_index::type_id(); + NTA_THROW << "Expecting '" << expectedType.pretty_name() << "' " + << "but got '" << PyArray_DTYPE(this->pyArray_)->type << "'"; + } + } + }; } // End namespace nupic. #endif