Skip to content

Conversation

jtgrasb
Copy link
Collaborator

@jtgrasb jtgrasb commented Aug 11, 2025

Description

Convert autograd to jax.

Wavebot tutorial is working for the first optimization.

@jtgrasb
Copy link
Collaborator Author

jtgrasb commented Aug 12, 2025

Wavebot is working now when converted to jax. Working on updating the code so that only the necessary functions use jax.numpy while the rest use numpy.

cmichelenstrofer and others added 28 commits August 20, 2025 13:00
* bug bix : DC and Nyquist frequency should not be devided by two before ifft

* Changed td_to_fd to scale single sided frequency components rather than TD signal

* minor bug fix from issue332 sandialabs#332
* added initial file changes based on sphinx_multiversion docs and WEC-Sim implementation

* removed sphinx-multiversion since it is no longer supported and made manual multiversion

* now uses absolute paths, commented out linkcheck for debugging

* fixed docstring errors in utilities module

* updating files again that somehow got reverted

* fixing path in conf.py

* don't run tutorials (will revert later)

* handle file moves correctly, fixed if statement to make other versions appear

* fixed two bugs in versions template

* reverted temp changes, changes latest to main

* switched latest to main

* main branch now in root directory of pages

* fixed URLs with change from last commit

* make other branches visible before building

* switched main branch tag for more testing

* fixed typo

* switched dev branch to an existing branch

* renamed main to latest, changed version.html file name to avoid confusion

* added prints about moving files so Sphinx output isn't misleading

* fixed typo with quotations

* changed versions.html name back because that broke things I guess

* modified contributing documentation to reflect changes

* add logic to remove duplicate 'latest' branch

* Fixed pathing when already on latest

* remove typo

* Troubleshooting complete, switching back to correct branches for deployment

* Removed extra word in docstring

* removed redundant function

* fixed pathing so returns to same file (and fixes tutorial/API docs)

* changed latest branch for demonstration

* switched back latest branch for deployment
* removed conda environment from workflows since newer capytaine/wavespectra work with Windows

* fixed unnecessary capitalization

* still create CI conda environment to fix Mac environment failures

* added conda env fully back in, push workflow deploys docs, split PR workflow

* conda environment activates again

* mambaforge instead of miniforge

* manual cache reset

* reset to older version of setup-miniconda to troubleshoot
* Try specifying subversion

* Test new cache

* revert to 3.12

* Revert comment back to normal
@jtgrasb
Copy link
Collaborator Author

jtgrasb commented Aug 21, 2025

To do:

  • Add in functions such as vmap, grad, and jit to increase code performance.
  • Resolve MacOS issues on GitHub actions.

@jtgrasb
Copy link
Collaborator Author

jtgrasb commented Sep 11, 2025

I added just-in-time compilation to the optimization (objective function, constraints, and relative gradients) using jax.jit which should speed up the code. I also had to change the call to block_diag() in the mimo_transfer_mat() function to a revised function that is now jittable. There were a couple of assert statements that I commented out for now because assertions related to the dynamic value (some were able to be left in because they were only checking static properties, which is jittable).

Here is the computation time with and without jit for the AquaHarmonics parameter sweep cell. For some reason, I'm not seeing that much of a speedup so need to look into this further.
Autograd: 42.98 s, jax with jit: 38.69 s

To do:

  • Figure out why speed is not improved that much with jit.
    • Based on this issue thread, it seems like its because scipy does computation on numpy arrays which means the data type keeps getting converted back and forth between jax and numpy arrays.
    • Compare time to jax without jit
  • investigate jax.vmap
    • Would need to write functions differently to vectorize. Would be good but for future work.
  • Resolve MacOS issues on GitHub actions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants