Skip to content

Commit

Permalink
Update extrinsic test
Browse files Browse the repository at this point in the history
  • Loading branch information
Tsun Ho Pang committed Aug 17, 2024
1 parent fd33882 commit 03e76dc
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions test/integration/test_extrinsic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
import psutil
p = psutil.Process()
p.cpu_affinity([0])

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from astropy.time import Time

import jax
Expand All @@ -14,7 +7,7 @@
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior
from jimgw.single_event.detector import H1, L1, V1
from jimgw.single_event.likelihood import ZeroLikelihood
from jimgw.transforms import BoundToUnbound
from jimgw.transforms import BoundToUnbound, SingleSidedUnboundTransform
from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform
from flowMC.strategy.optimization import optimization_Adam

Expand All @@ -30,7 +23,7 @@
ifos = [H1, L1, V1]

M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"])
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
dL_prior = PowerLawPrior(10.0, 200.0, -2.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
iota_prior = SinePrior(parameter_names=["iota"])
Expand Down Expand Up @@ -93,7 +86,7 @@ def calc_R_dets(ra, dec, psi, iota):
BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1),
BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=d_hat_min, original_upper_bound=d_hat_max),
SingleSidedUnboundTransform(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=float(d_hat_min)),
]

likelihood_transforms = []
Expand All @@ -119,9 +112,9 @@ def calc_R_dets(ra, dec, psi, iota):
likelihood_transforms=likelihood_transforms,
n_loop_training=n_loop_training,
n_loop_production=1,
n_local_steps=5,
n_global_steps=5,
n_chains=4,
n_local_steps=1,
n_global_steps=1,
n_chains=10,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30,
Expand Down

0 comments on commit 03e76dc

Please sign in to comment.