-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of the power spectrum inference with GPs #832
base: main
Are you sure you want to change the base?
Conversation
Hello @mlefkir! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:
Comment last updated at 2024-10-03 11:36:06 UTC |
@mlefkir thanks for your contribution to Stingray, and sorry for my late reply! From the point of view of the requirements for PRs to Stingray:
|
Yes, @mlefkir, do you have any use case example for this method (Something like a .ipynb notebook). Also @dhuppenkothen, can you also have a look into the usefulness of this method for Stingray, and whether it should go into the same file as the gpmodeling part? |
@Gaurav17Joshi @matteobachetti I made two examples available here Examples, one uses nested sampling with the jaxns sampler already called in Stingray and the other one uses NUTS with NumPyro. |
@mlefkir I'm playing with your PR. Really sorry for the slow progress, but my knowledge of these methods is pretty poor and it takes me a lot of time to just understand how it works, and... my agenda is pretty full 😅 . |
@matteobachetti This model is an improvement on available models. Currently, available models have a fixed low and high-frequency slope. For instance, the DRW/Exponential kernel has a Lorentzian power spectrum with a low-frequency slope 0 and a high-frequency slope of -2. This method allows modelling spectral shapes with flexible bends frequencies and slopes which can be between 0 and 4 using a sum of basis functions as shown in the figure below: This method is designed for Gaussian process regression so it can be used for any Gaussian time series with (or without) error bars. While the algorithm in tinygp is fast there are limitations on the number of points in terms of computational cost so I would use the method only for irregularly sampled data or data with gaps with less than 10,000 points. |
@matteobachetti If you are interested, the preprint of the paper describing the method is available here: https://arxiv.org/pdf/2501.05886 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mlefkir Sorry for the long wait. This method is very interesting and certainly deserves inclusion in Stingray. In my tests, I found that DefaultNestedSampler
is not recognized, I think jaxns
has changed its API, but monkeypatching it to point toNestedSampler
made it work.
Besides this, everything seems to work and I think the best test is now put it out there and let people use it. The method itself, in the Python implementation, seems quite slow (the sampling took hours in my laptop). Do we need to do anything specific to make it work with GPUs? My computer does not have one, so I don't know how to test this.
We need to maybe have it documented properly. Would you mind also adding the Pioran stingray notebooks you already shared, possibly with some more in-depth explanation of the methods, to our notebook repository?
Minor thing: the branch needs to be merged or rebased with main
. I also suggest to run black -l 100 stingray/
from the main directory and commit the results in order to solve the code style failure
fig1, fig2 = run_prior_checks( | ||
self.kernel_type, self.kernel_params, self.priors, loglike, 5.5e-2, 0.5 | ||
) | ||
plt.fignum_exists(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this would also work if some other test has opened a figure.
noise_color = "C5" | ||
window_color = "k" | ||
|
||
fig, ax = plt.subplots(1, 1, figsize=(6, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably good to give a unique name to these plots (also good for testing purposes!)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #832 +/- ##
==========================================
- Coverage 96.03% 94.71% -1.33%
==========================================
Files 48 48
Lines 9770 9950 +180
==========================================
+ Hits 9383 9424 +41
- Misses 387 526 +139 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @mlefkir! This is great, thank you for submitting this, and I'm sorry it's taken me so long to get to this!!!
Do you perhaps have a Jupyter notebook for inclusion in the documentation? If so, it'd be great if you could share that with me, because it'd also help me understand how the API is supposed to work. :D
Overall, great work! I'm so excited we get to have this in Stingray!
In terms of comments, I mostly have nitpicky comments about doc strings and code consistency:
- make sure your docstrings are consistent with the code (I've found a few places where the docstring does not match the function definition).
- expand your doc strings. Users will use these to try and understand what the function does, what its inputs are and what it returns. Make sure there's enough information for users to understand (1) what the function does and how, (2) what they need to put into the function, what form those inputs take, and (3) what comes out of the function. I've highlighted a bunch of places where I've had trouble figuring out what's happening, but probably not all of them.
- code consistency: make sure that there's consistency in your code across functions. For example, ideally input variables of the same name should always take the same kind of input type and input information, and variable names should match the kind of information put in. Document extensively where you break that pattern for whatever reason.
- Make sure that you implement consistent behaviour across functions. For example, some functions produce plots that are saved to file, other just return the Figure or Figure and Axes objects. Unless necessary, make sure that you're consistent in what these functions do, and document extensively where you break the pattern.
Log likelihood function. | ||
n_samples : int | ||
Number of samples. Default is 3000. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some nitpicks documentation questions here. From just looking at the docs/function, it's really hard to understand what's supposed to go in here. For example, kernel_params
is missing from the list of parameters. Is this a list of strings? Of values? What goes into priors
? A list of what? Functions? Something else? Be aware that these are the kinds of things a new user will ask themselves.
|
||
# get the prior model | ||
prior_dict = dict(zip(kernel_params, priors)) | ||
prior_model = get_prior(kernel_params, prior_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines imply that the priors in priors
need to be in the correct order according to the order in kernel_params
, yes? Why not automatically enforce this (and avoid user error), by forcing the user to put in that dictionary already, rather than construct it here from two lists?
approximate_with="SHO", | ||
with_normalisation=False, | ||
): # -> tuple[NDArray[Any], NDArray[Any]]: | ||
"""Get the PSD and the approximate PSD for a given set of parameters and samples. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure what this does? Specifically, why does this function need prior samples? Computing the PSD and approximate PSD should be analytical, no?
Note that there are parameters above that are not in the list of parameters. On the other hand, the docstring below claims that there's a parameter n_samples
that does not appear in the list of parameters in the function definition. I think it might be worth going through the doc strings and making sure they're consistent. It'd also be good to expand the information in them (both in terms of what the function does, and especially in terms of what goes into the input attributes, to help new users understand what they need to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, having read the function itself, I'm wondering whether it would make sense to split out the making of the PSD/approx PSD for a single set of parameter, and then have a function that loops over a set of samples (do they have to be prior samples, by the way, the way the function definition insists?). Feel free to tell me I'm wrong, but I could well imagine that someone would like to just take a single set of parameters and compute the PSD and approximate PSD from those.
for k in range(n_samples): | ||
param_dict = {} | ||
for i, params in enumerate(kernel_params): | ||
if params[0:4] == "log_": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this line do? That seems to hardcode some kind of log-values? What for? In what ways could that break, either due to using a different kernel, or due to user error?
param_dict = {} | ||
for i, params in enumerate(kernel_params): | ||
if params[0:4] == "log_": | ||
param_dict[params[4:]] = jnp.exp(prior_samples[params][k]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm probably just confused here, but: these two lines seem to indicate that if the first four entries in params
start with log
, we'll exponentiate the other ones? I.e. not the ones with log_
in front of the name? Is that what's intended or am I missing something here?
|
||
|
||
def get_psd_approx_samples( | ||
f, kernel_type, kernel_params, f_min, f_max, n_approx_components=20, approximate_with="SHO" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm mildly confused why this has "samples" in the name?
n_approx_components=n_approx_components, | ||
approximate_with=approximate_with, | ||
) | ||
psd_SHO = SHO_power_spectrum(f, a[..., None], f_c[..., None]).sum(axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this function return anything? Right now, it doesn't seem to?
|
||
def get_kernel(kernel_type, kernel_params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs docstring?
raise NotImplementedError(f"Approximation {approximate_with} not implemented") | ||
|
||
|
||
def _psd_model(kernel_type, kernel_params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why this has a "_" in front of it? Do we not expect users to want to get the power spectrum model, or is there another function that fulfils this task?
@@ -50,11 +650,22 @@ def get_kernel(kernel_type, kernel_params): | |||
kernel_type: string |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I think your code uses kernel_type
to mean the power spectral shape, whereas the previous use was for the actual covariance function used in the GP, no? It'd be good to make sure these don't get confused?
This is the Python-JAX implementation of a method to infer the power spectral density of irregular time series using Gaussian process regression. The method is described in a forthcoming paper and in the Julia package Pioran.jl, it relies on approximating a bending power-law model in a sum of scalable kernels implemented in tinygp.
Relevant Issue(s)/PR(s)
Provide an overview of the implemented solution or the fix and elaborate on the modifications.
Is there a new dependency introduced by your contribution? If so, please specify.
Any other comments?