Skip to content

Commit

Permalink
Add Python bindings to mops-torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Nov 21, 2023
1 parent 2046740 commit 207fb89
Show file tree
Hide file tree
Showing 16 changed files with 305 additions and 15 deletions.
18 changes: 7 additions & 11 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,10 @@ jobs:
python -m pip install --upgrade pip
python -m pip install tox
- name: run C++ build and tests
run: |
mkdir build
cd build
cmake -DMOPS_TESTS=ON ../mops/
cmake --build .
ctest
- name: run Python tests
run: python -m tox
- name: run tests
run: tox
env:
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu

# check that we can build Python wheels on any Python version
python-build:
Expand All @@ -57,7 +51,9 @@ jobs:
- name: install python dependencies
run: |
python -m pip install --upgrade pip
python -m pip install tox wheel
python -m pip install tox
- name: python build tests
run: tox -e build-python
env:
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu
5 changes: 3 additions & 2 deletions mops-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ if (${MOPS_TORCH_MAIN_PROJECT})
endif()

find_package(Torch 1.11 REQUIRED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../mops mops)

set(BUILD_SHARED_LIBS OFF)
add_subdirectory(mops EXCLUDE_FROM_ALL)

add_library(mops_torch SHARED
"src/register.cpp"
Expand Down Expand Up @@ -97,7 +99,6 @@ install(TARGETS mops_torch
)

install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR})
install(DIRECTORY ${PROJECT_BINARY_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR})

# Install files to find mops in CMake projects
configure_file(
Expand Down
1 change: 1 addition & 0 deletions mops-torch/mops
7 changes: 6 additions & 1 deletion mops-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
#include "mops/torch/opsa.hpp"

TORCH_LIBRARY(mops, m) {
m.def("outer_product_scatter_add", mops_torch::outer_product_scatter_add);
m.def(
"outer_product_scatter_add("
"Tensor A, Tensor B, Tensor indices_output, int output_size"
") -> Tensor",
mops_torch::outer_product_scatter_add
);
}
1 change: 1 addition & 0 deletions python/mops-torch/LICENSE
5 changes: 5 additions & 0 deletions python/mops-torch/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
include pyproject.toml
include LICENSE
include VERSION

recursive-include lib *
3 changes: 3 additions & 0 deletions python/mops-torch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Mops-torch

This is the TorchScript version of mops
1 change: 1 addition & 0 deletions python/mops-torch/VERSION
1 change: 1 addition & 0 deletions python/mops-torch/lib
62 changes: 62 additions & 0 deletions python/mops-torch/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
[project]
name = "mops_torch"
dynamic = ["version"]
requires-python = ">=3.7"

readme = "README.md"
license = {text = "BSD-3-Clause"}
description = "" # TODO
authors = [
# TODO
]

dependencies = [
"torch >= 1.11",
]

keywords = [] # TODO
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Operating System :: POSIX",
"Operating System :: MacOS :: MacOS X",
"Operating System :: Microsoft :: Windows",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Bio-Informatics",
"Topic :: Scientific/Engineering :: Chemistry",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
]

[project.urls]
# homepage = "TODO"
# documentation = "TODO"
repository = "https://github.com/lab-cosmo/mops"
# changelog = "TODO"

### ======================================================================== ###

[build-system]
requires = [
"setuptools >=68",
"cmake",
"torch >= 1.11",
]
build-backend = "setuptools.build_meta"

[tool.setuptools]
zip-safe = false

[tool.setuptools.packages.find]
where = ["src"]
include = ["mops*"]
namespaces = true

### ======================================================================== ###
[tool.pytest.ini_options]
python_files = ["*.py"]
testpaths = ["tests"]
84 changes: 84 additions & 0 deletions python/mops-torch/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import subprocess
import sys

from setuptools import Extension, setup
from setuptools.command.bdist_egg import bdist_egg
from setuptools.command.build_ext import build_ext

ROOT = os.path.realpath(os.path.dirname(__file__))


class cmake_ext(build_ext):
"""Build the native library using cmake"""

def run(self):
import torch

source_dir = os.path.join(ROOT, "lib")
build_dir = os.path.join(ROOT, "build")
install_dir = os.path.join(os.path.realpath(self.build_lib), "mops", "torch")

os.makedirs(build_dir, exist_ok=True)

cmake_options = [
"-DCMAKE_BUILD_TYPE=Release",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}",
]

subprocess.run(
["cmake", source_dir, *cmake_options],
cwd=build_dir,
check=True,
)
subprocess.run(
[
"cmake",
"--build",
build_dir,
"--config",
"Release",
"--target",
"install",
],
check=True,
)


class bdist_egg_disabled(bdist_egg):
"""Disabled version of bdist_egg
Prevents setup.py install performing setuptools' default easy_install,
which it should never ever do.
"""

def run(self):
sys.exit(
"Aborting implicit building of eggs.\nUse `pip install .` or "
"`python -m build --wheel . && pip install dist/mops_torch-*.whl` "
"to install from source."
)


