Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Differentiability, Refactoring] Rethink parameter dictionaries in backends / Introduce hybrid differentiation modes #255

Open
dominikandreasseitz opened this issue Dec 11, 2023 · 4 comments
Assignees
Labels
differentiability refactoring Refactoring of legacy code to_check To check in the development of the new expression system.

Comments

@dominikandreasseitz
Copy link
Collaborator

Issue:

Right now, when we do:

quantum_backend = SomeBackend()
conv = quantum_backend.convert(circuit, obs)
conv_circ, conv_obs, embedding_fn, params = conv

we store all of the following in the initial params dict:

(a) all variational user-facing parameters of the circuit AND observable
(b) all fixed parameters in both circuit AND observable

Issue 1: when using torch, the (torch-based) backend then knows for which params to compute gradients via the requires_grad flag. however, this doesnt work for JAX.

issue 2: both diff_modes ADJOINT and GPSR do not support parametric observables

Possible Solution:

Introduce separate parameter dicts for initial fixed and vparams in both circuit and observable:
initial_params = {'circuit_vparams': ..., 'circuit_fixedparams': ..., 'obs_vparams': ..., 'obs_fixedparams': ..., }

  1. This way, we can easily differentiate between v and fixed params in JAX
  2. We can start thinking about introducing hybrid diff_modes, circuit_diffmode = GPSR / ADJOINT, observable_diffmode= AD and use a certain diff routine on subsets of the parameters
@nmheim
Copy link
Collaborator

nmheim commented Dec 11, 2023

If I understand correctly, then the problem is that conv.params returns a dict that contains both fixed and variational parameters, right? would it be easier/more elegant to introduce a conv.vparams and conv.circuit.vparams (+same for observable) that returns only the variational parameters? then we don't have to change all the code that assumes conv.params to be one non-nested dict.

@dominikandreasseitz
Copy link
Collaborator Author

yes great idea, but i would try to avoid changing the low-level interface so i would be inclined to keep the conv.params and let it just return the composition of conv.circuit.vparams,...

next question: do we want to give the user the option to choose which diff_mode to use for a particular part of the model?

@GJBoth
Copy link
Collaborator

GJBoth commented Dec 12, 2023

If I remember correctly, I originally designed this with @awennersteen. Indeed we didn't consider fixing some parameters, and having different grad backends. There's two options, both without changing the API, I think:

  1. Have the parameters always the returned dict always be trainable, and let the embedding_fn take care of adding in the non-trainable params. In my opinion not a great option for various reasons, but possible.
  2. Instead of the params being a basic dict, make it a slightly more involved object with for example trainable_params, fixed_params, and the corresponding AD rules for each group. This is very Jax style (have a look at optax) I think this might cover everything we need. This is my preferred option, and I think in line with @dominikandreasseitz idea, if I understand it correctly?

@awennersteen
Copy link
Member

next question: do we want to give the user the option to choose which diff_mode to use for a particular part of the model?
IMHO, this sounds dangerous. But I guess it makes a lot of sense if consider for example a hybrid model where we have a classical NN composed with a QNN. I think that it should be strictly defined.

I, like @GJBoth, have no recollection of why we did this and what we may have considered or not :p
The one thing I do remember was that after the initial design over the next month or so there where many hacks and patches to make it actually work...

Since @nmheim was asking about namedtuples the other day, maybe this is another place to use them?
so that we have a more solid object, we keep all the different data in there (Gert-Jax' option number 2), and then go for it?

My only concern is how this might behave together with the idea of different diff-modes for different parts? But maybe this is the best way of achieving that too?
Suppose we end up saying that in order to use different Diff modes you would achieve this by composing multiple QuantumModels (or maybe DifferentiableBackends or whatever is the current name). Then in this namedtuple keeping track of parameters we could also keep track of which model they belong to. So then by using Gert-Jax' idea of "looking up the AD rules for each group" I guess that could be achieved arbitrarly. This is quickly overengineering though and we should think about that perhaps before implementing that part.

@RolandMacDoland RolandMacDoland changed the title Rethink parameter dictionaries in backends / Introduce hybrid differentiation modes [Differentiability, Refactoring] Rethink parameter dictionaries in backends / Introduce hybrid differentiation modes Feb 7, 2024
@RolandMacDoland RolandMacDoland removed their assignment Apr 24, 2024
@jpmoutinho jpmoutinho added the to_check To check in the development of the new expression system. label May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
differentiability refactoring Refactoring of legacy code to_check To check in the development of the new expression system.
Projects
None yet
Development

No branches or pull requests

7 participants