| 
 | 1 | +__copyright__ = """  | 
 | 2 | +Copyright (C) 2024 University of Illinois Board of Trustees  | 
 | 3 | +"""  | 
 | 4 | + | 
 | 5 | +__license__ = """  | 
 | 6 | +Permission is hereby granted, free of charge, to any person obtaining a copy  | 
 | 7 | +of this software and associated documentation files (the "Software"), to deal  | 
 | 8 | +in the Software without restriction, including without limitation the rights  | 
 | 9 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell  | 
 | 10 | +copies of the Software, and to permit persons to whom the Software is  | 
 | 11 | +furnished to do so, subject to the following conditions:  | 
 | 12 | +
  | 
 | 13 | +The above copyright notice and this permission notice shall be included in  | 
 | 14 | +all copies or substantial portions of the Software.  | 
 | 15 | +
  | 
 | 16 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR  | 
 | 17 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  | 
 | 18 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE  | 
 | 19 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER  | 
 | 20 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,  | 
 | 21 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN  | 
 | 22 | +THE SOFTWARE.  | 
 | 23 | +"""  | 
 | 24 | +from functools import partial, reduce  | 
 | 25 | + | 
 | 26 | +import cupy as cp  # type: ignore[import-untyped]  | 
 | 27 | + | 
 | 28 | +from arraycontext.container import is_array_container  | 
 | 29 | +from arraycontext.container.traversal import (  | 
 | 30 | +    multimap_reduce_array_container, rec_map_array_container,  | 
 | 31 | +    rec_map_reduce_array_container, rec_multimap_array_container,  | 
 | 32 | +    rec_multimap_reduce_array_container)  | 
 | 33 | +from arraycontext.fake_numpy import (  | 
 | 34 | +    BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace)  | 
 | 35 | + | 
 | 36 | + | 
 | 37 | +class CupyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):  | 
 | 38 | +    # Everything is implemented in the base class for now.  | 
 | 39 | +    pass  | 
 | 40 | + | 
 | 41 | + | 
 | 42 | +_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan",  | 
 | 43 | +                 "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan",  | 
 | 44 | +                 "sqrt", "concatenate", "transpose",  | 
 | 45 | +                 "ones_like", "maximum", "minimum", "where", "conj", "arctan2",  | 
 | 46 | +                 }  | 
 | 47 | + | 
 | 48 | + | 
 | 49 | +class CupyFakeNumpyNamespace(BaseFakeNumpyNamespace):  | 
 | 50 | +    """  | 
 | 51 | +    A :mod:`numpy` mimic for :class:`CupyArrayContext`.  | 
 | 52 | +    """  | 
 | 53 | +    def _get_fake_numpy_linalg_namespace(self):  | 
 | 54 | +        return CupyFakeNumpyLinalgNamespace(self._array_context)  | 
 | 55 | + | 
 | 56 | +    def __getattr__(self, name):  | 
 | 57 | + | 
 | 58 | +        if name in _NUMPY_UFUNCS:  | 
 | 59 | +            from functools import partial  | 
 | 60 | +            return partial(rec_multimap_array_container,  | 
 | 61 | +                           getattr(cp, name))  | 
 | 62 | + | 
 | 63 | +        raise NotImplementedError  | 
 | 64 | + | 
 | 65 | +    def sum(self, a, axis=None, dtype=None):  | 
 | 66 | +        return rec_map_reduce_array_container(sum, partial(cp.sum,  | 
 | 67 | +                                                           axis=axis,  | 
 | 68 | +                                                           dtype=dtype),  | 
 | 69 | +                                              a)  | 
 | 70 | + | 
 | 71 | +    def min(self, a, axis=None):  | 
 | 72 | +        return rec_map_reduce_array_container(  | 
 | 73 | +                partial(reduce, cp.minimum), partial(cp.amin, axis=axis), a)  | 
 | 74 | + | 
 | 75 | +    def max(self, a, axis=None):  | 
 | 76 | +        return rec_map_reduce_array_container(  | 
 | 77 | +                partial(reduce, cp.maximum), partial(cp.amax, axis=axis), a)  | 
 | 78 | + | 
 | 79 | +    def stack(self, arrays, axis=0):  | 
 | 80 | +        return rec_multimap_array_container(  | 
 | 81 | +                lambda *args: cp.stack(args, axis=axis),  | 
 | 82 | +                *arrays)  | 
 | 83 | + | 
 | 84 | +    def broadcast_to(self, array, shape):  | 
 | 85 | +        return rec_map_array_container(partial(cp.broadcast_to, shape=shape), array)  | 
 | 86 | + | 
 | 87 | +    # {{{ relational operators  | 
 | 88 | + | 
 | 89 | +    def equal(self, x, y):  | 
 | 90 | +        return rec_multimap_array_container(cp.equal, x, y)  | 
 | 91 | + | 
 | 92 | +    def not_equal(self, x, y):  | 
 | 93 | +        return rec_multimap_array_container(cp.not_equal, x, y)  | 
 | 94 | + | 
 | 95 | +    def greater(self, x, y):  | 
 | 96 | +        return rec_multimap_array_container(cp.greater, x, y)  | 
 | 97 | + | 
 | 98 | +    def greater_equal(self, x, y):  | 
 | 99 | +        return rec_multimap_array_container(cp.greater_equal, x, y)  | 
 | 100 | + | 
 | 101 | +    def less(self, x, y):  | 
 | 102 | +        return rec_multimap_array_container(cp.less, x, y)  | 
 | 103 | + | 
 | 104 | +    def less_equal(self, x, y):  | 
 | 105 | +        return rec_multimap_array_container(cp.less_equal, x, y)  | 
 | 106 | + | 
 | 107 | +    # }}}  | 
 | 108 | + | 
 | 109 | +    def ravel(self, a, order="C"):  | 
 | 110 | +        return rec_map_array_container(partial(cp.ravel, order=order), a)  | 
 | 111 | + | 
 | 112 | +    def vdot(self, x, y, dtype=None):  | 
 | 113 | +        if dtype is not None:  | 
 | 114 | +            raise NotImplementedError("only 'dtype=None' supported.")  | 
 | 115 | + | 
 | 116 | +        return rec_multimap_reduce_array_container(sum, cp.vdot, x, y)  | 
 | 117 | + | 
 | 118 | +    def any(self, a):  | 
 | 119 | +        return rec_map_reduce_array_container(partial(reduce, cp.logical_or),  | 
 | 120 | +                                              lambda subary: cp.any(subary), a)  | 
 | 121 | + | 
 | 122 | +    def all(self, a):  | 
 | 123 | +        return rec_map_reduce_array_container(partial(reduce, cp.logical_and),  | 
 | 124 | +                                              lambda subary: cp.all(subary), a)  | 
 | 125 | + | 
 | 126 | +    def array_equal(self, a, b):  | 
 | 127 | +        if type(a) is not type(b):  | 
 | 128 | +            return False  | 
 | 129 | +        elif not is_array_container(a):  | 
 | 130 | +            if a.shape != b.shape:  | 
 | 131 | +                return False  | 
 | 132 | +            else:  | 
 | 133 | +                return cp.all(cp.equal(a, b))  | 
 | 134 | +        else:  | 
 | 135 | +            try:  | 
 | 136 | +                return multimap_reduce_array_container(partial(reduce,  | 
 | 137 | +                                                           cp.logical_and),  | 
 | 138 | +                                                   self.array_equal, a, b)  | 
 | 139 | +            except TypeError:  | 
 | 140 | +                return True  | 
 | 141 | + | 
 | 142 | +    def zeros_like(self, ary):  | 
 | 143 | +        return rec_multimap_array_container(cp.zeros_like, ary)  | 
 | 144 | + | 
 | 145 | +    def reshape(self, a, newshape, order="C"):  | 
 | 146 | +        return rec_map_array_container(  | 
 | 147 | +                lambda ary: ary.reshape(newshape, order=order),  | 
 | 148 | +                a)  | 
 | 149 | + | 
 | 150 | +    def arange(self, *args, **kwargs):  | 
 | 151 | +        return cp.arange(*args, **kwargs)  | 
 | 152 | + | 
 | 153 | +    def linspace(self, *args, **kwargs):  | 
 | 154 | +        return cp.linspace(*args, **kwargs)  | 
 | 155 | + | 
 | 156 | +# vim: fdm=marker  | 
0 commit comments