Skip to content

Commit

Permalink
Update metatensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Sep 6, 2023
1 parent e1ab54b commit 913e698
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 38 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ myst-parser # markdown => rst translation, used in extensions/rascaline_json

# dependencies for the tutorials
--extra-index-url https://download.pytorch.org/whl/cpu
metatensor[torch] @ https://github.com/lab-cosmo/metatensor/archive/32ad5bb.zip
metatensor[torch] @ https://github.com/lab-cosmo/metatensor/archive/d97ea65.zip
torch
chemfiles
matplotlib
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
]

dependencies = [
"metatensor-core @ https://github.com/lab-cosmo/metatensor/archive/32ad5bb.zip#subdirectory=python/metatensor-core",
"metatensor-core @ https://github.com/lab-cosmo/metatensor/archive/d97ea65.zip#subdirectory=python/metatensor-core",
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion python/rascaline-torch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ requires = [
"wheel >=0.38",
"cmake",
"torch >= 1.11",
"metatensor-torch @ https://github.com/lab-cosmo/metatensor/archive/32ad5bb.zip#subdirectory=python/metatensor-torch",
"metatensor-torch @ https://github.com/lab-cosmo/metatensor/archive/d97ea65.zip#subdirectory=python/metatensor-torch",
]

# use a custom build backend to add a dependency on the right version of rascaline
Expand Down
8 changes: 3 additions & 5 deletions python/rascaline/rascaline/utils/power_spectrum/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,8 @@ def compute(
"species_neighbor",
]

# TODO: re-enable once we update metatensor with
# https://github.com/lab-cosmo/metatensor/pull/322
# assert spherical_expansion_1.keys.names == expected_key_names
# assert spherical_expansion_1.property_names == ["n"]
assert spherical_expansion_1.keys.names == expected_key_names
assert spherical_expansion_1.properties_names == ["n"]

# Fill blocks with `species_neighbor` from ALL blocks. If we don't do this
# merging blocks along the ``sample`` direction might be not possible.
Expand All @@ -200,7 +198,7 @@ def compute(
use_native_system=use_native_system,
)
assert spherical_expansion_2.keys.names == expected_key_names
assert spherical_expansion_2.property_names == ["n"]
assert spherical_expansion_2.properties_names == ["n"]

array = spherical_expansion_2.keys.column("species_neighbor")
keys_to_move = Labels(
Expand Down
4 changes: 2 additions & 2 deletions rascaline-c-api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ endif()
# ============================================================================ #
# Setup metatensor

set(METATENSOR_GIT_VERSION "32ad5bb")
set(METATENSOR_GIT_VERSION "d97ea65")
set(METATENSOR_REQUIRED_VERSION "0.1")
if (RASCALINE_FETCH_METATENSOR)
message(STATUS "Fetching metatensor @ ${METATENSOR_GIT_VERSION} from github")
Expand All @@ -226,7 +226,7 @@ if (RASCALINE_FETCH_METATENSOR)
FetchContent_Declare(
metatensor
URL https://github.com/lab-cosmo/metatensor/archive/${METATENSOR_GIT_VERSION}.zip
URL_HASH MD5=cbc7bd27e9e2307638405d1613fa7f89
URL_HASH MD5=6a6899779591ae15861bb3547c7c354b
SOURCE_SUBDIR metatensor-core
VERBOSE
)
Expand Down
4 changes: 2 additions & 2 deletions rascaline-c-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ chemfiles = ["rascaline/chemfiles"]

[dependencies]
rascaline = {path = "../rascaline", version = "0.1.0", default-features = false}
metatensor = {git = "https://github.com/lab-cosmo/metatensor", rev = "32ad5bb"}
metatensor = {git = "https://github.com/lab-cosmo/metatensor", rev = "d97ea65"}

ndarray = "0.15"
log = { version = "0.4", features = ["std"] }
Expand All @@ -29,7 +29,7 @@ libc = "0.2"
[build-dependencies]
cbindgen = { version = "0.24", default-features = false }
fs_extra = "1"
metatensor = {git = "https://github.com/lab-cosmo/metatensor", rev = "32ad5bb"}
metatensor = {git = "https://github.com/lab-cosmo/metatensor", rev = "d97ea65"}

[dev-dependencies]
which = "4"
4 changes: 2 additions & 2 deletions rascaline-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ find_package(Torch 1.11 REQUIRED)
# ============================================================================ #
# Setup metatensor_torch

set(METATENSOR_GIT_VERSION "32ad5bb")
set(METATENSOR_GIT_VERSION "d97ea65")
set(REQUIRED_METATENSOR_TORCH_VERSION "0.1")
if (RASCALINE_TORCH_FETCH_METATENSOR_TORCH)
message(STATUS "Fetching metatensor_torch @ ${METATENSOR_GIT_VERSION} from github")
Expand All @@ -67,7 +67,7 @@ if (RASCALINE_TORCH_FETCH_METATENSOR_TORCH)
FetchContent_Declare(
metatensor_torch
URL https://github.com/lab-cosmo/metatensor/archive/${METATENSOR_GIT_VERSION}.zip
URL_HASH MD5=cbc7bd27e9e2307638405d1613fa7f89
URL_HASH MD5=6a6899779591ae15861bb3547c7c354b
SOURCE_SUBDIR metatensor-torch
VERBOSE
)
Expand Down
9 changes: 5 additions & 4 deletions rascaline-torch/src/autograd.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <atomic>
#include <algorithm>

#include "metatensor/torch/tensor.hpp"
#include "rascaline/torch/autograd.hpp"

using namespace metatensor_torch;
Expand Down Expand Up @@ -73,8 +74,8 @@ static std::vector<TorchTensorBlock> extract_gradient_blocks(
) {
auto gradients = std::vector<TorchTensorBlock>();
for (int64_t i=0; i<tensor->keys()->count(); i++) {
auto block = tensor->block_by_id(i);
auto gradient = block->gradient(parameter);
auto block = TensorMapHolder::block_by_id(tensor, i);
auto gradient = TensorBlockHolder::gradient(block, parameter);

gradients.push_back(torch::make_intrusive<TensorBlockHolder>(
gradient->values(),
Expand Down Expand Up @@ -103,7 +104,7 @@ std::vector<torch::Tensor> RascalineAutograd::forward(
if (all_positions.requires_grad()) {
ctx->saved_data.emplace("structures_start", structures_start);

auto gradient = block->gradient("positions");
auto gradient = TensorBlockHolder::gradient(block, "positions");
ctx->saved_data["positions_gradients"] = torch::make_intrusive<TensorBlockHolder>(
gradient->values(),
gradient->samples(),
Expand All @@ -115,7 +116,7 @@ std::vector<torch::Tensor> RascalineAutograd::forward(
if (all_cells.requires_grad()) {
ctx->saved_data["samples"] = block->samples();

auto gradient = block->gradient("cell");
auto gradient = TensorBlockHolder::gradient(block, "cell");
ctx->saved_data["cell_gradients"] = torch::make_intrusive<TensorBlockHolder>(
gradient->values(),
gradient->samples(),
Expand Down
11 changes: 6 additions & 5 deletions rascaline-torch/src/calculator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "rascaline/torch/calculator.hpp"
#include "metatensor/torch/block.hpp"
#include "metatensor/torch/tensor.hpp"
#include "rascaline/torch/autograd.hpp"
#include <c10/util/Exception.h>
Expand Down Expand Up @@ -86,7 +87,7 @@ static TorchTensorMap remove_other_gradients(
) {
auto new_blocks = std::vector<TorchTensorBlock>();
for (int64_t i=0; i<tensor->keys()->count(); i++) {
auto block = tensor->block_by_id(i);
auto block = TensorMapHolder::block_by_id(tensor, i);
auto new_block = torch::make_intrusive<TensorBlockHolder>(
block->values(),
block->samples(),
Expand All @@ -95,7 +96,7 @@ static TorchTensorMap remove_other_gradients(
);

for (const auto& parameter: gradients_to_keep) {
auto gradient = block->gradient(parameter);
auto gradient = TensorBlockHolder::gradient(block, parameter);
new_block->add_gradient(parameter, gradient);
}

Expand Down Expand Up @@ -196,7 +197,7 @@ metatensor_torch::TorchTensorMap CalculatorHolder::compute(
}

for (int64_t block_i=0; block_i<torch_descriptor->keys()->count(); block_i++) {
auto block = torch_descriptor->block_by_id(block_i);
auto block = TensorMapHolder::block_by_id(torch_descriptor, block_i);
// see `RascalineAutograd::forward` for an explanation of what's happening
auto _ = RascalineAutograd::apply(
all_positions,
Expand Down Expand Up @@ -228,7 +229,7 @@ metatensor_torch::TorchTensorMap rascaline_torch::register_autograd(
auto all_cells = stack_all_cells(systems);
auto structures_start_ivalue = torch::IValue();

auto precomputed_gradients = precomputed->block_by_id(0)->gradients_list();
auto precomputed_gradients = TensorMapHolder::block_by_id(precomputed, 0)->gradients_list();

if (all_positions.requires_grad()) {
if (!contains(precomputed_gradients, "positions")) {
Expand Down Expand Up @@ -271,7 +272,7 @@ metatensor_torch::TorchTensorMap rascaline_torch::register_autograd(
}

for (int64_t block_i=0; block_i<precomputed->keys()->count(); block_i++) {
auto block = precomputed->block_by_id(block_i);
auto block = TensorMapHolder::block_by_id(precomputed, block_i);
auto _ = RascalineAutograd::apply(
all_positions,
all_cells,
Expand Down
26 changes: 14 additions & 12 deletions rascaline-torch/tests/calculator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

#include "metatensor/torch/block.hpp"
#include <torch/torch.h>

#include <rascaline.hpp>
Expand All @@ -7,6 +8,7 @@
#include <catch.hpp>

using namespace rascaline_torch;
using namespace metatensor_torch;

static TorchSystem test_system(bool positions_grad, bool cell_grad);

Expand Down Expand Up @@ -35,7 +37,7 @@ TEST_CASE("Calculator") {
));

// H block
auto block = descriptor->block_by_id(0);
auto block = TensorMapHolder::block_by_id(descriptor, 0);
CHECK(*block->samples() == metatensor::Labels(
{"structure", "center"},
{{0, 1}, {0, 2}, {0, 3}}
Expand All @@ -56,7 +58,7 @@ TEST_CASE("Calculator") {
CHECK(block->gradients_list().empty());

// C block
block = descriptor->block_by_id(1);
block = TensorMapHolder::block_by_id(descriptor, 1);
CHECK(*block->samples() == metatensor::Labels(
{"structure", "center"},
{{0, 0}}
Expand Down Expand Up @@ -85,7 +87,7 @@ TEST_CASE("Calculator") {
));

// H block
auto block = descriptor->block_by_id(0);
auto block = TensorMapHolder::block_by_id(descriptor, 0);

auto values = block->values();
CHECK(values.requires_grad() == true);
Expand All @@ -95,7 +97,7 @@ TEST_CASE("Calculator") {
CHECK_THAT(grad_fn->name(), Catch::Matchers::Contains("rascaline_torch::RascalineAutograd"));

// forward gradients
auto gradient = block->gradient("positions");
auto gradient = TensorBlockHolder::gradient(block, "positions");
CHECK(*gradient->samples() == metatensor::Labels(
{"sample", "structure", "atom"},
{
Expand All @@ -117,7 +119,7 @@ TEST_CASE("Calculator") {
CHECK(torch::all(gradient->values() == expected).item<bool>());

// C block
block = descriptor->block_by_id(1);
block = TensorMapHolder::block_by_id(descriptor, 1);

values = block->values();
CHECK(values.requires_grad() == true);
Expand All @@ -127,7 +129,7 @@ TEST_CASE("Calculator") {
CHECK_THAT(grad_fn->name(), Catch::Matchers::Contains("rascaline_torch::RascalineAutograd"));

// forward gradients
gradient = block->gradient("positions");
gradient = TensorBlockHolder::gradient(block, "positions");
CHECK(*gradient->samples() == metatensor::Labels(
{"sample", "structure", "atom"},
{{0, 0, 0}, {0, 0, 1}}
Expand All @@ -149,7 +151,7 @@ TEST_CASE("Calculator") {
));

// H block
auto block = descriptor->block_by_id(0);
auto block = TensorMapHolder::block_by_id(descriptor, 0);

auto values = block->values();
CHECK(values.requires_grad() == true);
Expand All @@ -162,7 +164,7 @@ TEST_CASE("Calculator") {
CHECK(block->gradients_list().empty());

// C block
block = descriptor->block_by_id(1);
block = TensorMapHolder::block_by_id(descriptor, 1);

values = block->values();
CHECK(values.requires_grad() == true);
Expand All @@ -185,25 +187,25 @@ TEST_CASE("Calculator") {
));

// H block
auto block = descriptor->block_by_id(0);
auto block = TensorMapHolder::block_by_id(descriptor, 0);

auto values = block->values();
CHECK(values.requires_grad() == false);
CHECK(values.grad_fn() == nullptr);

// forward gradients
auto gradient = block->gradient("positions");
auto gradient = TensorBlockHolder::gradient(block, "positions");
CHECK(gradient->samples()->count() == 8);

// C block
block = descriptor->block_by_id(1);
block = TensorMapHolder::block_by_id(descriptor, 1);

values = block->values();
CHECK(values.requires_grad() == false);
CHECK(values.grad_fn() == nullptr);

// forward gradients
gradient = block->gradient("positions");
gradient = TensorBlockHolder::gradient(block, "positions");
CHECK(gradient->samples()->count() == 2);
}
}
Expand Down
2 changes: 1 addition & 1 deletion rascaline/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ name = "soap-power-spectrum"
harness = false

[dependencies]
metatensor = {git = "https://github.com/lab-cosmo/metatensor", rev = "32ad5bb", features = ["rayon"]}
metatensor = {git = "https://github.com/lab-cosmo/metatensor", rev = "d97ea65", features = ["rayon"]}

ndarray = {version = "0.15", features = ["approx-0_5", "rayon", "serde"]}
num-traits = "0.2"
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ lint-folders = "{toxinidir}/python" "{toxinidir}/setup.py"
# we need to manually install dependencies for rascaline, since tox will install
# the fresh wheel with `--no-deps` after building it.
metatensor-core-requirement =
metatensor-core @ https://github.com/lab-cosmo/metatensor/archive/32ad5bb.zip\#subdirectory=python/metatensor-core
metatensor-core @ https://github.com/lab-cosmo/metatensor/archive/d97ea65.zip\#subdirectory=python/metatensor-core

metatensor-torch-requirement =
metatensor-torch @ https://github.com/lab-cosmo/metatensor/archive/32ad5bb.zip\#subdirectory=python/metatensor-torch
metatensor-torch @ https://github.com/lab-cosmo/metatensor/archive/d97ea65.zip\#subdirectory=python/metatensor-torch

build-single-wheel = --no-deps --no-build-isolation --check-build-dependencies

Expand Down

0 comments on commit 913e698

Please sign in to comment.