Skip to content

Commit

Permalink
Merge branch 'master' into port-polyeval-cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Nov 20, 2023
2 parents 8cf2d79 + 4232a2e commit ca7d0a5
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 12 deletions.
4 changes: 2 additions & 2 deletions mops/include/mops/sap.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef MOPS_OUTER_PRODUCT_SCATTER_ADD_H
#define MOPS_OUTER_PRODUCT_SCATTER_ADD_H
#ifndef MOPS_SPARSE_ACCUMULATION_OF_PRODUCTS_H
#define MOPS_SPARSE_ACCUMULATION_OF_PRODUCTS_H

#include "mops/exports.h"
#include "mops/tensor.h"
Expand Down
4 changes: 2 additions & 2 deletions mops/include/mops/sap.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef MOPS_OUTER_PRODUCT_SCATTER_ADD_HPP
#define MOPS_OUTER_PRODUCT_SCATTER_ADD_HPP
#ifndef MOPS_SPARSE_ACCUMULATION_OF_PRODUCTS_HPP
#define MOPS_SPARSE_ACCUMULATION_OF_PRODUCTS_HPP

#include <cstddef>
#include <cstdint>
Expand Down
1 change: 0 additions & 1 deletion python/mops/src/mops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .hpe import homogeneous_polynomial_evaluation # noqa
from .sap import sparse_accumulation_of_products # noqa
from .opsa import outer_product_scatter_add # noqa

47 changes: 46 additions & 1 deletion python/mops/src/mops/_c_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class mops_tensor_2d_f64_t(ctypes.Structure):

class mops_tensor_1d_f32_t(ctypes.Structure):
_fields_ = [
("data", ctypes.POINTER(ctypes.c_double)),
("data", ctypes.POINTER(ctypes.c_float)),
("shape", ctypes.ARRAY(ctypes.c_int64, 1)),
]

Expand Down Expand Up @@ -159,3 +159,48 @@ def setup_functions(lib):
mops_tensor_1d_i32_t,
]
lib.mops_cuda_outer_product_scatter_add_f64.restype = _check_status

# sparse_accumulation_of_products
lib.mops_sparse_accumulation_of_products_f32.argtypes = [
mops_tensor_2d_f32_t,
mops_tensor_2d_f32_t,
mops_tensor_2d_f32_t,
mops_tensor_1d_f32_t,
mops_tensor_1d_i32_t,
mops_tensor_1d_i32_t,
mops_tensor_1d_i32_t,
]
lib.mops_sparse_accumulation_of_products_f32.restype = _check_status

lib.mops_sparse_accumulation_of_products_f64.argtypes = [
mops_tensor_2d_f64_t,
mops_tensor_2d_f64_t,
mops_tensor_2d_f64_t,
mops_tensor_1d_f64_t,
mops_tensor_1d_i32_t,
mops_tensor_1d_i32_t,
mops_tensor_1d_i32_t,
]
lib.mops_sparse_accumulation_of_products_f64.restype = _check_status

lib.mops_cuda_sparse_accumulation_of_products_f32.argtypes = [
mops_tensor_2d_f32_t,
mops_tensor_2d_f32_t,
mops_tensor_2d_f32_t,
mops_tensor_1d_f32_t,
mops_tensor_1d_i32_t,
mops_tensor_1d_i32_t,
mops_tensor_1d_i32_t,
]
lib.mops_cuda_sparse_accumulation_of_products_f32.restype = _check_status

lib.mops_cuda_sparse_accumulation_of_products_f64.argtypes = [
mops_tensor_2d_f64_t,
mops_tensor_2d_f64_t,
mops_tensor_2d_f64_t,
mops_tensor_1d_f64_t,
mops_tensor_1d_i32_t,
mops_tensor_1d_i32_t,
mops_tensor_1d_i32_t,
]
lib.mops_cuda_sparse_accumulation_of_products_f64.restype = _check_status
5 changes: 4 additions & 1 deletion python/mops/tests/opsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,8 @@ def test_opsa_no_neighbors():

def test_opsa_wrong_type():

with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="A must be a 2D array in opsa, got a 1D array"
):
opsa(np.array([1]), 2, 3, 4)
7 changes: 2 additions & 5 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,17 @@ deps =
wheel
cmake
twine
build

allowlist_externals =
bash

commands =
# check building sdist and wheels from a checkout
python setup.py sdist
python setup.py bdist_wheel
python -m build . --outdir dist
twine check dist/*.tar.gz
twine check dist/*.whl

# check building wheels from the sdist
bash -c "python -m pip wheel --verbose dist/mops-*.tar.gz -w dist/test"

; [testenv:torch-cxx-tests]
; package = skip
; passenv = *
Expand Down

0 comments on commit ca7d0a5

Please sign in to comment.