-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
305 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../mops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../LICENSE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
include pyproject.toml | ||
include LICENSE | ||
include VERSION | ||
|
||
recursive-include lib * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Mops-torch | ||
|
||
This is the TorchScript version of mops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../mops-torch/VERSION |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../mops-torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
[project] | ||
name = "mops_torch" | ||
dynamic = ["version"] | ||
requires-python = ">=3.7" | ||
|
||
readme = "README.md" | ||
license = {text = "BSD-3-Clause"} | ||
description = "" # TODO | ||
authors = [ | ||
# TODO | ||
] | ||
|
||
dependencies = [ | ||
"torch >= 1.11", | ||
] | ||
|
||
keywords = [] # TODO | ||
classifiers = [ | ||
"Development Status :: 4 - Beta", | ||
"Intended Audience :: Science/Research", | ||
"License :: OSI Approved :: BSD License", | ||
"Operating System :: POSIX", | ||
"Operating System :: MacOS :: MacOS X", | ||
"Operating System :: Microsoft :: Windows", | ||
"Programming Language :: Python", | ||
"Programming Language :: Python :: 3", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Bio-Informatics", | ||
"Topic :: Scientific/Engineering :: Chemistry", | ||
"Topic :: Scientific/Engineering :: Physics", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
] | ||
|
||
[project.urls] | ||
# homepage = "TODO" | ||
# documentation = "TODO" | ||
repository = "https://github.com/lab-cosmo/mops" | ||
# changelog = "TODO" | ||
|
||
### ======================================================================== ### | ||
|
||
[build-system] | ||
requires = [ | ||
"setuptools >=68", | ||
"cmake", | ||
"torch >= 1.11", | ||
] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools] | ||
zip-safe = false | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
include = ["mops*"] | ||
namespaces = true | ||
|
||
### ======================================================================== ### | ||
[tool.pytest.ini_options] | ||
python_files = ["*.py"] | ||
testpaths = ["tests"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
import subprocess | ||
import sys | ||
|
||
from setuptools import Extension, setup | ||
from setuptools.command.bdist_egg import bdist_egg | ||
from setuptools.command.build_ext import build_ext | ||
|
||
ROOT = os.path.realpath(os.path.dirname(__file__)) | ||
|
||
|
||
class cmake_ext(build_ext): | ||
"""Build the native library using cmake""" | ||
|
||
def run(self): | ||
import torch | ||
|
||
source_dir = os.path.join(ROOT, "lib") | ||
build_dir = os.path.join(ROOT, "build") | ||
install_dir = os.path.join(os.path.realpath(self.build_lib), "mops", "torch") | ||
|
||
os.makedirs(build_dir, exist_ok=True) | ||
|
||
cmake_options = [ | ||
"-DCMAKE_BUILD_TYPE=Release", | ||
f"-DCMAKE_INSTALL_PREFIX={install_dir}", | ||
f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}", | ||
] | ||
|
||
subprocess.run( | ||
["cmake", source_dir, *cmake_options], | ||
cwd=build_dir, | ||
check=True, | ||
) | ||
subprocess.run( | ||
[ | ||
"cmake", | ||
"--build", | ||
build_dir, | ||
"--config", | ||
"Release", | ||
"--target", | ||
"install", | ||
], | ||
check=True, | ||
) | ||
|
||
|
||
class bdist_egg_disabled(bdist_egg): | ||
"""Disabled version of bdist_egg | ||
Prevents setup.py install performing setuptools' default easy_install, | ||
which it should never ever do. | ||
""" | ||
|
||
def run(self): | ||
sys.exit( | ||
"Aborting implicit building of eggs.\nUse `pip install .` or " | ||
"`python -m build --wheel . && pip install dist/mops_torch-*.whl` " | ||
"to install from source." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
with open(os.path.join(ROOT, "VERSION")) as fd: | ||
version = fd.read().strip() | ||
|
||
setup( | ||
version=version, | ||
ext_modules=[ | ||
Extension(name="mops_torch", sources=[]), | ||
], | ||
cmdclass={ | ||
"build_ext": cmake_ext, | ||
"bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled, | ||
}, | ||
package_data={ | ||
"mops-torch": [ | ||
"mops/torch/bin/*", | ||
"mops/torch/lib/*", | ||
"mops/torch/include/*", | ||
] | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import torch | ||
|
||
from ._c_lib import _lib_path | ||
|
||
torch.ops.load_library(_lib_path()) | ||
|
||
outer_product_scatter_add = torch.ops.mops.outer_product_scatter_add |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import sys | ||
|
||
_HERE = os.path.realpath(os.path.dirname(__file__)) | ||
|
||
# TODO: check that mops_torch was compiled for a compatible version of torch | ||
|
||
|
||
def _lib_path(): | ||
if sys.platform.startswith("darwin"): | ||
path = os.path.join(_HERE, "lib", "libmops_torch.dylib") | ||
windows = False | ||
elif sys.platform.startswith("linux"): | ||
path = os.path.join(_HERE, "lib", "libmops_torch.so") | ||
windows = False | ||
elif sys.platform.startswith("win"): | ||
path = os.path.join(_HERE, "bin", "mops_torch.dll") | ||
windows = True | ||
else: | ||
raise ImportError("Unknown platform. Please edit this file") | ||
|
||
if os.path.isfile(path): | ||
if windows: | ||
_check_dll(path) | ||
return path | ||
|
||
raise ImportError("Could not find mops_torch shared library at " + path) | ||
|
||
|
||
def _check_dll(path): | ||
""" | ||
Check if the DLL pointer size matches Python (32-bit or 64-bit) | ||
""" | ||
import platform | ||
import struct | ||
|
||
IMAGE_FILE_MACHINE_I386 = 332 | ||
IMAGE_FILE_MACHINE_AMD64 = 34404 | ||
|
||
machine = None | ||
with open(path, "rb") as fd: | ||
header = fd.read(2).decode(encoding="utf-8", errors="strict") | ||
if header != "MZ": | ||
raise ImportError(path + " is not a DLL") | ||
else: | ||
fd.seek(60) | ||
header = fd.read(4) | ||
header_offset = struct.unpack("<L", header)[0] | ||
fd.seek(header_offset + 4) | ||
header = fd.read(2) | ||
machine = struct.unpack("<H", header)[0] | ||
|
||
arch = platform.architecture()[0] | ||
if arch == "32bit": | ||
if machine != IMAGE_FILE_MACHINE_I386: | ||
raise ImportError("Python is 32-bit, but this DLL is not") | ||
elif arch == "64bit": | ||
if machine != IMAGE_FILE_MACHINE_AMD64: | ||
raise ImportError("Python is 64-bit, but this DLL is not") | ||
else: | ||
raise ImportError("Could not determine pointer size of Python") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
|
||
import mops.torch | ||
from mops import reference_implementations as ref | ||
|
||
|
||
torch.manual_seed(0xDEADBEEF) | ||
|
||
|
||
def test_opsa(): | ||
print(mops.torch) | ||
print(dir(mops.torch)) | ||
|
||
A = torch.rand(100, 20) | ||
B = torch.rand(100, 5) | ||
|
||
output_size = 10 | ||
|
||
indices = torch.sort( | ||
torch.randint(output_size, size=(100,), dtype=torch.int32) | ||
).values | ||
# substitute all 1s by 2s so as to test the no-neighbor case | ||
indices[indices == 1] = 2 | ||
|
||
reference = torch.tensor( | ||
ref.outer_product_scatter_add( | ||
A.numpy(), B.numpy(), indices.numpy(), output_size | ||
) | ||
) | ||
actual = mops.torch.outer_product_scatter_add(A, B, indices, output_size) | ||
assert torch.allclose(reference, actual) | ||
|
||
|
||
def test_opsa_grad(): | ||
A = torch.rand(100, 20, dtype=torch.float64, requires_grad=True) | ||
B = torch.rand(100, 5, dtype=torch.float64, requires_grad=True) | ||
|
||
output_size = 10 | ||
indices = torch.sort( | ||
torch.randint(output_size, size=(100,), dtype=torch.int32) | ||
).values | ||
|
||
assert torch.autograd.gradcheck( | ||
mops.torch.outer_product_scatter_add, | ||
(A, B, indices, output_size), | ||
fast_mode=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters