diff --git a/python/ngraph/__init__.py b/python/ngraph/__init__.py index 9ff20972a4c..c3e4120c79a 100644 --- a/python/ngraph/__init__.py +++ b/python/ngraph/__init__.py @@ -22,6 +22,7 @@ from ngraph.ops import asin from ngraph.ops import atan from ngraph.ops import avg_pool +from ngraph.ops import batch_norm from ngraph.ops import broadcast from ngraph.ops import ceiling from ngraph.ops import ceiling as ceil @@ -35,7 +36,9 @@ from ngraph.ops import dot from ngraph.ops import equal from ngraph.ops import exp +from ngraph.ops import function_call from ngraph.ops import floor +from ngraph.ops import get_output_element from ngraph.ops import greater from ngraph.ops import greater_eq from ngraph.ops import less diff --git a/python/ngraph/ops.py b/python/ngraph/ops.py index d2676a186e7..f4d7b29bbb4 100644 --- a/python/ngraph/ops.py +++ b/python/ngraph/ops.py @@ -20,11 +20,12 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \ Shape, Strides -from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \ - Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq,\ - Less, LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, \ - OneHot, Pad, Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, \ - Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh +from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling,\ + Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \ + FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \ + Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, Relu, \ + ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, \ + Tan, Tanh from typing import Iterable, List @@ -761,3 +762,33 @@ def reverse(node, reversed_axes, name=None): # type: (Node, List[int], str) -> :return: The new node with reversed axes. """ return Reverse(node, AxisSet(reversed_axes)) + + +@nameable_op +def batch_norm(eps, # type: float + gamma, # type: Node + beta, # type: Node + data, # type: Node + mean=None, # type: Node + variance=None, # type: Node + training=False, # type: bool + name=None, # type: str + ): + # type: (...) -> Node + """Return batch normalization node.""" + if mean is None and variance is None: + return BatchNorm(eps, gamma, beta, data) + else: + return BatchNorm(eps, gamma, beta, data, mean, variance, training) + + +@nameable_op +def function_call(function_to_call, args): # type: (Node, NodeVector) -> Node + """Return Function call op.""" + return FunctionCall(function_to_call, args) + + +@nameable_op +def get_output_element(data, index): # type: (Node, int) -> Node + """Return the `n`th element of the input tuple.""" + return GetOutputElement(data, index) diff --git a/python/pyngraph/ops/batch_norm.cpp b/python/pyngraph/ops/batch_norm.cpp index a9f1b09ea7f..ce9141d452e 100644 --- a/python/pyngraph/ops/batch_norm.cpp +++ b/python/pyngraph/ops/batch_norm.cpp @@ -33,6 +33,14 @@ void regclass_pyngraph_op_BatchNorm(py::module m) const std::shared_ptr&, const std::shared_ptr&, const std::shared_ptr&>()); + + batch_norm.def(py::init&, + const std::shared_ptr&, + const std::shared_ptr&, + const std::shared_ptr&, + const std::shared_ptr&, + bool&>()); } void regclass_pyngraph_op_BatchNormBackprop(py::module m) diff --git a/python/test/ngraph/test_basic.py b/python/test/ngraph/test_basic.py index 9a7c98fbd17..1f5aaebae3a 100644 --- a/python/test/ngraph/test_basic.py +++ b/python/test/ngraph/test_basic.py @@ -19,6 +19,7 @@ import ngraph as ng from test.ngraph.util import get_runtime, run_op_node +from ngraph.impl import Function, NodeVector @pytest.mark.parametrize('dtype', [np.float32, np.float64, @@ -48,6 +49,26 @@ def test_simple_computation_on_ndarrays(dtype): assert np.allclose(result, np.array([[630, 704], [782, 864]], dtype=dtype)) +def test_function_call(): + runtime = get_runtime() + dtype = int + shape = [2, 2] + parameter_a = ng.parameter(shape, dtype=dtype, name='A') + parameter_b = ng.parameter(shape, dtype=dtype, name='B') + parameter_c = ng.parameter(shape, dtype=dtype, name='C') + parameter_list = [parameter_a, parameter_b, parameter_c] + ops = ((parameter_a + parameter_b) * parameter_c) + func = Function(NodeVector([ops]), parameter_list, 'addmul') + fc = ng.function_call(func, NodeVector(parameter_list)) + computation = runtime.computation(fc, parameter_a, parameter_b, parameter_c) + + value_a = np.array([[1, 2], [3, 4]], dtype=dtype) + value_b = np.array([[5, 6], [7, 8]], dtype=dtype) + value_c = np.array([[9, 10], [11, 12]], dtype=dtype) + result = computation(value_a, value_b, value_c) + assert np.allclose(result, np.array([[54, 80], [110, 144]], dtype=dtype)) + + def test_serialization(): dtype = np.float32 manager_name = pytest.config.getoption('backend', default='CPU')