Skip to content

Commit

Permalink
nflows only in nnanalysis (#376)
Browse files Browse the repository at this point in the history
* nflows only in nnanalysis

* restrict p-tqdm
  • Loading branch information
sahiljhawar authored Aug 5, 2024
1 parent c61097a commit cdd222f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
40 changes: 20 additions & 20 deletions nmma/em/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,6 @@
from .utils import getFilteredMag, dataProcess
from .io import loadEvent

# import functions
from ..mlmodel.dataprocessing import gen_prepend_filler, gen_append_filler, pad_the_data
from ..mlmodel.resnet import ResNet
from ..mlmodel.embedding import SimilarityEmbedding
from ..mlmodel.normalizingflows import normflow_params
from ..mlmodel.inference import cast_as_bilby_result

# need to add these packages:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
import torch.nn.functional as F
from nflows.nn.nets.resnet import ResidualNet
from nflows import transforms, distributions, flows
from nflows.distributions import StandardNormal
from nflows.flows import Flow
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms import CompositeTransform, RandomPermutation
import nflows.utils as torchutils

matplotlib.use("agg")


Expand Down Expand Up @@ -1175,6 +1155,26 @@ def analysis(args):

def nnanalysis(args):

# import functions
from ..mlmodel.dataprocessing import gen_prepend_filler, gen_append_filler, pad_the_data
from ..mlmodel.resnet import ResNet
from ..mlmodel.embedding import SimilarityEmbedding
from ..mlmodel.normalizingflows import normflow_params
from ..mlmodel.inference import cast_as_bilby_result

# need to add these packages:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
import torch.nn.functional as F
from nflows.nn.nets.resnet import ResidualNet
from nflows import transforms, distributions, flows
from nflows.distributions import StandardNormal
from nflows.flows import Flow
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms import CompositeTransform, RandomPermutation
import nflows.utils as torchutils

# only continue if the Kasen model is selected
if args.model != "Ka2017":
print(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pymultinest
sncosmo
dust_extinction
arviz
p_tqdm
p_tqdm<1.4.1
tornado
notebook
ligo.skymap
Expand Down

0 comments on commit cdd222f

Please sign in to comment.