if __name__ == "__main__":
with open(os.path.join(ROOT, "VERSION")) as fd:
version = fd.read().strip()

setup(
version=version,
ext_modules=[
Extension(name="mops_torch", sources=[]),
],
cmdclass={
"build_ext": cmake_ext,
"bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled,
},
package_data={
"mops-torch": [
"mops/torch/bin/*",
"mops/torch/lib/*",
"mops/torch/include/*",
]
},
)
7 changes: 7 additions & 0 deletions python/mops-torch/src/mops/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch

from ._c_lib import _lib_path

torch.ops.load_library(_lib_path())

outer_product_scatter_add = torch.ops.mops.outer_product_scatter_add
61 changes: 61 additions & 0 deletions python/mops-torch/src/mops/torch/_c_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import sys

_HERE = os.path.realpath(os.path.dirname(__file__))

# TODO: check that mops_torch was compiled for a compatible version of torch


def _lib_path():
if sys.platform.startswith("darwin"):
path = os.path.join(_HERE, "lib", "libmops_torch.dylib")
windows = False
elif sys.platform.startswith("linux"):
path = os.path.join(_HERE, "lib", "libmops_torch.so")
windows = False
elif sys.platform.startswith("win"):
path = os.path.join(_HERE, "bin", "mops_torch.dll")
windows = True
else:
raise ImportError("Unknown platform. Please edit this file")

if os.path.isfile(path):
if windows:
_check_dll(path)
return path

raise ImportError("Could not find mops_torch shared library at " + path)


def _check_dll(path):
"""
Check if the DLL pointer size matches Python (32-bit or 64-bit)
"""
import platform
import struct

IMAGE_FILE_MACHINE_I386 = 332
IMAGE_FILE_MACHINE_AMD64 = 34404

machine = None
with open(path, "rb") as fd:
header = fd.read(2).decode(encoding="utf-8", errors="strict")
if header != "MZ":
raise ImportError(path + " is not a DLL")
else:
fd.seek(60)
header = fd.read(4)
header_offset = struct.unpack("<L", header)[0]
fd.seek(header_offset + 4)
header = fd.read(2)
machine = struct.unpack("<H", header)[0]

arch = platform.architecture()[0]
if arch == "32bit":
if machine != IMAGE_FILE_MACHINE_I386:
raise ImportError("Python is 32-bit, but this DLL is not")
elif arch == "64bit":
if machine != IMAGE_FILE_MACHINE_AMD64:
raise ImportError("Python is 64-bit, but this DLL is not")
else:
raise ImportError("Could not determine pointer size of Python")
47 changes: 47 additions & 0 deletions python/mops-torch/tests/opsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch

import mops.torch
from mops import reference_implementations as ref


torch.manual_seed(0xDEADBEEF)


def test_opsa():
print(mops.torch)
print(dir(mops.torch))

A = torch.rand(100, 20)
B = torch.rand(100, 5)

output_size = 10

indices = torch.sort(
torch.randint(output_size, size=(100,), dtype=torch.int32)
).values
# substitute all 1s by 2s so as to test the no-neighbor case
indices[indices == 1] = 2

reference = torch.tensor(
ref.outer_product_scatter_add(
A.numpy(), B.numpy(), indices.numpy(), output_size
)
)
actual = mops.torch.outer_product_scatter_add(A, B, indices, output_size)
assert torch.allclose(reference, actual)


def test_opsa_grad():
A = torch.rand(100, 20, dtype=torch.float64, requires_grad=True)
B = torch.rand(100, 5, dtype=torch.float64, requires_grad=True)

output_size = 10
indices = torch.sort(
torch.randint(output_size, size=(100,), dtype=torch.int32)
).values

assert torch.autograd.gradcheck(
mops.torch.outer_product_scatter_add,
(A, B, indices, output_size),
fast_mode=True,
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run(self):
os.makedirs(build_dir, exist_ok=True)

cmake_options = [
"-DCMAKE_BUILD_TYPE=Debug",
"-DCMAKE_BUILD_TYPE=Release",
"-DBUILD_SHARED_LIBS=ON",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
]
Expand Down
15 changes: 15 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ min_version = 4.0
# execute `tox` in the command-line without anything else
envlist =
python-tests
torch-tests
cxx-tests
torch-cxx-tests

Expand All @@ -15,6 +16,18 @@ deps =
commands =
pytest --import-mode=append {posargs}


[testenv:torch-tests]
passenv = *
deps =
pytest
torch

changedir = python/mops-torch
commands =
pip install .
pytest --import-mode=append --assert=plain {posargs}

[testenv:cxx-tests]
package = skip
passenv = *
Expand Down Expand Up @@ -62,5 +75,7 @@ allowlist_externals =
commands =
# check building sdist and wheels from a checkout
python -m build . --outdir dist
python -m build python/mops-torch --outdir dist

twine check dist/*.tar.gz
twine check dist/*.whl

0 comments on commit 207fb89

Please sign in to comment.