-
Notifications
You must be signed in to change notification settings - Fork 1
add neural network modules that work with tensor maps #62
Conversation
a27c00c
to
d5ec1ae
Compare
f01e400
to
665688d
Compare
8e8ad63
to
b3e8754
Compare
This can be done by the user using vanilla tensormap functions, right? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks good, the documentation will still need some work: how to wrap custom module inside equisolve.nn.Module
, how/when to use these/... but this can come at a later date.
src/equisolve/nn/__init__.py
Outdated
HAS_TORCH = False | ||
|
||
if HAS_TORCH: | ||
from .module_tensor import LinearTensorMap, ModuleTensorMap # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming wise, I would call these Linear
and Module
, and users can then use torch.nn.Linear/torch.nn.Module
or equisolve.nn.Linear/equisolve.nn.Module
for the different behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay lets try this. We might move these equisolve.torch.nn.Linear
, to better distinguish the methods here from the once supporting also numpy backend. But I would do the refactor in another PR.
.github/workflows/tests.yml
Outdated
- name: run Python metatensor-core tests | ||
run: tox -e metatensor-core-numpy-tests | ||
|
||
- name: run Python metatensor-core tests | ||
run: tox -e metatensor-core-torch-tests | ||
|
||
- name: run Python metatensor-torch tests | ||
run: tox -e metatensor-torch-tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove the metatensor
from the name. We don't have tests without metatensor I assume.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true
docs/requirements.txt
Outdated
@@ -10,3 +10,7 @@ furo | |||
rascaline @ https://github.com/luthaf/rascaline/archive/fb5332f.zip | |||
sphinx >=4.4 | |||
sphinx-gallery | |||
# temporary add torch to dependencies until rascaline is on metatensor | |||
# version 3c5fee40 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# version 3c5fee40 | |
# version 3c5fee40 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the comment, because we need to have a torch dependency anyway later for the examples. The dependency is now given through extra requirement on metatensor-torch in tox.ini and .readthedocs.yml
|
||
return TensorMap(tensor.keys, out_blocks) | ||
|
||
def forward_block(self, key: LabelsEntry, block: TensorBlock) -> TensorBlock: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can also be a Labels object right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you get it by iterating over the keys, you'll get LabelsEntry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input is one key of a TensorMap labels object. It has a bit different structure than metatensor-operations, because here we need the key to actually figure which module to apply on the block
:param out_tensor: a tensor map that has the output properties of the Labels | ||
""" | ||
|
||
def __init__(self, module_map: ModuleDict, out_tensor: Optional[TensorMap] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what the out_tensor
is doing. Is this the TensorMap
where the results will be written in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added better documentation. Also there was a bug how it was used in the Linear module
to wrap existing torch arbitrary modules and apply them on each key of a tensor map
b3e8754
to
32951b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is missing, is a functionality that allows to merge multiple blocks into a new one (e.g. in species compression), but I also don't think it is a good idea to put this in the same class.
This can be done by the user using vanilla tensormap functions, right?
Yes, I am not sure if one can create another Module for this purpose that is generic enough to be useful for multiple applications without being too abstract and less efficient than doing it "manually". For example the species decomposition in alchemical learning. You are transforming the species neighbor sparsity to a new property the pseudo species. So this combines a movement to properties and a linear combination. Don't have good ideas here.
The code looks good, the documentation will still need some work: how to wrap custom module inside equisolve.nn.Module, how/when to use these/... but this can come at a later date.
Improved the documentation. I will do an issue for an example for the neural network.
EDIT: figuring out how to add PIP_EXTRA_INDEX_URL to readthedocs, otherwise everything passes
EDIT: seems to work now, but I reached concurrency limit^^
:param out_tensor: a tensor map that has the output properties of the Labels | ||
""" | ||
|
||
def __init__(self, module_map: ModuleDict, out_tensor: Optional[TensorMap] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added better documentation. Also there was a bug how it was used in the Linear module
|
||
return TensorMap(tensor.keys, out_blocks) | ||
|
||
def forward_block(self, key: LabelsEntry, block: TensorBlock) -> TensorBlock: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input is one key of a TensorMap labels object. It has a bit different structure than metatensor-operations, because here we need the key to actually figure which module to apply on the block
docs/requirements.txt
Outdated
@@ -10,3 +10,7 @@ furo | |||
rascaline @ https://github.com/luthaf/rascaline/archive/fb5332f.zip | |||
sphinx >=4.4 | |||
sphinx-gallery | |||
# temporary add torch to dependencies until rascaline is on metatensor | |||
# version 3c5fee40 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the comment, because we need to have a torch dependency anyway later for the examples. The dependency is now given through extra requirement on metatensor-torch in tox.ini and .readthedocs.yml
.github/workflows/tests.yml
Outdated
- name: run Python metatensor-core tests | ||
run: tox -e metatensor-core-numpy-tests | ||
|
||
- name: run Python metatensor-core tests | ||
run: tox -e metatensor-core-torch-tests | ||
|
||
- name: run Python metatensor-torch tests | ||
run: tox -e metatensor-torch-tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true
src/equisolve/nn/__init__.py
Outdated
HAS_TORCH = False | ||
|
||
if HAS_TORCH: | ||
from .module_tensor import LinearTensorMap, ModuleTensorMap # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay lets try this. We might move these equisolve.torch.nn.Linear
, to better distinguish the methods here from the once supporting also numpy backend. But I would do the refactor in another PR.
the existing tests are moved to core-numpy-tests and new environments for metatensor-core with torch metatensor-torch are added
add tests for from_module and default constructor
the function random_single_block_no_components_tensor_map is now in one global utilities
add tests for from_module and default constructor
this optional dependency loads metatensor-torch
32951b5
to
5f732e7
Compare
5f732e7
to
5aa8936
Compare
This properly references TensorMap types in the doc instead of only showing abstruse ScriptClass types
5aa8936
to
5388061
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the improced Docs @agoscinski !
From the discussion in #13. Should replace https://github.com/bananenpampe/H2O/blob/move-rascaline/model/nn/modules.py
The goal is to offer a flexible wrapper class around torch modules that are applied on each block of a input TensorMap while also offer functions to wrap simpler cases, for example when the same module is used for all blocks (for that there is the
from_module
constructor).What is missing, is a functionality that allows to merge multiple blocks into a new one (e.g. in species compression), but I also don't think it is a good idea to put this in the same class.
TODO (will add these after first review):
__init__
ofLinear
from_module
ofModuleTensorMap
equisolve.nn
toequisolve.torch.nn
(needs to be discussed in meeting)📚 Documentation preview 📚: https://equisolve--62.org.readthedocs.build/en/62/