-
Notifications
You must be signed in to change notification settings - Fork 926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can odeint_adjoint solve parametric ODEs? #227
Comments
can someone please answer this question? |
Hi! I'm not one of the developers but I think you can do it this way: import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchdiffeq import odeint, odeint_adjoint
class ODEfunc(nn.Module):
def __init__(self, params):
super(ODEfunc, self).__init__()
self.params = params
def forward(self, t, y):
a, b, c = self.params
dydt = torch.zeros_like(y)
dydt[0] = a*y[0] - b*y[1]
dydt[1] = b*y[0] - c*y[1]
return dydt
time = torch.linspace(0.0, 10.0, 100)
params = torch.Tensor([1.0, 2.0, 3.0])
y0 = torch.Tensor([1.5, 0.25])
func = ODEfunc(params)
result = odeint(func, y0, time)
result_adjoint = odeint_adjoint(func, y0, time)
plt.plot(time, result[:, 0], color='tab:blue', zorder=0, label="odeint")
plt.scatter(time, result_adjoint[:, 0], color='tab:blue')
plt.plot(time, result[:, 1], color='tab:red', zorder=0)
plt.scatter(time, result_adjoint[:, 1], color='tab:red')
plt.xlabel('Time')
plt.ylabel('Y')
plt.show() |
Do this approach work when the parameters are not constant, but different for the samples we want to train on? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Let's say I have some linear system of parametric ODEs:
How do I pass the parameters to odeint/odeint_adjoint? In scipy.integrate.odeint, this would look like this:
The text was updated successfully, but these errors were encountered: