-
Notifications
You must be signed in to change notification settings - Fork 1
Add a Neural-Network class #13
Comments
Some thoughts about implementations that are necessary for basic features like a BPNN Basic modules
weight initialization methods (nn.init https://pytorch.org/docs/stable/nn.init.html)
Activation functionsThese are supported by n2p2
|
Instead of manually rewrapping every possible function in torch, this could be implemented as a generic wrapper taking an existing |
The weight initializer we have to reimplement, because they take as argument torch tensors. I am not sure about the torch.nn.Module, aren't we limiting a possible JAX support in the future that way? |
I don't think we need to touch weight initializers either. The API I have in mind would look like this: import torch
import numpy as np
from equistore import TensorBlock, TensorMap, Labels
# defined in equisolve
class EquistoreTorchWrapper(torch.nn.Module):
def __init__(self, module):
self.module = module
def forward(self, X: TensorMap):
blocks = []
for _, block in X:
new_values = self.module(block.values)
blocks.append(TensorBlock(
values=new_values,
samples=block.samples,
components=block.components,
properties=Labels(names="q", values=np.arange(new_values.shape[-1]))
))
# code the user writes
class MyFancyModel(torch.nn.Module):
def __init__(self, n_features, n_output):
self.nn = torch.nn.Sequential(
torch.nn.Linear(n_features, 80),
torch.nn.Tanh(),
torch.nn.Linear(80, 80),
torch.nn.Tanh(),
torch.nn.Linear(80, n_output),
)
# initialize weights as needed
torch.nn.init.uniform_(self.nn[0].weight)
torch.nn.init.uniform_(self.nn[2].weight)
torch.nn.init.uniform_(self.nn[4].weight)
def forward(self, X: torch.Tensor):
return torch.log(self.nn(X))
model = EquistoreTorchWrapper(MyFancyModel(n_features=234, n_output=1)) Here, the user still takes care of weight initialization & defining the model, and we apply it on a TensorMap. There are a couple of issues with the code above that still need to be improved:
Yes, although we could also write an |
I see that we can implement probably all methods with this wrapper approach. But I was thinking about the use case where you mix your model with methods that you only can do with a TensorMap (e.g. you do something that depends on the metadata for a TensorMap). So I was thinking about something this. # equisolve.nn
if HAS_TORCH:
Module = torch.nn.Module
elif HAS_JAX:
Module = jax....
else:
Module = BaseModule # some basic class having a forward function
# maybe one wants to do the decision in future
# if HAS_TORCH and HAS_JAX both true, but not
# important for now
def equistore_module_factory(nn_Module_class):
if HAS_TORCH:
class EquistoreModule(nn_Module_class):
def forward(self, X : TensorMap):
# what @Luthaf wrote
elif HAS_JAX:
class EquistoreModule(nn_Module_class):
# ... something analgous to EquistoreTorchWrapper
return EquistoreModule
if HAS_TORCH:
Linear = equistore_module_factory(torch.nn.Linear)
elif HAS_JAX:
# is not really existiing, has to be do done a bit
# different, but not important for the discussion
Linear = equistore_module_factory(jax.nn.Linear) # my_script.py
import torch # or import jax
import equistore
class MyModelUsingEquistoreMetadata(equistore.nn.Module):
def __init__(self, n_features, n_output):
self.linear = equistore.nn.Linear(feature_in, feature_out)
self.output = equistore.nn.Linear(feature_in, 1)
self.some_module = ...
self.some_other_module == ....
super().__init__()
def forward(self, X : TensorMap):
X = self.linear(X)
for key, block in X: # or something using the euistore meta information
if key == "0":
self.some_module(block)
else:
self.some_other_module(block)
return self.output(X) So we can use the same script with Jax and Torch |
So I would have two concerns here:
However, if we go with this wrapper approach, nothing prevents the user from wrapping every individual operation, so it is more of a question of what do we provide by default in equisolve. |
Yes I also saw that this works different, but one can do probably (NEEDS TO BE CHECKED!) something analogous. Found this https://rockpool.ai/_modules/nn/modules/jax/linear_jax.html#LinearJax which might be helpful to get an idea how much work this is. Your second point is a more serious concern to me and I don't know how to test this at the moment.
I feel like for the moment we should explore both use cases so providing also code for these cases: Making the wrapper available to the user so one can do what you did in your code snippet with the MyFancyModel (wrap whole pytorch model), but also a wrapping of all standard modules one can do something as in the I feel like we can start a first PR for the prototype. |
Except that's not jax, that's another package (rockpool) building on top of jax. Building a prototype & benchmarking it works for me! |
I know, but looking how other have wrapped it is not a bad idea I think. |
We haven't thought this out much but I open this issue here, that people who are interested can share their ideas.
The text was updated successfully, but these errors were encountered: