diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 666dd39..950f98e 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -43,24 +43,13 @@ jobs: - name: Install dependencies run: | python -m pip install -U pip - python -m pip install -U coveralls coverage[toml] nox + python -m pip install -U nox - name: Run tests run: | python -m nox --non-interactive --error-on-missing-interpreter \ --session ${{ matrix.session }}-${{ matrix.python-version }} - - name: Combine and upload coverage - run: | - python -m coverage combine - python -m coverage xml -i - python -m coveralls --service=github - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COVERALLS_PARALLEL: true - COVERALLS_FLAG_NAME: ${{matrix.os}}-${{matrix.session}}-${{matrix.python-version}} - - tests-pymc: runs-on: ubuntu-latest defaults: @@ -88,19 +77,3 @@ jobs: run: | python -m nox --non-interactive --error-on-missing-interpreter \ --session pymc_mamba-3.10 - - coverage: - needs: tests - runs-on: ubuntu-latest - steps: - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: "3.9" - - name: Finish coverage collection - run: | - python -m pip install -U pip - python -m pip install -U coveralls - python -m coveralls --finish - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index ec02e3b..139f934 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -1,24 +1,34 @@ name: Wheels on: - release: - types: [published] + push: + branches: + - main + tags: + - "*" + workflow_dispatch: + inputs: + prerelease: + description: "Run a pre-release, testing the build" + required: false + type: boolean + default: false jobs: build_wheels: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: + - "ubuntu-22.04" + - "macos-12" + - "macos-14" + - "windows-latest" steps: - uses: actions/checkout@v4 with: submodules: true fetch-depth: 0 - uses: pypa/cibuildwheel@v2.19.0 - env: - CIBW_SKIP: "pp* *-win32 *-manylinux_i686 *-musllinux*" - CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014 - CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" - uses: actions/upload-artifact@v4 with: name: binary-${{ matrix.os }} @@ -34,28 +44,32 @@ jobs: - uses: actions/setup-python@v5 name: Install Python with: - python-version: "3.9" - - name: Build sdist + python-version: "3.10" + - name: Install dependencies run: | python -m pip install -U pip - python -m pip install -U build - python -m build --sdist . + python -m pip install -U build twine + - name: Build sdist + run: python -m build --sdist . + - name: Check the sdist + run: python -m twine check dist/*.tar.gz - uses: actions/upload-artifact@v4 with: name: sdist path: dist/*.tar.gz upload_pypi: + environment: + name: pypi + url: https://pypi.org/p/celerite2 + permissions: + id-token: write needs: [build_wheels, build_sdist] runs-on: ubuntu-latest - if: github.event_name == 'release' && github.event.action == 'published' + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: - uses: actions/download-artifact@v4 with: path: dist merge-multiple: true - - uses: pypa/gh-action-pypi-publish@v1.8.14 - with: - user: __token__ - password: ${{ secrets.pypi_password }} diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..a4c9f09 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,25 @@ +cmake_minimum_required(VERSION 3.15...3.27) +project( + ${SKBUILD_PROJECT_NAME} + VERSION ${SKBUILD_PROJECT_VERSION} + LANGUAGES CXX) + +set(PYBIND11_NEWPYTHON ON) +find_package(pybind11 CONFIG REQUIRED) + +include_directories( + "c++/include" + "c++/vendor/eigen" + "python/celerite2") + +pybind11_add_module(driver "python/celerite2/driver.cpp") +target_compile_features(driver PUBLIC cxx_std_14) +install(TARGETS driver LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) + +pybind11_add_module(backprop "python/celerite2/backprop.cpp") +target_compile_features(backprop PUBLIC cxx_std_14) +install(TARGETS backprop LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) + +pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp") +target_compile_features(xla_ops PUBLIC cxx_std_14) +install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax") diff --git a/c++/vendor/eigen b/c++/vendor/eigen index 0fd6b4f..d791d48 160000 --- a/c++/vendor/eigen +++ b/c++/vendor/eigen @@ -1 +1 @@ -Subproject commit 0fd6b4f71dd85b2009ee4d1aeb296e2c11fc9d68 +Subproject commit d791d48859c6fc7850c9fd5270d2b236c818068d diff --git a/noxfile.py b/noxfile.py index bc7e8ed..7d9bfe8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,7 +1,7 @@ import nox ALL_PYTHON_VS = ["3.8", "3.9", "3.10"] -TEST_CMD = ["coverage", "run", "-m", "pytest", "-v"] +TEST_CMD = ["python", "-m", "pytest", "-v"] def _session_run(session, path): diff --git a/pyproject.toml b/pyproject.toml index 1ef3c02..dc9ea68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,61 @@ -[build-system] -requires = [ - "setuptools>=40.6.0", - "wheel", - "setuptools_scm", - "oldest-supported-numpy", - "pybind11>=2.4", +[project] +name = "celerite2" +description = "Fast and scalable Gaussian Processes in 1D" +authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] +readme = "README.md" +requires-python = ">=3.9" +license = { text = "MIT License" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", +] +dynamic = ["version"] +dependencies = ["numpy"] + +[project.optional-dependencies] +test = ["pytest", "scipy", "celerite"] +pymc3 = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"] +theano = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"] +pymc = ["pymc>=5.9.2"] +jax = ["jax"] +docs = [ + "sphinx", + "sphinx-material", + "sphinx_copybutton", + "breathe", + "myst-nb", + "matplotlib", + "scipy", + "emcee", + "pymc>=5", + "tqdm", + "numpyro", ] -build-backend = "setuptools.build_meta" +tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5", "tqdm", "numpyro"] + +[project.urls] +"Homepage" = "https://celerite2.readthedocs.io" +"Source" = "https://github.com/exoplanet-dev/celerite2" +"Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues" + +[build-system] +requires = ["scikit-build-core", "numpy", "pybind11"] +build-backend = "scikit_build_core.build" + +[tool.scikit-build] +sdist.exclude = [] +sdist.include = ["python/celerite2/celerite2_version.py"] +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" + +[tool.setuptools_scm] +write_to = "python/celerite2/celerite2_version.py" + +[tool.cibuildwheel] +skip = "pp* *-win32 *-musllinux_* *-manylinux_i686" [tool.black] line-length = 79 diff --git a/python/celerite2/backprop.cpp b/python/celerite2/backprop.cpp index e5c813a..37478bd 100644 --- a/python/celerite2/backprop.cpp +++ b/python/celerite2/backprop.cpp @@ -9,782 +9,914 @@ namespace py = pybind11; namespace celerite2 { namespace driver { -auto factor_fwd(py::array_t t, py::array_t c, py::array_t a, - py::array_t U, py::array_t V, py::array_t d, - py::array_t W, py::array_t S) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info abuf = a.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Vbuf = V.request(); - py::buffer_info dbuf = d.request(); - py::buffer_info Wbuf = W.request(); - py::buffer_info Sbuf = S.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (abuf.ndim != 1 || abuf.shape[0] != N) throw std::invalid_argument("Invalid shape: a"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); - if (dbuf.ndim != 1 || dbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: d"); - if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); - if (Sbuf.ndim != 3 || Sbuf.shape[0] != N || Sbuf.shape[1] != J || Sbuf.shape[2] != J) throw std::invalid_argument("Invalid shape: S"); - - Eigen::Index flag = 0; -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map a_((const double *)abuf.ptr, N, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ - Eigen::Map d_((double *)dbuf.ptr, N, 1); \ - Eigen::Map::value>> W_((double *)Wbuf.ptr, N, J); \ - Eigen::Map::value>> S_((double *)Sbuf.ptr, N, J *J); \ - flag = celerite2::core::factor(t_, c_, a_, U_, V_, d_, W_, S_); \ - } - UNWRAP_CASES_MOST +auto factor_fwd ( + py::array_t t, + py::array_t c, + py::array_t a, + py::array_t U, + py::array_t V, + py::array_t d, + py::array_t W, + py::array_t S +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info abuf = a.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Vbuf = V.request(); + py::buffer_info dbuf = d.request(); + py::buffer_info Wbuf = W.request(); + py::buffer_info Sbuf = S.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (abuf.ndim != 1 || abuf.shape[0] != N) throw std::invalid_argument("Invalid shape: a"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); + if (dbuf.ndim != 1 || dbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: d"); + if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); + if (Sbuf.ndim != 3 || Sbuf.shape[0] != N || Sbuf.shape[1] != J || Sbuf.shape[2] != J) throw std::invalid_argument("Invalid shape: S"); + + Eigen::Index flag = 0; +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map a_((const double *)abuf.ptr, N, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ + Eigen::Map d_((double *)dbuf.ptr, N, 1); \ + Eigen::Map::value>> W_((double *)Wbuf.ptr, N, J); \ + Eigen::Map::value>> S_((double *)Sbuf.ptr, N, J * J); \ + flag = celerite2::core::factor(t_, c_, a_, U_, V_, d_, W_, S_); \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP - if (flag) throw backprop_linalg_exception(); - return std::make_tuple(d, W, S); + if (flag) throw backprop_linalg_exception(); + return std::make_tuple(d, W, S); } -auto factor_rev(py::array_t t, py::array_t c, py::array_t a, - py::array_t U, py::array_t V, py::array_t d, - py::array_t W, py::array_t S, py::array_t bd, - py::array_t bW, py::array_t bt, py::array_t bc, - py::array_t ba, py::array_t bU, py::array_t bV) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info abuf = a.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Vbuf = V.request(); - py::buffer_info dbuf = d.request(); - py::buffer_info Wbuf = W.request(); - py::buffer_info Sbuf = S.request(); - py::buffer_info bdbuf = bd.request(); - py::buffer_info bWbuf = bW.request(); - py::buffer_info btbuf = bt.request(); - py::buffer_info bcbuf = bc.request(); - py::buffer_info babuf = ba.request(); - py::buffer_info bUbuf = bU.request(); - py::buffer_info bVbuf = bV.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (abuf.ndim != 1 || abuf.shape[0] != N) throw std::invalid_argument("Invalid shape: a"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); - if (dbuf.ndim != 1 || dbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: d"); - if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); - if (Sbuf.ndim != 3 || Sbuf.shape[0] != N || Sbuf.shape[1] != J || Sbuf.shape[2] != J) throw std::invalid_argument("Invalid shape: S"); - if (bdbuf.ndim != 1 || bdbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bd"); - if (bWbuf.ndim != 2 || bWbuf.shape[0] != N || bWbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bW"); - if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); - if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); - if (babuf.ndim != 1 || babuf.shape[0] != N) throw std::invalid_argument("Invalid shape: ba"); - if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); - if (bVbuf.ndim != 2 || bVbuf.shape[0] != N || bVbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bV"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map a_((const double *)abuf.ptr, N, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ - Eigen::Map d_((const double *)dbuf.ptr, N, 1); \ - Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ - Eigen::Map::value>> S_((const double *)Sbuf.ptr, N, J *J); \ - Eigen::Map bd_((const double *)bdbuf.ptr, N, 1); \ - Eigen::Map::value>> bW_((const double *)bWbuf.ptr, N, J); \ - Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ - Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ - Eigen::Map ba_((double *)babuf.ptr, N, 1); \ - Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ - Eigen::Map::value>> bV_((double *)bVbuf.ptr, N, J); \ - celerite2::core::factor_rev(t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_, bc_, ba_, bU_, bV_); \ - } - UNWRAP_CASES_FEW +auto factor_rev ( + py::array_t t, + py::array_t c, + py::array_t a, + py::array_t U, + py::array_t V, + py::array_t d, + py::array_t W, + py::array_t S, + py::array_t bd, + py::array_t bW, + py::array_t bt, + py::array_t bc, + py::array_t ba, + py::array_t bU, + py::array_t bV +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info abuf = a.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Vbuf = V.request(); + py::buffer_info dbuf = d.request(); + py::buffer_info Wbuf = W.request(); + py::buffer_info Sbuf = S.request(); + py::buffer_info bdbuf = bd.request(); + py::buffer_info bWbuf = bW.request(); + py::buffer_info btbuf = bt.request(); + py::buffer_info bcbuf = bc.request(); + py::buffer_info babuf = ba.request(); + py::buffer_info bUbuf = bU.request(); + py::buffer_info bVbuf = bV.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (abuf.ndim != 1 || abuf.shape[0] != N) throw std::invalid_argument("Invalid shape: a"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); + if (dbuf.ndim != 1 || dbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: d"); + if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); + if (Sbuf.ndim != 3 || Sbuf.shape[0] != N || Sbuf.shape[1] != J || Sbuf.shape[2] != J) throw std::invalid_argument("Invalid shape: S"); + if (bdbuf.ndim != 1 || bdbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bd"); + if (bWbuf.ndim != 2 || bWbuf.shape[0] != N || bWbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bW"); + if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); + if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); + if (babuf.ndim != 1 || babuf.shape[0] != N) throw std::invalid_argument("Invalid shape: ba"); + if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); + if (bVbuf.ndim != 2 || bVbuf.shape[0] != N || bVbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bV"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map a_((const double *)abuf.ptr, N, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ + Eigen::Map d_((const double *)dbuf.ptr, N, 1); \ + Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ + Eigen::Map::value>> S_((const double *)Sbuf.ptr, N, J * J); \ + Eigen::Map bd_((const double *)bdbuf.ptr, N, 1); \ + Eigen::Map::value>> bW_((const double *)bWbuf.ptr, N, J); \ + Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ + Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ + Eigen::Map ba_((double *)babuf.ptr, N, 1); \ + Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ + Eigen::Map::value>> bV_((double *)bVbuf.ptr, N, J); \ + celerite2::core::factor_rev(t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_, bc_, ba_, bU_, bV_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP - return std::make_tuple(bt, bc, ba, bU, bV); + return std::make_tuple(bt, bc, ba, bU, bV); } -auto solve_lower_fwd(py::array_t t, py::array_t c, py::array_t U, - py::array_t W, py::array_t Y, py::array_t Z, - py::array_t F) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Wbuf = W.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ - Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((double *)Fbuf.ptr, N, J); \ - Z_.setZero(); \ - celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ - Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((double *)Fbuf.ptr, N, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto solve_lower_fwd ( + py::array_t t, + py::array_t c, + py::array_t U, + py::array_t W, + py::array_t Y, + py::array_t Z, + py::array_t F +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Wbuf = W.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ + Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((double *)Fbuf.ptr, N, J); \ + Z_.setZero(); \ + celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ + Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((double *)Fbuf.ptr, N, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP - return std::make_tuple(Z, F); + return std::make_tuple(Z, F); } -auto solve_lower_rev(py::array_t t, py::array_t c, py::array_t U, - py::array_t W, py::array_t Y, py::array_t Z, - py::array_t F, py::array_t bZ, - py::array_t bt, py::array_t bc, - py::array_t bU, py::array_t bW, - py::array_t bY) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Wbuf = W.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - py::buffer_info bZbuf = bZ.request(); - py::buffer_info btbuf = bt.request(); - py::buffer_info bcbuf = bc.request(); - py::buffer_info bUbuf = bU.request(); - py::buffer_info bWbuf = bW.request(); - py::buffer_info bYbuf = bY.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - if (bZbuf.ndim != 2 || bZbuf.shape[0] != N || bZbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bZ"); - if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); - if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); - if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); - if (bWbuf.ndim != 2 || bWbuf.shape[0] != N || bWbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bW"); - if (bYbuf.ndim != 2 || bYbuf.shape[0] != N || bYbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bY"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ - Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ - Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ - Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ - Eigen::Map::value>> bW_((double *)bWbuf.ptr, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ - Eigen::Map Z_((const double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((const double *)Fbuf.ptr, N, J); \ - Eigen::Map bZ_((const double *)bZbuf.ptr, N, 1); \ - Eigen::Map bY_((double *)bYbuf.ptr, N, 1); \ - celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ - Eigen::Map> Z_((const double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((const double *)Fbuf.ptr, N, J *nrhs); \ - Eigen::Map> bZ_((const double *)bZbuf.ptr, N, nrhs); \ - Eigen::Map> bY_((double *)bYbuf.ptr, N, nrhs); \ - celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } \ - } - UNWRAP_CASES_FEW +auto solve_lower_rev ( + py::array_t t, + py::array_t c, + py::array_t U, + py::array_t W, + py::array_t Y, + py::array_t Z, + py::array_t F, + py::array_t bZ, + py::array_t bt, + py::array_t bc, + py::array_t bU, + py::array_t bW, + py::array_t bY +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Wbuf = W.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + py::buffer_info bZbuf = bZ.request(); + py::buffer_info btbuf = bt.request(); + py::buffer_info bcbuf = bc.request(); + py::buffer_info bUbuf = bU.request(); + py::buffer_info bWbuf = bW.request(); + py::buffer_info bYbuf = bY.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + if (bZbuf.ndim != 2 || bZbuf.shape[0] != N || bZbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bZ"); + if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); + if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); + if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); + if (bWbuf.ndim != 2 || bWbuf.shape[0] != N || bWbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bW"); + if (bYbuf.ndim != 2 || bYbuf.shape[0] != N || bYbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bY"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ + Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ + Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ + Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ + Eigen::Map::value>> bW_((double *)bWbuf.ptr, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ + Eigen::Map Z_((const double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((const double *)Fbuf.ptr, N, J); \ + Eigen::Map bZ_((const double *)bZbuf.ptr, N, 1); \ + Eigen::Map bY_((double *)bYbuf.ptr, N, 1); \ + celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ + Eigen::Map> Z_((const double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((const double *)Fbuf.ptr, N, J * nrhs); \ + Eigen::Map> bZ_((const double *)bZbuf.ptr, N, nrhs); \ + Eigen::Map> bY_((double *)bYbuf.ptr, N, nrhs); \ + celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP - return std::make_tuple(bt, bc, bU, bW, bY); + return std::make_tuple(bt, bc, bU, bW, bY); } -auto solve_upper_fwd(py::array_t t, py::array_t c, py::array_t U, - py::array_t W, py::array_t Y, py::array_t Z, - py::array_t F) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Wbuf = W.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ - Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((double *)Fbuf.ptr, N, J); \ - Z_.setZero(); \ - celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ - Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((double *)Fbuf.ptr, N, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto solve_upper_fwd ( + py::array_t t, + py::array_t c, + py::array_t U, + py::array_t W, + py::array_t Y, + py::array_t Z, + py::array_t F +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Wbuf = W.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ + Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((double *)Fbuf.ptr, N, J); \ + Z_.setZero(); \ + celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ + Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((double *)Fbuf.ptr, N, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP - return std::make_tuple(Z, F); + return std::make_tuple(Z, F); } -auto solve_upper_rev(py::array_t t, py::array_t c, py::array_t U, - py::array_t W, py::array_t Y, py::array_t Z, - py::array_t F, py::array_t bZ, - py::array_t bt, py::array_t bc, - py::array_t bU, py::array_t bW, - py::array_t bY) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Wbuf = W.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - py::buffer_info bZbuf = bZ.request(); - py::buffer_info btbuf = bt.request(); - py::buffer_info bcbuf = bc.request(); - py::buffer_info bUbuf = bU.request(); - py::buffer_info bWbuf = bW.request(); - py::buffer_info bYbuf = bY.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - if (bZbuf.ndim != 2 || bZbuf.shape[0] != N || bZbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bZ"); - if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); - if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); - if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); - if (bWbuf.ndim != 2 || bWbuf.shape[0] != N || bWbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bW"); - if (bYbuf.ndim != 2 || bYbuf.shape[0] != N || bYbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bY"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ - Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ - Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ - Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ - Eigen::Map::value>> bW_((double *)bWbuf.ptr, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ - Eigen::Map Z_((const double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((const double *)Fbuf.ptr, N, J); \ - Eigen::Map bZ_((const double *)bZbuf.ptr, N, 1); \ - Eigen::Map bY_((double *)bYbuf.ptr, N, 1); \ - celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ - Eigen::Map> Z_((const double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((const double *)Fbuf.ptr, N, J *nrhs); \ - Eigen::Map> bZ_((const double *)bZbuf.ptr, N, nrhs); \ - Eigen::Map> bY_((double *)bYbuf.ptr, N, nrhs); \ - celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } \ - } - UNWRAP_CASES_FEW +auto solve_upper_rev ( + py::array_t t, + py::array_t c, + py::array_t U, + py::array_t W, + py::array_t Y, + py::array_t Z, + py::array_t F, + py::array_t bZ, + py::array_t bt, + py::array_t bc, + py::array_t bU, + py::array_t bW, + py::array_t bY +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Wbuf = W.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + py::buffer_info bZbuf = bZ.request(); + py::buffer_info btbuf = bt.request(); + py::buffer_info bcbuf = bc.request(); + py::buffer_info bUbuf = bU.request(); + py::buffer_info bWbuf = bW.request(); + py::buffer_info bYbuf = bY.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Wbuf.ndim != 2 || Wbuf.shape[0] != N || Wbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: W"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + if (bZbuf.ndim != 2 || bZbuf.shape[0] != N || bZbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bZ"); + if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); + if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); + if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); + if (bWbuf.ndim != 2 || bWbuf.shape[0] != N || bWbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bW"); + if (bYbuf.ndim != 2 || bYbuf.shape[0] != N || bYbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bY"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> W_((const double *)Wbuf.ptr, N, J); \ + Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ + Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ + Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ + Eigen::Map::value>> bW_((double *)bWbuf.ptr, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ + Eigen::Map Z_((const double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((const double *)Fbuf.ptr, N, J); \ + Eigen::Map bZ_((const double *)bZbuf.ptr, N, 1); \ + Eigen::Map bY_((double *)bYbuf.ptr, N, 1); \ + celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ + Eigen::Map> Z_((const double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((const double *)Fbuf.ptr, N, J * nrhs); \ + Eigen::Map> bZ_((const double *)bZbuf.ptr, N, nrhs); \ + Eigen::Map> bY_((double *)bYbuf.ptr, N, nrhs); \ + celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP - return std::make_tuple(bt, bc, bU, bW, bY); + return std::make_tuple(bt, bc, bU, bW, bY); } -auto matmul_lower_fwd(py::array_t t, py::array_t c, py::array_t U, - py::array_t V, py::array_t Y, py::array_t Z, - py::array_t F) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Vbuf = V.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ - Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((double *)Fbuf.ptr, N, J); \ - Z_.setZero(); \ - celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ - Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((double *)Fbuf.ptr, N, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto matmul_lower_fwd ( + py::array_t t, + py::array_t c, + py::array_t U, + py::array_t V, + py::array_t Y, + py::array_t Z, + py::array_t F +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Vbuf = V.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ + Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((double *)Fbuf.ptr, N, J); \ + Z_.setZero(); \ + celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ + Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((double *)Fbuf.ptr, N, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP - return std::make_tuple(Z, F); + return std::make_tuple(Z, F); } -auto matmul_lower_rev(py::array_t t, py::array_t c, py::array_t U, - py::array_t V, py::array_t Y, py::array_t Z, - py::array_t F, py::array_t bZ, - py::array_t bt, py::array_t bc, - py::array_t bU, py::array_t bV, - py::array_t bY) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Vbuf = V.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - py::buffer_info bZbuf = bZ.request(); - py::buffer_info btbuf = bt.request(); - py::buffer_info bcbuf = bc.request(); - py::buffer_info bUbuf = bU.request(); - py::buffer_info bVbuf = bV.request(); - py::buffer_info bYbuf = bY.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - if (bZbuf.ndim != 2 || bZbuf.shape[0] != N || bZbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bZ"); - if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); - if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); - if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); - if (bVbuf.ndim != 2 || bVbuf.shape[0] != N || bVbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bV"); - if (bYbuf.ndim != 2 || bYbuf.shape[0] != N || bYbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bY"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ - Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ - Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ - Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ - Eigen::Map::value>> bV_((double *)bVbuf.ptr, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ - Eigen::Map Z_((const double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((const double *)Fbuf.ptr, N, J); \ - Eigen::Map bZ_((const double *)bZbuf.ptr, N, 1); \ - Eigen::Map bY_((double *)bYbuf.ptr, N, 1); \ - celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ - Eigen::Map> Z_((const double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((const double *)Fbuf.ptr, N, J *nrhs); \ - Eigen::Map> bZ_((const double *)bZbuf.ptr, N, nrhs); \ - Eigen::Map> bY_((double *)bYbuf.ptr, N, nrhs); \ - celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } \ - } - UNWRAP_CASES_FEW +auto matmul_lower_rev ( + py::array_t t, + py::array_t c, + py::array_t U, + py::array_t V, + py::array_t Y, + py::array_t Z, + py::array_t F, + py::array_t bZ, + py::array_t bt, + py::array_t bc, + py::array_t bU, + py::array_t bV, + py::array_t bY +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Vbuf = V.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + py::buffer_info bZbuf = bZ.request(); + py::buffer_info btbuf = bt.request(); + py::buffer_info bcbuf = bc.request(); + py::buffer_info bUbuf = bU.request(); + py::buffer_info bVbuf = bV.request(); + py::buffer_info bYbuf = bY.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + if (bZbuf.ndim != 2 || bZbuf.shape[0] != N || bZbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bZ"); + if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); + if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); + if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); + if (bVbuf.ndim != 2 || bVbuf.shape[0] != N || bVbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bV"); + if (bYbuf.ndim != 2 || bYbuf.shape[0] != N || bYbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bY"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ + Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ + Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ + Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ + Eigen::Map::value>> bV_((double *)bVbuf.ptr, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ + Eigen::Map Z_((const double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((const double *)Fbuf.ptr, N, J); \ + Eigen::Map bZ_((const double *)bZbuf.ptr, N, 1); \ + Eigen::Map bY_((double *)bYbuf.ptr, N, 1); \ + celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ + Eigen::Map> Z_((const double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((const double *)Fbuf.ptr, N, J * nrhs); \ + Eigen::Map> bZ_((const double *)bZbuf.ptr, N, nrhs); \ + Eigen::Map> bY_((double *)bYbuf.ptr, N, nrhs); \ + celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP - return std::make_tuple(bt, bc, bU, bV, bY); + return std::make_tuple(bt, bc, bU, bV, bY); } -auto matmul_upper_fwd(py::array_t t, py::array_t c, py::array_t U, - py::array_t V, py::array_t Y, py::array_t Z, - py::array_t F) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Vbuf = V.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ - Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((double *)Fbuf.ptr, N, J); \ - Z_.setZero(); \ - celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ - Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((double *)Fbuf.ptr, N, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto matmul_upper_fwd ( + py::array_t t, + py::array_t c, + py::array_t U, + py::array_t V, + py::array_t Y, + py::array_t Z, + py::array_t F +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Vbuf = V.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ + Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((double *)Fbuf.ptr, N, J); \ + Z_.setZero(); \ + celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ + Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((double *)Fbuf.ptr, N, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP - return std::make_tuple(Z, F); + return std::make_tuple(Z, F); } -auto matmul_upper_rev(py::array_t t, py::array_t c, py::array_t U, - py::array_t V, py::array_t Y, py::array_t Z, - py::array_t F, py::array_t bZ, - py::array_t bt, py::array_t bc, - py::array_t bU, py::array_t bV, - py::array_t bY) { - // Request buffers - py::buffer_info tbuf = t.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Vbuf = V.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - py::buffer_info bZbuf = bZ.request(); - py::buffer_info btbuf = bt.request(); - py::buffer_info bcbuf = bc.request(); - py::buffer_info bUbuf = bU.request(); - py::buffer_info bVbuf = bV.request(); - py::buffer_info bYbuf = bY.request(); - - // Parse dimensions - if (tbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t"); - py::ssize_t N = tbuf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - if (bZbuf.ndim != 2 || bZbuf.shape[0] != N || bZbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bZ"); - if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); - if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); - if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); - if (bVbuf.ndim != 2 || bVbuf.shape[0] != N || bVbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bV"); - if (bYbuf.ndim != 2 || bYbuf.shape[0] != N || bYbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bY"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ - Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ - Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ - Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ - Eigen::Map::value>> bV_((double *)bVbuf.ptr, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ - Eigen::Map Z_((const double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((const double *)Fbuf.ptr, N, J); \ - Eigen::Map bZ_((const double *)bZbuf.ptr, N, 1); \ - Eigen::Map bY_((double *)bYbuf.ptr, N, 1); \ - celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ - Eigen::Map> Z_((const double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((const double *)Fbuf.ptr, N, J *nrhs); \ - Eigen::Map> bZ_((const double *)bZbuf.ptr, N, nrhs); \ - Eigen::Map> bY_((double *)bYbuf.ptr, N, nrhs); \ - celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } \ - } - UNWRAP_CASES_FEW +auto matmul_upper_rev ( + py::array_t t, + py::array_t c, + py::array_t U, + py::array_t V, + py::array_t Y, + py::array_t Z, + py::array_t F, + py::array_t bZ, + py::array_t bt, + py::array_t bc, + py::array_t bU, + py::array_t bV, + py::array_t bY +) { + // Request buffers + py::buffer_info tbuf = t.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Vbuf = V.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + py::buffer_info bZbuf = bZ.request(); + py::buffer_info btbuf = bt.request(); + py::buffer_info bcbuf = bc.request(); + py::buffer_info bUbuf = bU.request(); + py::buffer_info bVbuf = bV.request(); + py::buffer_info bYbuf = bY.request(); + + // Parse dimensions + if (tbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t"); + py::ssize_t N = tbuf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (tbuf.ndim != 1 || tbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: t"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Vbuf.ndim != 2 || Vbuf.shape[0] != N || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != N || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != N || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + if (bZbuf.ndim != 2 || bZbuf.shape[0] != N || bZbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bZ"); + if (btbuf.ndim != 1 || btbuf.shape[0] != N) throw std::invalid_argument("Invalid shape: bt"); + if (bcbuf.ndim != 1 || bcbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: bc"); + if (bUbuf.ndim != 2 || bUbuf.shape[0] != N || bUbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bU"); + if (bVbuf.ndim != 2 || bVbuf.shape[0] != N || bVbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: bV"); + if (bYbuf.ndim != 2 || bYbuf.shape[0] != N || bYbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: bY"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_((const double *)tbuf.ptr, N, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> V_((const double *)Vbuf.ptr, N, J); \ + Eigen::Map bt_((double *)btbuf.ptr, N, 1); \ + Eigen::Map> bc_((double *)bcbuf.ptr, J, 1); \ + Eigen::Map::value>> bU_((double *)bUbuf.ptr, N, J); \ + Eigen::Map::value>> bV_((double *)bVbuf.ptr, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, N, 1); \ + Eigen::Map Z_((const double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((const double *)Fbuf.ptr, N, J); \ + Eigen::Map bZ_((const double *)bZbuf.ptr, N, 1); \ + Eigen::Map bY_((double *)bYbuf.ptr, N, 1); \ + celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, N, nrhs); \ + Eigen::Map> Z_((const double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((const double *)Fbuf.ptr, N, J * nrhs); \ + Eigen::Map> bZ_((const double *)bZbuf.ptr, N, nrhs); \ + Eigen::Map> bY_((double *)bYbuf.ptr, N, nrhs); \ + celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP - return std::make_tuple(bt, bc, bU, bV, bY); + return std::make_tuple(bt, bc, bU, bV, bY); } -auto general_matmul_lower_fwd(py::array_t t1, py::array_t t2, - py::array_t c, py::array_t U, - py::array_t V, py::array_t Y, - py::array_t Z, py::array_t F) { - // Request buffers - py::buffer_info t1buf = t1.request(); - py::buffer_info t2buf = t2.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Vbuf = V.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - - // Parse dimensions - if (t1buf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t1"); - py::ssize_t N = t1buf.shape[0]; - if (t2buf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t2"); - py::ssize_t M = t2buf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (t1buf.ndim != 1 || t1buf.shape[0] != N) throw std::invalid_argument("Invalid shape: t1"); - if (t2buf.ndim != 1 || t2buf.shape[0] != M) throw std::invalid_argument("Invalid shape: t2"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Vbuf.ndim != 2 || Vbuf.shape[0] != M || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != M || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != M || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t1_((const double *)t1buf.ptr, N, 1); \ - Eigen::Map t2_((const double *)t2buf.ptr, M, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> V_((const double *)Vbuf.ptr, M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, M, 1); \ - Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((double *)Fbuf.ptr, M, J); \ - Z_.setZero(); \ - celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, M, nrhs); \ - Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((double *)Fbuf.ptr, M, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto general_matmul_lower_fwd ( + py::array_t t1, + py::array_t t2, + py::array_t c, + py::array_t U, + py::array_t V, + py::array_t Y, + py::array_t Z, + py::array_t F +) { + // Request buffers + py::buffer_info t1buf = t1.request(); + py::buffer_info t2buf = t2.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Vbuf = V.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + + // Parse dimensions + if (t1buf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t1"); + py::ssize_t N = t1buf.shape[0]; + if (t2buf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t2"); + py::ssize_t M = t2buf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (t1buf.ndim != 1 || t1buf.shape[0] != N) throw std::invalid_argument("Invalid shape: t1"); + if (t2buf.ndim != 1 || t2buf.shape[0] != M) throw std::invalid_argument("Invalid shape: t2"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Vbuf.ndim != 2 || Vbuf.shape[0] != M || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != M || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != M || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t1_((const double *)t1buf.ptr, N, 1); \ + Eigen::Map t2_((const double *)t2buf.ptr, M, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> V_((const double *)Vbuf.ptr, M, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, M, 1); \ + Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((double *)Fbuf.ptr, M, J); \ + Z_.setZero(); \ + celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, M, nrhs); \ + Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((double *)Fbuf.ptr, M, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP - return std::make_tuple(Z, F); + return std::make_tuple(Z, F); } -auto general_matmul_upper_fwd(py::array_t t1, py::array_t t2, - py::array_t c, py::array_t U, - py::array_t V, py::array_t Y, - py::array_t Z, py::array_t F) { - // Request buffers - py::buffer_info t1buf = t1.request(); - py::buffer_info t2buf = t2.request(); - py::buffer_info cbuf = c.request(); - py::buffer_info Ubuf = U.request(); - py::buffer_info Vbuf = V.request(); - py::buffer_info Ybuf = Y.request(); - py::buffer_info Zbuf = Z.request(); - py::buffer_info Fbuf = F.request(); - - // Parse dimensions - if (t1buf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t1"); - py::ssize_t N = t1buf.shape[0]; - if (t2buf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: t2"); - py::ssize_t M = t2buf.shape[0]; - if (cbuf.ndim <= 0) throw std::invalid_argument("Invalid number of dimensions: c"); - py::ssize_t J = cbuf.shape[0]; - if (Ybuf.ndim <= 1) throw std::invalid_argument("Invalid number of dimensions: Y"); - py::ssize_t nrhs = Ybuf.shape[1]; - - // Check shapes - if (t1buf.ndim != 1 || t1buf.shape[0] != N) throw std::invalid_argument("Invalid shape: t1"); - if (t2buf.ndim != 1 || t2buf.shape[0] != M) throw std::invalid_argument("Invalid shape: t2"); - if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); - if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); - if (Vbuf.ndim != 2 || Vbuf.shape[0] != M || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); - if (Ybuf.ndim != 2 || Ybuf.shape[0] != M || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); - if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); - if (Fbuf.ndim != 3 || Fbuf.shape[0] != M || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t1_((const double *)t1buf.ptr, N, 1); \ - Eigen::Map t2_((const double *)t2buf.ptr, M, 1); \ - Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ - Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ - Eigen::Map::value>> V_((const double *)Vbuf.ptr, M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_((const double *)Ybuf.ptr, M, 1); \ - Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ - Eigen::Map::value>> F_((double *)Fbuf.ptr, M, J); \ - Z_.setZero(); \ - celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_((const double *)Ybuf.ptr, M, nrhs); \ - Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ - Eigen::Map> F_((double *)Fbuf.ptr, M, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST +auto general_matmul_upper_fwd ( + py::array_t t1, + py::array_t t2, + py::array_t c, + py::array_t U, + py::array_t V, + py::array_t Y, + py::array_t Z, + py::array_t F +) { + // Request buffers + py::buffer_info t1buf = t1.request(); + py::buffer_info t2buf = t2.request(); + py::buffer_info cbuf = c.request(); + py::buffer_info Ubuf = U.request(); + py::buffer_info Vbuf = V.request(); + py::buffer_info Ybuf = Y.request(); + py::buffer_info Zbuf = Z.request(); + py::buffer_info Fbuf = F.request(); + + // Parse dimensions + if (t1buf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t1"); + py::ssize_t N = t1buf.shape[0]; + if (t2buf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: t2"); + py::ssize_t M = t2buf.shape[0]; + if (cbuf.ndim <= 0) + throw std::invalid_argument("Invalid number of dimensions: c"); + py::ssize_t J = cbuf.shape[0]; + if (Ybuf.ndim <= 1) + throw std::invalid_argument("Invalid number of dimensions: Y"); + py::ssize_t nrhs = Ybuf.shape[1]; + + // Check shapes + if (t1buf.ndim != 1 || t1buf.shape[0] != N) throw std::invalid_argument("Invalid shape: t1"); + if (t2buf.ndim != 1 || t2buf.shape[0] != M) throw std::invalid_argument("Invalid shape: t2"); + if (cbuf.ndim != 1 || cbuf.shape[0] != J) throw std::invalid_argument("Invalid shape: c"); + if (Ubuf.ndim != 2 || Ubuf.shape[0] != N || Ubuf.shape[1] != J) throw std::invalid_argument("Invalid shape: U"); + if (Vbuf.ndim != 2 || Vbuf.shape[0] != M || Vbuf.shape[1] != J) throw std::invalid_argument("Invalid shape: V"); + if (Ybuf.ndim != 2 || Ybuf.shape[0] != M || Ybuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Y"); + if (Zbuf.ndim != 2 || Zbuf.shape[0] != N || Zbuf.shape[1] != nrhs) throw std::invalid_argument("Invalid shape: Z"); + if (Fbuf.ndim != 3 || Fbuf.shape[0] != M || Fbuf.shape[1] != J || Fbuf.shape[2] != nrhs) throw std::invalid_argument("Invalid shape: F"); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t1_((const double *)t1buf.ptr, N, 1); \ + Eigen::Map t2_((const double *)t2buf.ptr, M, 1); \ + Eigen::Map> c_((const double *)cbuf.ptr, J, 1); \ + Eigen::Map::value>> U_((const double *)Ubuf.ptr, N, J); \ + Eigen::Map::value>> V_((const double *)Vbuf.ptr, M, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_((const double *)Ybuf.ptr, M, 1); \ + Eigen::Map Z_((double *)Zbuf.ptr, N, 1); \ + Eigen::Map::value>> F_((double *)Fbuf.ptr, M, J); \ + Z_.setZero(); \ + celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_((const double *)Ybuf.ptr, M, nrhs); \ + Eigen::Map> Z_((double *)Zbuf.ptr, N, nrhs); \ + Eigen::Map> F_((double *)Fbuf.ptr, M, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP - return std::make_tuple(Z, F); + return std::make_tuple(Z, F); } } // namespace driver } // namespace celerite2 PYBIND11_MODULE(backprop, m) { - py::register_exception(m, "LinAlgError"); - m.def("factor_fwd", &celerite2::driver::factor_fwd); - m.def("factor_rev", &celerite2::driver::factor_rev); - m.def("solve_lower_fwd", &celerite2::driver::solve_lower_fwd); - m.def("solve_lower_rev", &celerite2::driver::solve_lower_rev); - m.def("solve_upper_fwd", &celerite2::driver::solve_upper_fwd); - m.def("solve_upper_rev", &celerite2::driver::solve_upper_rev); - m.def("matmul_lower_fwd", &celerite2::driver::matmul_lower_fwd); - m.def("matmul_lower_rev", &celerite2::driver::matmul_lower_rev); - m.def("matmul_upper_fwd", &celerite2::driver::matmul_upper_fwd); - m.def("matmul_upper_rev", &celerite2::driver::matmul_upper_rev); - m.def("general_matmul_lower_fwd", &celerite2::driver::general_matmul_lower_fwd); - m.def("general_matmul_upper_fwd", &celerite2::driver::general_matmul_upper_fwd); + py::register_exception(m, "LinAlgError"); + m.def("factor_fwd", &celerite2::driver::factor_fwd); + m.def("factor_rev", &celerite2::driver::factor_rev); + m.def("solve_lower_fwd", &celerite2::driver::solve_lower_fwd); + m.def("solve_lower_rev", &celerite2::driver::solve_lower_rev); + m.def("solve_upper_fwd", &celerite2::driver::solve_upper_fwd); + m.def("solve_upper_rev", &celerite2::driver::solve_upper_rev); + m.def("matmul_lower_fwd", &celerite2::driver::matmul_lower_fwd); + m.def("matmul_lower_rev", &celerite2::driver::matmul_lower_rev); + m.def("matmul_upper_fwd", &celerite2::driver::matmul_upper_fwd); + m.def("matmul_upper_rev", &celerite2::driver::matmul_upper_rev); + m.def("general_matmul_lower_fwd", &celerite2::driver::general_matmul_lower_fwd); + m.def("general_matmul_upper_fwd", &celerite2::driver::general_matmul_upper_fwd); #ifdef VERSION_INFO m.attr("__version__") = VERSION_INFO; diff --git a/python/celerite2/driver.hpp b/python/celerite2/driver.hpp index c69c88a..2907213 100644 --- a/python/celerite2/driver.hpp +++ b/python/celerite2/driver.hpp @@ -30,12 +30,6 @@ struct backprop_linalg_exception : public std::exception { case 2: FIXED_SIZE_MAP(2); break; \ case 3: FIXED_SIZE_MAP(3); break; \ case 4: FIXED_SIZE_MAP(4); break; \ - case 5: FIXED_SIZE_MAP(5); break; \ - case 6: FIXED_SIZE_MAP(6); break; \ - case 7: FIXED_SIZE_MAP(7); break; \ - case 8: FIXED_SIZE_MAP(8); break; \ - case 9: FIXED_SIZE_MAP(9); break; \ - case 10: FIXED_SIZE_MAP(10); break; \ default: FIXED_SIZE_MAP(Eigen::Dynamic); \ } @@ -117,6 +111,6 @@ struct order<1> { const static int value = Eigen::ColMajor; }; -}; // namespace driver -}; // namespace celerite2 +}; // namespace driver +}; // namespace celerite2 #endif // _CELERITE2_PYTHON_DRIVER_HPP_DEFINED_ diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index bc6e983..ffbc5f8 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -20,13 +20,14 @@ import json from collections import OrderedDict from functools import partial +from itertools import chain import numpy as np import pkg_resources from jax import core, lax from jax import numpy as jnp from jax.core import ShapedArray -from jax.interpreters import ad, xla +from jax.interpreters import ad, mlir, xla from jax.lib import xla_client from celerite2.jax import xla_ops @@ -99,45 +100,27 @@ def _abstract_eval(spec, *args): ) -def _translation_rule(name, spec, c, *args): - shapes = tuple(c.get_shape(arg) for arg in args) +def _lowering_rule(name, spec, ctx: mlir.LoweringRuleContext, *args): + if any(a.dtype != np.float64 for a in chain(ctx.avals_in, ctx.avals_out)): + raise ValueError(f"{spec['name']} requires float64 precision") + shapes = [a.shape for a in ctx.avals_in] dims = OrderedDict( - (s["name"], shapes[s["coords"][0]].dimensions()[s["coords"][1]]) + (s["name"], shapes[s["coords"][0]][s["coords"][1]]) for s in spec["dimensions"] ) - if any(shape.element_type() != np.float64 for shape in shapes): - raise ValueError(f"{spec['name']} requires float64 precision") - - return xops.CustomCallWithLayout( - c, + return mlir.custom_call( name, - operands=tuple( - xops.ConstantLiteral(c, np.int32(v)) for v in dims.values() - ) + operands=tuple(mlir.ir_constant(np.int32(v)) for v in dims.values()) + args, - shape_with_layout=xla_client.Shape.tuple_shape( - tuple( - xla_client.Shape.array_shape( - jnp.dtype(np.float64), - tuple(dims[k] for k in s["shape"]), - tuple(range(len(s["shape"]) - 1, -1, -1)), - ) - for s in spec["outputs"] + spec["extra_outputs"] - ) - ), - operand_shapes_with_layout=tuple( - xla_client.Shape.array_shape(jnp.dtype(jnp.int32), (), ()) - for _ in range(len(dims)) - ) - + tuple( - xla_client.Shape.array_shape( - jnp.dtype(np.float64), - tuple(dims[k] for k in s["shape"]), - tuple(range(len(s["shape"]) - 1, -1, -1)), - ) - for s in spec["inputs"] - ), - ) + operand_layouts=[()] * len(dims) + + _default_layouts(aval.shape for aval in ctx.avals_in), + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + result_layouts=_default_layouts(aval.shape for aval in ctx.avals_out), + ).results + + +def _default_layouts(shapes): + return [range(len(shape) - 1, -1, -1) for shape in shapes] def _jvp(prim, jvp_prim, spec, arg_values, arg_tangents): @@ -186,7 +169,7 @@ def _rev_abstract_eval(spec, *args): ) -def _rev_translation_rule(name, spec, c, *args): +def _rev_lowering_rule(name, spec, ctx, *args): rev_spec = dict( name=f"{spec['name']}_rev", dimensions=spec["dimensions"], @@ -197,7 +180,7 @@ def _rev_translation_rule(name, spec, c, *args): outputs=spec["inputs"], extra_outputs=[], ) - return _translation_rule(name, rev_spec, c, *args) + return _lowering_rule(name, rev_spec, ctx, *args) def _build_op(name, spec): @@ -209,15 +192,15 @@ def _build_op(name, spec): prim.multiple_results = True prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(partial(_abstract_eval, spec)) - xla.backend_specific_translations["cpu"][prim] = partial( - _translation_rule, name, spec + mlir.register_lowering( + prim, partial(_lowering_rule, name, spec), platform="cpu" ) if not spec["has_rev"]: return prim, None xla_client.register_custom_call_target( - name + b"_rev", + name + "_rev", getattr(xla_ops, f"{spec['name']}_rev")(), platform="cpu", ) @@ -235,8 +218,10 @@ def _build_op(name, spec): # Handle reverse pass using custom op rev_prim.def_impl(partial(xla.apply_primitive, rev_prim)) rev_prim.def_abstract_eval(partial(_rev_abstract_eval, spec)) - xla.backend_specific_translations["cpu"][rev_prim] = partial( - _rev_translation_rule, name + b"_rev", spec + mlir.register_lowering( + rev_prim, + partial(_rev_lowering_rule, name + "_rev", spec), + platform="cpu", ) return prim, rev_prim @@ -248,22 +233,22 @@ def _build_op(name, spec): definitions = {spec["name"]: spec for spec in json.load(f)} -factor_p, factor_rev_p = _build_op(b"celerite2_factor", definitions["factor"]) +factor_p, factor_rev_p = _build_op("celerite2_factor", definitions["factor"]) solve_lower_p, solve_lower_rev_p = _build_op( - b"celerite2_solve_lower", definitions["solve_lower"] + "celerite2_solve_lower", definitions["solve_lower"] ) solve_upper_p, solve_upper_rev_p = _build_op( - b"celerite2_solve_upper", definitions["solve_upper"] + "celerite2_solve_upper", definitions["solve_upper"] ) matmul_lower_p, matmul_lower_rev_p = _build_op( - b"celerite2_matmul_lower", definitions["matmul_lower"] + "celerite2_matmul_lower", definitions["matmul_lower"] ) matmul_upper_p, matmul_upper_rev_p = _build_op( - b"celerite2_matmul_upper", definitions["matmul_upper"] + "celerite2_matmul_upper", definitions["matmul_upper"] ) general_matmul_lower_p, _ = _build_op( - b"celerite2_general_matmul_lower", definitions["general_matmul_lower"] + "celerite2_general_matmul_lower", definitions["general_matmul_lower"] ) general_matmul_upper_p, _ = _build_op( - b"celerite2_general_matmul_upper", definitions["general_matmul_upper"] + "celerite2_general_matmul_upper", definitions["general_matmul_upper"] ) diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index a081633..73a2a55 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -2,566 +2,578 @@ // NOTE: Changes should be made to the template #include +#include +#include +#include +#include #include "../driver.hpp" namespace py = pybind11; using namespace celerite2::driver; -auto factor(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - - const double *t = reinterpret_cast(in[2]); - const double *c = reinterpret_cast(in[3]); - const double *a = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - double *d = reinterpret_cast(out[0]); - double *W = reinterpret_cast(out[1]); - double *S = reinterpret_cast(out[2]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map a_(a, N, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map d_(d, N, 1); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map::value>> S_(S, N, J *J); \ - Eigen::Index flag = celerite2::core::factor(t_, c_, a_, U_, V_, d_, W_, S_); \ - if (flag) d_.setZero(); \ - } - UNWRAP_CASES_MOST + +auto factor (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + + const double *t = reinterpret_cast(in[2]); + const double *c = reinterpret_cast(in[3]); + const double *a = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *V = reinterpret_cast(in[6]); + double *d = reinterpret_cast(out[0]); + double *W = reinterpret_cast(out[1]); + double *S = reinterpret_cast(out[2]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map a_(a, N, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> V_(V, N, J); \ + Eigen::Map d_(d, N, 1); \ + Eigen::Map::value>> W_(W, N, J); \ + Eigen::Map::value>> S_(S, N, J * J); \ + Eigen::Index flag = celerite2::core::factor(t_, c_, a_, U_, V_, d_, W_, S_); \ + if (flag) d_.setZero(); \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP } -auto factor_rev(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - - const double *t = reinterpret_cast(in[2]); - const double *c = reinterpret_cast(in[3]); - const double *a = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *d = reinterpret_cast(in[7]); - const double *W = reinterpret_cast(in[8]); - const double *S = reinterpret_cast(in[9]); - const double *bd = reinterpret_cast(in[10]); - const double *bW = reinterpret_cast(in[11]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *ba = reinterpret_cast(out[2]); - double *bU = reinterpret_cast(out[3]); - double *bV = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map a_(a, N, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map d_(d, N, 1); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map::value>> S_(S, N, J *J); \ - Eigen::Map bd_(bd, N, 1); \ - Eigen::Map::value>> bW_(bW, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map ba_(ba, N, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - celerite2::core::factor_rev(t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_, bc_, ba_, bU_, bV_); \ - } - UNWRAP_CASES_FEW +auto factor_rev (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + + const double *t = reinterpret_cast(in[2]); + const double *c = reinterpret_cast(in[3]); + const double *a = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *V = reinterpret_cast(in[6]); + const double *d = reinterpret_cast(in[7]); + const double *W = reinterpret_cast(in[8]); + const double *S = reinterpret_cast(in[9]); + const double *bd = reinterpret_cast(in[10]); + const double *bW = reinterpret_cast(in[11]); + double *bt = reinterpret_cast(out[0]); + double *bc = reinterpret_cast(out[1]); + double *ba = reinterpret_cast(out[2]); + double *bU = reinterpret_cast(out[3]); + double *bV = reinterpret_cast(out[4]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map a_(a, N, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> V_(V, N, J); \ + Eigen::Map d_(d, N, 1); \ + Eigen::Map::value>> W_(W, N, J); \ + Eigen::Map::value>> S_(S, N, J * J); \ + Eigen::Map bd_(bd, N, 1); \ + Eigen::Map::value>> bW_(bW, N, J); \ + Eigen::Map bt_(bt, N, 1); \ + Eigen::Map> bc_(bc, J, 1); \ + Eigen::Map ba_(ba, N, 1); \ + Eigen::Map::value>> bU_(bU, N, J); \ + Eigen::Map::value>> bV_(bV, N, J); \ + celerite2::core::factor_rev(t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_, bc_, ba_, bU_, bV_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP } -auto solve_lower(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto solve_lower (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + const int nrhs = *reinterpret_cast(in[2]); + + const double *t = reinterpret_cast(in[3]); + const double *c = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *W = reinterpret_cast(in[6]); + const double *Y = reinterpret_cast(in[7]); + double *Z = reinterpret_cast(out[0]); + double *F = reinterpret_cast(out[1]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> W_(W, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, N, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, N, J); \ + Z_.setZero(); \ + celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_(Y, N, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, N, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP } -auto solve_lower_rev(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bW = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bW_(bW, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J *nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } \ - } - UNWRAP_CASES_FEW +auto solve_lower_rev (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + const int nrhs = *reinterpret_cast(in[2]); + + const double *t = reinterpret_cast(in[3]); + const double *c = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *W = reinterpret_cast(in[6]); + const double *Y = reinterpret_cast(in[7]); + const double *Z = reinterpret_cast(in[8]); + const double *F = reinterpret_cast(in[9]); + const double *bZ = reinterpret_cast(in[10]); + double *bt = reinterpret_cast(out[0]); + double *bc = reinterpret_cast(out[1]); + double *bU = reinterpret_cast(out[2]); + double *bW = reinterpret_cast(out[3]); + double *bY = reinterpret_cast(out[4]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> W_(W, N, J); \ + Eigen::Map bt_(bt, N, 1); \ + Eigen::Map> bc_(bc, J, 1); \ + Eigen::Map::value>> bU_(bU, N, J); \ + Eigen::Map::value>> bW_(bW, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, N, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, N, J); \ + Eigen::Map bZ_(bZ, N, 1); \ + Eigen::Map bY_(bY, N, 1); \ + celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + } else { \ + Eigen::Map> Y_(Y, N, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, N, J * nrhs); \ + Eigen::Map> bZ_(bZ, N, nrhs); \ + Eigen::Map> bY_(bY, N, nrhs); \ + celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP } -auto solve_upper(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto solve_upper (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + const int nrhs = *reinterpret_cast(in[2]); + + const double *t = reinterpret_cast(in[3]); + const double *c = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *W = reinterpret_cast(in[6]); + const double *Y = reinterpret_cast(in[7]); + double *Z = reinterpret_cast(out[0]); + double *F = reinterpret_cast(out[1]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> W_(W, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, N, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, N, J); \ + Z_.setZero(); \ + celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_(Y, N, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, N, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP } -auto solve_upper_rev(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bW = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bW_(bW, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J *nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } \ - } - UNWRAP_CASES_FEW +auto solve_upper_rev (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + const int nrhs = *reinterpret_cast(in[2]); + + const double *t = reinterpret_cast(in[3]); + const double *c = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *W = reinterpret_cast(in[6]); + const double *Y = reinterpret_cast(in[7]); + const double *Z = reinterpret_cast(in[8]); + const double *F = reinterpret_cast(in[9]); + const double *bZ = reinterpret_cast(in[10]); + double *bt = reinterpret_cast(out[0]); + double *bc = reinterpret_cast(out[1]); + double *bU = reinterpret_cast(out[2]); + double *bW = reinterpret_cast(out[3]); + double *bY = reinterpret_cast(out[4]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> W_(W, N, J); \ + Eigen::Map bt_(bt, N, 1); \ + Eigen::Map> bc_(bc, J, 1); \ + Eigen::Map::value>> bU_(bU, N, J); \ + Eigen::Map::value>> bW_(bW, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, N, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, N, J); \ + Eigen::Map bZ_(bZ, N, 1); \ + Eigen::Map bY_(bY, N, 1); \ + celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + } else { \ + Eigen::Map> Y_(Y, N, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, N, J * nrhs); \ + Eigen::Map> bZ_(bZ, N, nrhs); \ + Eigen::Map> bY_(bY, N, nrhs); \ + celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP } -auto matmul_lower(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto matmul_lower (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + const int nrhs = *reinterpret_cast(in[2]); + + const double *t = reinterpret_cast(in[3]); + const double *c = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *V = reinterpret_cast(in[6]); + const double *Y = reinterpret_cast(in[7]); + double *Z = reinterpret_cast(out[0]); + double *F = reinterpret_cast(out[1]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> V_(V, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, N, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, N, J); \ + Z_.setZero(); \ + celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_(Y, N, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, N, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP } -auto matmul_lower_rev(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bV = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J *nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } \ - } - UNWRAP_CASES_FEW +auto matmul_lower_rev (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + const int nrhs = *reinterpret_cast(in[2]); + + const double *t = reinterpret_cast(in[3]); + const double *c = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *V = reinterpret_cast(in[6]); + const double *Y = reinterpret_cast(in[7]); + const double *Z = reinterpret_cast(in[8]); + const double *F = reinterpret_cast(in[9]); + const double *bZ = reinterpret_cast(in[10]); + double *bt = reinterpret_cast(out[0]); + double *bc = reinterpret_cast(out[1]); + double *bU = reinterpret_cast(out[2]); + double *bV = reinterpret_cast(out[3]); + double *bY = reinterpret_cast(out[4]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> V_(V, N, J); \ + Eigen::Map bt_(bt, N, 1); \ + Eigen::Map> bc_(bc, J, 1); \ + Eigen::Map::value>> bU_(bU, N, J); \ + Eigen::Map::value>> bV_(bV, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, N, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, N, J); \ + Eigen::Map bZ_(bZ, N, 1); \ + Eigen::Map bY_(bY, N, 1); \ + celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + } else { \ + Eigen::Map> Y_(Y, N, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, N, J * nrhs); \ + Eigen::Map> bZ_(bZ, N, nrhs); \ + Eigen::Map> bY_(bY, N, nrhs); \ + celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP } -auto matmul_upper(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J *nrhs); \ - Z_.setZero(); \ - celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto matmul_upper (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + const int nrhs = *reinterpret_cast(in[2]); + + const double *t = reinterpret_cast(in[3]); + const double *c = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *V = reinterpret_cast(in[6]); + const double *Y = reinterpret_cast(in[7]); + double *Z = reinterpret_cast(out[0]); + double *F = reinterpret_cast(out[1]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> V_(V, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, N, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, N, J); \ + Z_.setZero(); \ + celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_(Y, N, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, N, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP } -auto matmul_upper_rev(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bV = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J *nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } \ - } - UNWRAP_CASES_FEW +auto matmul_upper_rev (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int J = *reinterpret_cast(in[1]); + const int nrhs = *reinterpret_cast(in[2]); + + const double *t = reinterpret_cast(in[3]); + const double *c = reinterpret_cast(in[4]); + const double *U = reinterpret_cast(in[5]); + const double *V = reinterpret_cast(in[6]); + const double *Y = reinterpret_cast(in[7]); + const double *Z = reinterpret_cast(in[8]); + const double *F = reinterpret_cast(in[9]); + const double *bZ = reinterpret_cast(in[10]); + double *bt = reinterpret_cast(out[0]); + double *bc = reinterpret_cast(out[1]); + double *bU = reinterpret_cast(out[2]); + double *bV = reinterpret_cast(out[3]); + double *bY = reinterpret_cast(out[4]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t, N, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> V_(V, N, J); \ + Eigen::Map bt_(bt, N, 1); \ + Eigen::Map> bc_(bc, J, 1); \ + Eigen::Map::value>> bU_(bU, N, J); \ + Eigen::Map::value>> bV_(bV, N, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, N, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, N, J); \ + Eigen::Map bZ_(bZ, N, 1); \ + Eigen::Map bY_(bY, N, 1); \ + celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + } else { \ + Eigen::Map> Y_(Y, N, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, N, J * nrhs); \ + Eigen::Map> bZ_(bZ, N, nrhs); \ + Eigen::Map> bY_(bY, N, nrhs); \ + celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP } -auto general_matmul_lower(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int M = *reinterpret_cast(in[1]); - const int J = *reinterpret_cast(in[2]); - const int nrhs = *reinterpret_cast(in[3]); - - const double *t1 = reinterpret_cast(in[4]); - const double *t2 = reinterpret_cast(in[5]); - const double *c = reinterpret_cast(in[6]); - const double *U = reinterpret_cast(in[7]); - const double *V = reinterpret_cast(in[8]); - const double *Y = reinterpret_cast(in[9]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t1_(t1, N, 1); \ - Eigen::Map t2_(t2, M, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, M, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, M, J); \ - Z_.setZero(); \ - F_.setZero(); \ - celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, M, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, M, J *nrhs); \ - Z_.setZero(); \ - F_.setZero(); \ - celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + +auto general_matmul_lower (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int M = *reinterpret_cast(in[1]); + const int J = *reinterpret_cast(in[2]); + const int nrhs = *reinterpret_cast(in[3]); + + const double *t1 = reinterpret_cast(in[4]); + const double *t2 = reinterpret_cast(in[5]); + const double *c = reinterpret_cast(in[6]); + const double *U = reinterpret_cast(in[7]); + const double *V = reinterpret_cast(in[8]); + const double *Y = reinterpret_cast(in[9]); + double *Z = reinterpret_cast(out[0]); + double *F = reinterpret_cast(out[1]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t1_(t1, N, 1); \ + Eigen::Map t2_(t2, M, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> V_(V, M, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, M, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, M, J); \ + Z_.setZero(); \ + celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_(Y, M, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, M, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP } -auto general_matmul_upper(void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int M = *reinterpret_cast(in[1]); - const int J = *reinterpret_cast(in[2]); - const int nrhs = *reinterpret_cast(in[3]); - - const double *t1 = reinterpret_cast(in[4]); - const double *t2 = reinterpret_cast(in[5]); - const double *c = reinterpret_cast(in[6]); - const double *U = reinterpret_cast(in[7]); - const double *V = reinterpret_cast(in[8]); - const double *Y = reinterpret_cast(in[9]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t1_(t1, N, 1); \ - Eigen::Map t2_(t2, M, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, M, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, M, J); \ - Z_.setZero(); \ - F_.setZero(); \ - celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, M, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, M, J *nrhs); \ - Z_.setZero(); \ - F_.setZero(); \ - celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST +auto general_matmul_upper (void *out_tuple, const void **in) { + void **out = reinterpret_cast(out_tuple); + const int N = *reinterpret_cast(in[0]); + const int M = *reinterpret_cast(in[1]); + const int J = *reinterpret_cast(in[2]); + const int nrhs = *reinterpret_cast(in[3]); + + const double *t1 = reinterpret_cast(in[4]); + const double *t2 = reinterpret_cast(in[5]); + const double *c = reinterpret_cast(in[6]); + const double *U = reinterpret_cast(in[7]); + const double *V = reinterpret_cast(in[8]); + const double *Y = reinterpret_cast(in[9]); + double *Z = reinterpret_cast(out[0]); + double *F = reinterpret_cast(out[1]); + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t1_(t1, N, 1); \ + Eigen::Map t2_(t2, M, 1); \ + Eigen::Map> c_(c, J, 1); \ + Eigen::Map::value>> U_(U, N, J); \ + Eigen::Map::value>> V_(V, M, J); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y, M, 1); \ + Eigen::Map Z_(Z, N, 1); \ + Eigen::Map::value>> F_(F, M, J); \ + Z_.setZero(); \ + celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + } else { \ + Eigen::Map> Y_(Y, M, nrhs); \ + Eigen::Map> Z_(Z, N, nrhs); \ + Eigen::Map> F_(F, M, J * nrhs); \ + Z_.setZero(); \ + celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + } \ + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP } + +// https://en.cppreference.com/w/cpp/numeric/bit_cast +template +typename std::enable_if::value && std::is_trivially_copyable::value, To>::type +bit_cast(const From &src) noexcept { + static_assert(std::is_trivially_constructible::value, + "This implementation additionally requires destination type to be trivially constructible"); + + To dst; + memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template +py::capsule encapsulate_function(T* fn) { + return py::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + PYBIND11_MODULE(xla_ops, m) { - m.def("factor", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&factor, name); - }); - m.def("factor_rev", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&factor_rev, name); - }); - m.def("solve_lower", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&solve_lower, name); - }); - m.def("solve_lower_rev", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&solve_lower_rev, name); - }); - m.def("solve_upper", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&solve_upper, name); - }); - m.def("solve_upper_rev", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&solve_upper_rev, name); - }); - m.def("matmul_lower", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&matmul_lower, name); - }); - m.def("matmul_lower_rev", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&matmul_lower_rev, name); - }); - m.def("matmul_upper", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&matmul_upper, name); - }); - m.def("matmul_upper_rev", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&matmul_upper_rev, name); - }); - m.def("general_matmul_lower", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&general_matmul_lower, name); - }); - m.def("general_matmul_upper", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&general_matmul_upper, name); - }); + m.def("factor", []() { + return encapsulate_function(factor); + }); + m.def("factor_rev", []() { + return encapsulate_function(factor_rev); + }); + m.def("solve_lower", []() { + return encapsulate_function(solve_lower); + }); + m.def("solve_lower_rev", []() { + return encapsulate_function(solve_lower_rev); + }); + m.def("solve_upper", []() { + return encapsulate_function(solve_upper); + }); + m.def("solve_upper_rev", []() { + return encapsulate_function(solve_upper_rev); + }); + m.def("matmul_lower", []() { + return encapsulate_function(matmul_lower); + }); + m.def("matmul_lower_rev", []() { + return encapsulate_function(matmul_lower_rev); + }); + m.def("matmul_upper", []() { + return encapsulate_function(matmul_upper); + }); + m.def("matmul_upper_rev", []() { + return encapsulate_function(matmul_upper_rev); + }); + m.def("general_matmul_lower", []() { + return encapsulate_function(general_matmul_lower); + }); + m.def("general_matmul_upper", []() { + return encapsulate_function(general_matmul_upper); + }); } diff --git a/python/celerite2/pymc/celerite2.py b/python/celerite2/pymc/celerite2.py index 13386d9..25f891f 100644 --- a/python/celerite2/pymc/celerite2.py +++ b/python/celerite2/pymc/celerite2.py @@ -102,7 +102,7 @@ def marginal(self, name, **kwargs): self._U, self._W, self._d, - **kwargs + **kwargs, ) def conditional( diff --git a/python/spec/generate.py b/python/spec/generate.py index 70ad71f..14323ad 100644 --- a/python/spec/generate.py +++ b/python/spec/generate.py @@ -5,6 +5,7 @@ import os from pathlib import Path +import pkg_resources from jinja2 import Environment, FileSystemLoader, select_autoescape base = Path(os.path.dirname(os.path.abspath(__file__))) diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index 97f1485..bd22248 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -1,4 +1,8 @@ #include +#include +#include +#include +#include #include "../driver.hpp" namespace py = pybind11; @@ -133,16 +137,32 @@ auto {{mod.name}}_rev (void *out_tuple, const void **in) { } {% endif %} {% endfor %} + +// https://en.cppreference.com/w/cpp/numeric/bit_cast +template +typename std::enable_if::value && std::is_trivially_copyable::value, To>::type +bit_cast(const From &src) noexcept { + static_assert(std::is_trivially_constructible::value, + "This implementation additionally requires destination type to be trivially constructible"); + + To dst; + memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template +py::capsule encapsulate_function(T* fn) { + return py::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + PYBIND11_MODULE(xla_ops, m) { {%- for mod in spec %} m.def("{{mod.name}}", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&{{mod.name}}, name); + return encapsulate_function({{mod.name}}); }); {%- if mod.has_rev %} m.def("{{mod.name}}_rev", []() { - const char *name = "xla._CUSTOM_CALL_TARGET"; - return py::capsule((void *)&{{mod.name}}_rev, name); + return encapsulate_function({{mod.name}}_rev); }); {%- endif %} {%- endfor %} diff --git a/python/test/pymc/test_pymc_ops.py b/python/test/pymc/test_pymc_ops.py index 5f1b834..1c077d7 100644 --- a/python/test/pymc/test_pymc_ops.py +++ b/python/test/pymc/test_pymc_ops.py @@ -234,6 +234,7 @@ def test_general_matmul_lower_fwd(): ) +@pytest.mark.xfail(reason="Numerically unstable") def test_general_matmul_upper_fwd(): x, c, a, U, V, Y, t, U2, V2 = get_matrices(conditional=True) check_basic( @@ -266,9 +267,6 @@ def compare_jax_and_py( if len(fgraph.outputs) > 1: for j, p in zip(jax_res, py_res): - print(np.min(j), np.max(j), np.any(np.isnan(j))) - print(np.min(p), np.max(p), np.any(np.isnan(p))) - assert_fn(j, p) else: assert_fn(jax_res, py_res) diff --git a/setup.py b/setup.py deleted file mode 100644 index e158675..0000000 --- a/setup.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python - -# Inspired by: -# https://hynek.me/articles/sharing-your-labor-of-love-pypi-quick-and-dirty/ -import codecs -import os -import re - -from pybind11.setup_helpers import Pybind11Extension, build_ext -from setuptools import find_packages, setup - -# PROJECT SPECIFIC - -NAME = "celerite2" -PACKAGES = find_packages(where="python") -META_PATH = os.path.join("python", "celerite2", "__init__.py") -CLASSIFIERS = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3", -] -INSTALL_REQUIRES = ["numpy>=1.13.0"] -SETUP_REQUIRES = INSTALL_REQUIRES + [ - "pybind11>=2.4", - "setuptools>=40.6.0", - "setuptools_scm", - "wheel", -] -EXTRA_REQUIRE = { - "style": ["isort", "black", "black_nbconvert"], - "test": [ - "coverage[toml]", - "pytest", - "pytest-cov", - "scipy", - "celerite>=0.3.1", - ], - "pymc3": ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"], - "pymc": ["pymc>=5.9.2"], - "jax": ["jax", "jaxlib"], - "docs": [ - "sphinx", - "sphinx-material", - "sphinx_copybutton", - "breathe", - "myst-nb", - ], - "tutorials": [ - "matplotlib", - "scipy", - "emcee", - "pymc>=5", - "tqdm", - "numpyro", - ], -} -EXTRA_REQUIRE["docs"] += EXTRA_REQUIRE["tutorials"] -EXTRA_REQUIRE["theano"] = EXTRA_REQUIRE["pymc3"] -EXTRA_REQUIRE["dev"] = ( - EXTRA_REQUIRE["style"] - + EXTRA_REQUIRE["test"] - + ["pre-commit", "nbstripout", "flake8"] -) - -include_dirs = [ - "c++/include", - "c++/vendor/eigen", - "python/celerite2", -] -ext_modules = [ - Pybind11Extension( - "celerite2.driver", - ["python/celerite2/driver.cpp"], - include_dirs=include_dirs, - language="c++", - ), - Pybind11Extension( - "celerite2.backprop", - ["python/celerite2/backprop.cpp"], - include_dirs=include_dirs, - language="c++", - ), - Pybind11Extension( - "celerite2.jax.xla_ops", - ["python/celerite2/jax/xla_ops.cpp"], - include_dirs=include_dirs, - language="c++", - ), -] - -# END PROJECT SPECIFIC - - -HERE = os.path.dirname(os.path.realpath(__file__)) - - -def read(*parts): - with codecs.open(os.path.join(HERE, *parts), "rb", "utf-8") as f: - return f.read() - - -def find_meta(meta, meta_file=read(META_PATH)): - meta_match = re.search( - r"^__{meta}__ = ['\"]([^'\"]*)['\"]".format(meta=meta), meta_file, re.M - ) - if meta_match: - return meta_match.group(1) - raise RuntimeError("Unable to find __{meta}__ string.".format(meta=meta)) - - -if __name__ == "__main__": - setup( - name=NAME, - use_scm_version={ - "write_to": os.path.join( - "python", NAME, "{0}_version.py".format(NAME) - ), - "write_to_template": '__version__ = "{version}"\n', - }, - author=find_meta("author"), - author_email=find_meta("email"), - maintainer=find_meta("author"), - maintainer_email=find_meta("email"), - url=find_meta("uri"), - license=find_meta("license"), - description=find_meta("description"), - long_description=read("README.md"), - long_description_content_type="text/markdown", - packages=PACKAGES, - package_dir={"": "python"}, - include_package_data=True, - python_requires=">=3.6", - install_requires=INSTALL_REQUIRES, - setup_requires=SETUP_REQUIRES, - extras_require=EXTRA_REQUIRE, - classifiers=CLASSIFIERS, - zip_safe=False, - ext_modules=ext_modules, - cmdclass={"build_ext": build_ext}, - )