Skip to content

Commit

Permalink
Detect potential copyto usage and improve numpy.copyto support
Browse files Browse the repository at this point in the history
Detect a[:] = expr pattern and use np.copyto(a, expr) instead.
Use memcpy when possible (through broadcast_copy improvement).
Accept moved reference.
  • Loading branch information
serge-sans-paille committed Apr 12, 2023
1 parent 5396483 commit 46da5e5
Show file tree
Hide file tree
Showing 19 changed files with 306 additions and 82 deletions.
1 change: 1 addition & 0 deletions pythran/optimizations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

from .constant_folding import ConstantFolding, PartialConstantFolding
from .copyto import CopyTo
from .dead_code_elimination import DeadCodeElimination
from .forward_substitution import ForwardSubstitution, PreInliningForwardSubstitution
from .iter_transformation import IterTransformation
Expand Down
91 changes: 91 additions & 0 deletions pythran/optimizations/copyto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
""" Replaces a[:] = b by a call to numpy.copyto. """

from pythran.passmanager import Transformation
from pythran.analyses.ast_matcher import ASTMatcher, AST_any
from pythran.conversion import mangle
from pythran.utils import isnum

import gast as ast
import copy


class CopyTo(Transformation):

"""
Replaces a[:] = b by a call to numpy.copyto.
This is a slight extension to numpy.copyto as it assumes it also supports
string and list as first argument.
>>> import gast as ast
>>> from pythran import passmanager, backend
>>> node = ast.parse('a[:] = b')
>>> pm = passmanager.PassManager("test")
>>> _, node = pm.apply(CopyTo, node)
>>> print(pm.dump(backend.Python, node))
import numpy as __pythran_import_numpy
__pythran_import_numpy.copyto(a, b)
>>> node = ast.parse('a[:] = b[:]')
>>> pm = passmanager.PassManager("test")
>>> _, node = pm.apply(CopyTo, node)
>>> print(pm.dump(backend.Python, node))
import numpy as __pythran_import_numpy
__pythran_import_numpy.copyto(a, b)
"""

def isNone(self, node):
if node is None:
return True
return isinstance(node, ast.Constant) and node.value is None

def is_full_slice(self, node):
# FIXME: could accept a call to len for node.upper
return (
isinstance(node, ast.Slice) and
(node.lower == 0 or self.isNone(node.lower)) and
(self.isNone(node.upper)) and
(self.isNone(node.step) or node.step == 1)
)

def is_fully_sliced(self, node):
if not isinstance(node, ast.Subscript):
return False
if not isinstance(node.value, ast.Name):
return False
if self.is_full_slice(node.slice):
return True
elif isinstance(node.slice, ast.Tuple):
return all(self.is_full_slice(elt) for elt in node.slice.elts)
else:
return False

def visit_Module(self, node):
self.generic_visit(node)
if self.update:
import_alias = ast.alias(name='numpy', asname=mangle('numpy'))
importIt = ast.Import(names=[import_alias])
node.body.insert(0, importIt)
return node

def visit_Assign(self, node):
if len(node.targets) != 1:
return node
target, = node.targets
if not self.is_fully_sliced(target):
return node
if self.is_fully_sliced(node.value):
value = node.value.value
else:
value = node.value

self.update = True
return ast.Expr(
ast.Call(
ast.Attribute(ast.Name(mangle('numpy'), ast.Load(), None, None),
'copyto',
ast.Load()),
[target.value, value],
[])
)


18 changes: 17 additions & 1 deletion pythran/pythonic/include/numpy/copyto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,23 @@ PYTHONIC_NS_BEGIN
namespace numpy
{
template <class T, class pS, class E>
types::ndarray<T, pS> copyto(types::ndarray<T, pS> &out, E const &expr);
types::none_type copyto(types::ndarray<T, pS> &out, E const &expr);

template <class T, class pS, class E>
types::none_type copyto(types::ndarray<T, pS> &&out, E const &expr);

template <class T, class pS, class E>
types::none_type copyto(types::numpy_texpr<types::ndarray<T, pS>> &out, E const &expr);

template <class T, class pS, class E>
types::none_type copyto(types::numpy_texpr<types::ndarray<T, pS>> &&out, E const &expr);

// pythran extensions
template <class E, class F>
types::none_type copyto(E &out, F const &expr) {
out[types::fast_contiguous_slice(0, types::none_type{})] = expr;
return {};
}

DEFINE_FUNCTOR(pythonic::numpy, copyto);
}
Expand Down
9 changes: 7 additions & 2 deletions pythran/pythonic/include/types/list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ namespace types
typename std::remove_cv<typename std::remove_reference<T>::type>::type
_type;
typedef container<_type> container_type;
utils::shared_ref<container_type> data;
utils::shared_ref<container_type> _data;

template <class U>
friend class list;
Expand Down Expand Up @@ -100,6 +100,7 @@ namespace types
static const bool is_vectorizable =
types::is_vectorizable_dtype<dtype>::value &&
!std::is_same<S, slice>::value;
static const bool is_flat = std::is_same<slice, S>::value;
static const bool is_strided = std::is_same<slice, S>::value;

using shape_t = types::array<long, value>;
Expand Down Expand Up @@ -192,7 +193,7 @@ namespace types
typename std::remove_cv<typename std::remove_reference<T>::type>::type
_type;
typedef container<_type> container_type;
utils::shared_ref<container_type> data;
utils::shared_ref<container_type> _data;

template <class U, class S>
friend class sliced_list;
Expand Down Expand Up @@ -220,6 +221,7 @@ namespace types
typedef typename utils::nested_container_value_type<list>::type dtype;
static const size_t value = utils::nested_container_depth<list>::value;
static const bool is_vectorizable = types::is_vectorizable<dtype>::value;
static const bool is_flat = true;
static const bool is_strided = false;

// constructors
Expand Down Expand Up @@ -325,6 +327,9 @@ namespace types
return fast(index);
}

dtype* data() { return _data->data();}
const dtype* data() const { return _data->data();}

// modifiers
template <class Tp>
void push_back(Tp &&x);
Expand Down
3 changes: 3 additions & 0 deletions pythran/pythonic/include/types/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ namespace types
template <class T, class pS>
struct ndarray {
static const bool is_vectorizable = types::is_vectorizable<T>::value;
static const bool is_flat = true;
static const bool is_strided = false;

/* types */
Expand Down Expand Up @@ -646,6 +647,8 @@ namespace types
flat_iterator fend();

/* member functions */
T* data() { return buffer;}
T const* data() const { return buffer;}
long flat_size() const;
bool may_overlap(ndarray const &) const;

Expand Down
2 changes: 2 additions & 0 deletions pythran/pythonic/include/types/numpy_broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ namespace types
using broadcast_or_broadcasted_t =
typename broadcast_or_broadcasted<Tp, is_dtype<Tp>::value>::type;
static const bool is_vectorizable = true;
static const bool is_flat = false;
static const bool is_strided = false;
using dtype = typename std::remove_reference<T>::type::dtype;
using value_type = typename std::remove_reference<T>::type::value_type;
Expand Down Expand Up @@ -279,6 +280,7 @@ namespace types
// always)
using dtype = typename broadcast_dtype<T, B>::type;
static const bool is_vectorizable = types::is_vectorizable<dtype>::value;
static const bool is_flat = false;
static const bool is_strided = false;
using value_type = dtype;
using const_iterator = const_broadcast_iterator<dtype>;
Expand Down
1 change: 1 addition & 0 deletions pythran/pythonic/include/types/numpy_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ namespace types
Args>::type>::type::dtype>::value...>::value &&
types::is_vector_op<
Op, typename std::remove_reference<Args>::type::dtype...>::value;
static const bool is_flat = false;
static const bool is_strided =
utils::any_of<std::remove_reference<Args>::type::is_strided...>::value;

Expand Down
6 changes: 6 additions & 0 deletions pythran/pythonic/include/types/numpy_gexpr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,9 @@ namespace types
std::is_same<contiguous_normalized_slice,
typename std::tuple_element<
sizeof...(S) - 1, std::tuple<S...>>::type>::value);
static const bool is_flat =
std::remove_reference<Arg>::type::is_flat && value == 1 &&
utils::all_of<std::is_same<contiguous_normalized_slice, S>::value...>::value;
static const bool is_strided =
std::remove_reference<Arg>::type::is_strided ||
(((sizeof...(S) - count_long<S...>::value) == value) &&
Expand Down Expand Up @@ -874,6 +877,9 @@ namespace types

explicit operator bool() const;


dtype* data() { return buffer;}
const dtype* data() const { return buffer;}
long flat_size() const;
long size() const;
ndarray<dtype, shape_t> copy() const
Expand Down
5 changes: 5 additions & 0 deletions pythran/pythonic/include/types/numpy_iexpr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ namespace types
static constexpr size_t value = std::remove_reference<Arg>::type::value - 1;
static const bool is_vectorizable =
std::remove_reference<Arg>::type::is_vectorizable;
static const bool is_flat =
std::remove_reference<Arg>::type::is_flat;
using dtype = typename std::remove_reference<Arg>::type::dtype;
using value_type =
typename std::remove_reference<decltype(numpy_iexpr_helper<value>::get(
Expand Down Expand Up @@ -334,6 +336,9 @@ namespace types
return (*this)[std::get<0>(index)];
}

dtype* data() { return buffer;}
const dtype* data() const { return buffer;}

private:
/* compute the buffer offset, returning the offset between the
* first element of the iexpr and the start of the buffer.
Expand Down
1 change: 1 addition & 0 deletions pythran/pythonic/include/types/numpy_texpr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace types
struct numpy_texpr_2 {
static_assert(E::value == 2, "texpr only implemented for matrices");
static const bool is_vectorizable = false;
static const bool is_flat = false;
static const bool is_strided = true;
using Arg = E;

Expand Down
1 change: 1 addition & 0 deletions pythran/pythonic/include/types/numpy_vexpr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace types

static constexpr size_t value = T::value;
static const bool is_vectorizable = false;
static const bool is_flat = false;
using dtype = typename dtype_of<T>::type;
using value_type = T;
static constexpr bool is_strided = T::is_strided;
Expand Down
1 change: 1 addition & 0 deletions pythran/pythonic/include/types/tuple.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ namespace types
static const size_t value =
utils::nested_container_depth<array_base>::value;
static const bool is_vectorizable = true;
static const bool is_flat = true;
static const bool is_strided = false;

// flat_size implementation
Expand Down
56 changes: 53 additions & 3 deletions pythran/pythonic/numpy/copyto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define PYTHONIC_NUMPY_COPYTO_HPP

#include "pythonic/include/numpy/copyto.hpp"
#include "pythonic//numpy/asarray.hpp"

#include "pythonic/utils/functor.hpp"
#include "pythonic/types/ndarray.hpp"
Expand All @@ -10,10 +11,59 @@ PYTHONIC_NS_BEGIN
namespace numpy
{
template <class T, class pS, class E>
types::ndarray<T, pS> copyto(types::ndarray<T, pS> &out, E const &expr)
types::none_type copyto(types::ndarray<T, pS> &out, E const &expr)
{
out[types::contiguous_slice(0, types::none_type{})] = expr;
return out;
using out_type = types::ndarray<T, pS>;
if (may_overlap(out, expr)) {
auto aexpr = asarray(expr);
utils::broadcast_copy < out_type &, decltype(aexpr), out_type::value,
(int)out_type::value - (int)utils::dim_of<E>::value,
out_type::is_vectorizable &&
std::is_same<typename out_type::dtype, typename types::dtype_of<E>::type>::value &&
types::is_vectorizable<E>::value > (out, aexpr);
}
else {
utils::broadcast_copy < out_type &, E, out_type::value,
(int)out_type::value - (int)utils::dim_of<E>::value,
out_type::is_vectorizable &&
std::is_same<typename out_type::dtype, typename types::dtype_of<E>::type>::value &&
types::is_vectorizable<E>::value > (out, expr);
}
return {};
}

template <class T, class pS, class E>
types::none_type copyto(types::ndarray<T, pS> &&out, E const &expr)
{
return copyto(out, expr);
}

template <class T, class pS, class E>
types::none_type copyto(types::numpy_texpr<types::ndarray<T, pS>> &out, E const &expr)
{
using out_type = types::numpy_texpr<types::ndarray<T, pS>>;
if (may_overlap(out, expr)) {
auto aexpr = asarray(expr);
utils::broadcast_copy < out_type &, decltype(aexpr), out_type::value,
(int)out_type::value - (int)utils::dim_of<E>::value,
out_type::is_vectorizable &&
std::is_same<typename out_type::dtype, typename types::dtype_of<E>::type>::value &&
types::is_vectorizable<E>::value > (out, aexpr);
}
else {
utils::broadcast_copy < out_type &, E, out_type::value,
(int)out_type::value - (int)utils::dim_of<E>::value,
out_type::is_vectorizable &&
std::is_same<typename out_type::dtype, typename types::dtype_of<E>::type>::value &&
types::is_vectorizable<E>::value > (out, expr);
}
return {};
}

template <class T, class pS, class E>
types::none_type copyto(types::numpy_texpr<types::ndarray<T, pS>> &&out, E const &expr)
{
return copyto(out, expr);
}
}
PYTHONIC_NS_END
Expand Down
Loading

0 comments on commit 46da5e5

Please sign in to comment.