-
-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example for regression of cde #573
Comments
You might find this issue helpful as a starting point: #519 |
Hello, Many thanks for your reply. I wrote up a simple example for regression which works. However, the training continues up to a certain error MSE=0.01 and flattens. It's desirable to have error 1e-6. Would you please provide a solution for this issue? import os import time import diffrax import matplotlib.pyplot as plt
Here we wrap up the entire ODE solve into a model.
Toy dataset of nonlinear oscillators. Sample paths look like deformed sines and cosines.
def get_data(dataset_size, *, key): # Generate datakey = jr.PRNGKey(42)dataset_size = 100 # Generate 5 trajectories for visualizationts, yts, coeffs = get_data(dataset_size, key=key)ys=yts[:, :, 1:]# Create figure with subplotsfig = plt.figure(figsize=(15, 10))# Plot individual components over timeax1 = plt.subplot(2, 2, 1)ax2 = plt.subplot(2, 2, 2)ax3 = plt.subplot(2, 2, (3, 4)) # Phase space plot takes bottom half# Plot first componentfor i in range(dataset_size):ax1.plot(ts[i, :], ys[i, :, 0], label=f'Trajectory {i+1}')ax1.grid(True)ax1.set_xlabel('Time')ax1.set_ylabel('y₁(t)')ax1.set_title('First Component')ax1.legend()# Plot second componentfor i in range(dataset_size):ax2.plot(ts[i, :], ys[i, :, 1], label=f'Trajectory {i+1}')ax2.grid(True)ax2.set_xlabel('Time')ax2.set_ylabel('y₂(t)')ax2.set_title('Second Component')ax2.legend()# Phase space plotfor i in range(dataset_size):ax3.plot(ys[i, :, 0], ys[i, :, 1], label=f'Trajectory {i+1}')# Mark start pointax3.plot(ys[i, 0, 0], ys[i, 0, 1], 'go', markersize=8, label='Start' if i == 0 else "")# Mark end pointax3.plot(ys[i, -1, 0], ys[i, -1, 1], 'ro', markersize=8, label='End' if i == 0 else "")ax3.grid(True)ax3.set_xlabel('y₁')ax3.set_ylabel('y₂')ax3.set_title('Phase Space')ax3.legend()# Add arrows to show direction of trajectoriesfor i in range(dataset_size):# Add arrows at regular intervalsn_arrows = 5idx = len(ts) // n_arrowsfor j in range(n_arrows):k = j * idxif k + idx < len(ts):dx = ys[i, k + idx, 0] - ys[i, k, 0]dy = ys[i, k + idx, 1] - ys[i, k, 1]ax3.arrow(ys[i, k, 0], ys[i, k, 1], dx/2, dy/2,head_width=0.05, head_length=0.08, fc='k', ec='k', alpha=0.5)plt.tight_layout()plt.show()# Print some statistical informationprint("\nTrajectory Statistics:")print("-" * 50)for i in range(dataset_size):print(f"\nTrajectory {i+1}:")print(f"Initial position: ({ys[i,0,0]:.3f}, {ys[i,0,1]:.3f})")print(f"Final position: ({ys[i,-1,0]:.3f}, {ys[i,-1,1]:.3f})")print(f"Maximum y₁: {jnp.max(ys[i,:,0]):.3f}")print(f"Maximum y₂: {jnp.max(ys[i,:,1]):.3f}")print(f"Minimum y₁: {jnp.min(ys[i,:,0]):.3f}")print(f"Minimum y₂: {jnp.min(ys[i,:,1]):.3f}")def dataloader(arrays, batch_size, *, key):
ts, ys, model, train_losses = main() data_key, model_key, loader_key = jr.split(key, 3) trajNo = 0 print(ts.shape) plt.figure(figsize=(5, 4)) plt.figure(figsize=(5, 4)) |
That’s not really a clear issue with diffrax itself, more of an open ended research question. And since I don’t know much about neural CDEs I doubt I can help. |
Hello,
Would you please provide a simple example for regression of cde?
I see there is neuralcde example for classification. It will be really useful if we have a regression one as well.
Many thank
The text was updated successfully, but these errors were encountered: