Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Add a Neural-Network class #13

Open
PicoCentauri opened this issue Feb 22, 2023 · 9 comments
Open

Add a Neural-Network class #13

PicoCentauri opened this issue Feb 22, 2023 · 9 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@PicoCentauri
Copy link
Collaborator

We haven't thought this out much but I open this issue here, that people who are interested can share their ideas.

@PicoCentauri PicoCentauri added enhancement New feature or request help wanted Extra attention is needed labels Feb 22, 2023
@agoscinski
Copy link
Collaborator

Some thoughts about implementations that are necessary for basic features like a BPNN

Basic modules

  • nn.Linear

weight initialization methods (nn.init https://pytorch.org/docs/stable/nn.init.html)

  • Where do we put this? maybe for now in equisolve/module/nn/initializers?
  • How this will work for gradients? Also just random initialization?

Activation functions

These are supported by n2p2

  • TANH (nn.functionals)
  • LOGISTIC
  • SOFTPLUS (nn.functionals)
  • RELU (nn.functionals)
  • GAUSSIAN
  • COS
  • REVLOGISTIC
  • EXP
  • HARMONIC

https://compphysvienna.github.io/n2p2/doxygen/classnnp_1_1NeuralNetwork.html#a032b3b525f06cd70953aec8e6aeedbedaf565187036a46ee35d905df14542dc01

@Luthaf
Copy link
Collaborator

Luthaf commented Mar 7, 2023

Instead of manually rewrapping every possible function in torch, this could be implemented as a generic wrapper taking an existing torch.nn.Module (which would be torch.nn.Sequencial for basic NN) and applying it block-by-block to a TensorMap. Then users are free to build arbitrarily complex NN and use them with equisolve with minimal amount of work on our side.

@agoscinski
Copy link
Collaborator

agoscinski commented Mar 7, 2023

The weight initializer we have to reimplement, because they take as argument torch tensors. I am not sure about the torch.nn.Module, aren't we limiting a possible JAX support in the future that way?

@Luthaf
Copy link
Collaborator

Luthaf commented Mar 8, 2023

I don't think we need to touch weight initializers either. The API I have in mind would look like this:

import torch
import numpy as np
from equistore import TensorBlock, TensorMap, Labels

# defined in equisolve
class EquistoreTorchWrapper(torch.nn.Module):
    def __init__(self, module):
        self.module = module

    def forward(self, X: TensorMap):
        blocks = []
        for _, block in X:
            new_values = self.module(block.values)

            blocks.append(TensorBlock(
                values=new_values,
                samples=block.samples, 
                components=block.components,
                properties=Labels(names="q", values=np.arange(new_values.shape[-1]))
            ))

# code the user writes
class MyFancyModel(torch.nn.Module):
    def __init__(self, n_features, n_output):
        self.nn = torch.nn.Sequential(
            torch.nn.Linear(n_features, 80),
            torch.nn.Tanh(),
            torch.nn.Linear(80, 80),
            torch.nn.Tanh(),
            torch.nn.Linear(80, n_output),
        )

        # initialize weights as needed
        torch.nn.init.uniform_(self.nn[0].weight)
        torch.nn.init.uniform_(self.nn[2].weight)
        torch.nn.init.uniform_(self.nn[4].weight)
    
    def forward(self, X: torch.Tensor):
        return torch.log(self.nn(X))


model = EquistoreTorchWrapper(MyFancyModel(n_features=234, n_output=1))

Here, the user still takes care of weight initialization & defining the model, and we apply it on a TensorMap. There are a couple of issues with the code above that still need to be improved:

  • how to do different models per block
  • how to get the right size (n_features above) when creating the models
  • maybe others ^_^

aren't we limiting a possible JAX support in the future that way?

Yes, although we could also write an EquistoreJaxWrapper class for Jax support. This should be a lot less work than trying to unify torch & jax with a common API.

@agoscinski
Copy link
Collaborator

I see that we can implement probably all methods with this wrapper approach. But I was thinking about the use case where you mix your model with methods that you only can do with a TensorMap (e.g. you do something that depends on the metadata for a TensorMap). So I was thinking about something this.

# equisolve.nn 

if HAS_TORCH:
    Module = torch.nn.Module
elif HAS_JAX:
    Module = jax....
else:
    Module = BaseModule # some basic class having a forward function
# maybe one wants to do the decision in future
# if HAS_TORCH and HAS_JAX both true, but not
# important for now

def equistore_module_factory(nn_Module_class):
    if HAS_TORCH:
        class EquistoreModule(nn_Module_class):
            def forward(self, X : TensorMap):
                # what @Luthaf wrote
    elif HAS_JAX:
        class EquistoreModule(nn_Module_class):
            # ... something analgous to EquistoreTorchWrapper
    return EquistoreModule

if HAS_TORCH:
  Linear = equistore_module_factory(torch.nn.Linear)
elif HAS_JAX:
  # is not really existiing, has to be do done a bit
  # different, but not important for the discussion
  Linear = equistore_module_factory(jax.nn.Linear)
# my_script.py
import torch # or import jax
import equistore

class MyModelUsingEquistoreMetadata(equistore.nn.Module):
    def __init__(self, n_features, n_output):
        self.linear = equistore.nn.Linear(feature_in, feature_out)
        self.output = equistore.nn.Linear(feature_in, 1)
        self.some_module = ...
        self.some_other_module == ....
        super().__init__()
    
    def forward(self, X : TensorMap):
        X = self.linear(X)
        for key, block in X: # or something using the euistore meta information
            if key == "0":
                self.some_module(block)
            else:
                self.some_other_module(block)
        return self.output(X)

So we can use the same script with Jax and Torch

@Luthaf
Copy link
Collaborator

Luthaf commented Mar 9, 2023

So I would have two concerns here:

  1. I don't think jax has anything resembling torch.nn.Module (root class that needs to be used for everything). They rely more on tracing standard python functions with pseudo-arrays to extract the operations and build the corresponding gpu/... kernel. So I'm not sure this design could be used as-is with jax

  2. We would have to check, but I fear that wrapping individual torch modules one by one will introduce a lot of redundant overhead, while wrapping the overall module allow us to minimize the cost.

However, if we go with this wrapper approach, nothing prevents the user from wrapping every individual operation, so it is more of a question of what do we provide by default in equisolve.

@agoscinski
Copy link
Collaborator

I don't think jax has anything resembling torch.nn.Module (root class that needs to be used for everything). They rely more on tracing standard python functions with pseudo-arrays to extract the operations and build the corresponding gpu/... kernel. So I'm not sure this design could be used as-is with jax

Yes I also saw that this works different, but one can do probably (NEEDS TO BE CHECKED!) something analogous. Found this https://rockpool.ai/_modules/nn/modules/jax/linear_jax.html#LinearJax which might be helpful to get an idea how much work this is. Your second point is a more serious concern to me and I don't know how to test this at the moment.

However, if we go with this wrapper approach, nothing prevents the user from wrapping every individual operation, so it is more of a question of what do we provide by default in equisolve.

I feel like for the moment we should explore both use cases so providing also code for these cases: Making the wrapper available to the user so one can do what you did in your code snippet with the MyFancyModel (wrap whole pytorch model), but also a wrapping of all standard modules one can do something as in the MyModelUsingEquistoreMetadata in my code snippet. Then we can check also if this introduces serious overheads that make this approach less useful. For the first prototype we only implement some basic operations to do benchmarks.

I feel like we can start a first PR for the prototype.

@Luthaf
Copy link
Collaborator

Luthaf commented Mar 9, 2023

Found this https://rockpool.ai/_modules/nn/modules/jax/linear_jax.html#LinearJax which might be helpful to get an idea how much work this is.

Except that's not jax, that's another package (rockpool) building on top of jax.


Building a prototype & benchmarking it works for me!

@agoscinski
Copy link
Collaborator

Except that's not jax, that's another package (rockpool) building on top of jax.

I know, but looking how other have wrapped it is not a bad idea I think.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants