Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor into proper package #4

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data/bsds300.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import h5py
import numpy as np
import os
import utils
from nsflow import utils

from matplotlib import pyplot as plt
from torch.utils import data
Expand Down
2 changes: 1 addition & 1 deletion data/gas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import os
import pandas as pd
import utils
from nsflow import utils

from matplotlib import pyplot as plt
from torch.utils.data import Dataset
Expand Down
2 changes: 1 addition & 1 deletion data/hepmass.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import os
import pandas as pd
import utils
from nsflow import utils

from collections import Counter
from matplotlib import pyplot as plt
Expand Down
2 changes: 1 addition & 1 deletion data/miniboone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import os
import utils
from nsflow import utils

from matplotlib import pyplot as plt
from torch.utils.data import Dataset
Expand Down
2 changes: 1 addition & 1 deletion data/omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import torch

import utils
from nsflow import utils

from PIL import Image
from scipy.io import loadmat
Expand Down
2 changes: 1 addition & 1 deletion data/plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import distributions
from torch.utils.data import Dataset

import utils
from nsflow import utils


class PlaneDataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion data/power.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import os
import utils
from nsflow import utils

from matplotlib import pyplot as plt
from torch.utils.data import Dataset
Expand Down
File renamed without changes.
Empty file added nsflow/nde/__init__.py
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch import nn

import utils
from nsflow import utils


class NoMeanException(Exception):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from torch.nn import functional as F

import utils
from nsflow import utils

from nde import distributions
from nsflow.nde import distributions


class ConditionalIndependentBernoulli(distributions.Distribution):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import numpy as np
import torch

import utils
from nsflow import utils

from nde import distributions
from nsflow.nde import distributions


class StandardNormal(distributions.Distribution):
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torchtestcase
import unittest
from nde.distributions import discrete
from nsflow.nde.distributions import discrete


class ConditionalIndependentBernoulliTest(torchtestcase.TorchTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torchtestcase
import unittest
from nde.distributions import normal
from nsflow.nde.distributions import normal


class StandardNormalTest(torchtestcase.TorchTestCase):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from torch.nn import functional as F

from nde import distributions
from nde import flows
from nde import transforms
from nsflow.nde import distributions
from nsflow.nde import flows
from nsflow.nde import transforms


class MaskedAutoregressiveFlow(flows.Flow):
Expand Down
4 changes: 2 additions & 2 deletions nde/flows/base.py → nsflow/nde/flows/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Basic definitions for the flows module."""

import utils
from nsflow import utils

from nde import distributions
from nsflow.nde import distributions


class Flow(distributions.Distribution):
Expand Down
8 changes: 4 additions & 4 deletions nde/flows/realnvp.py → nsflow/nde/flows/realnvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from torch.nn import functional as F

from nde import distributions
from nde import flows
from nde import transforms
import nn as nn_
from nsflow.nde import distributions
from nsflow.nde import flows
from nsflow.nde import transforms
from nsflow import nn as nn_


class SimpleRealNVP(flows.Flow):
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torchtestcase
import unittest

from nde.flows import autoregressive as ar
from nsflow.nde.flows import autoregressive as ar


class MaskedAutoregressiveFlowTest(torchtestcase.TorchTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
import torchtestcase
import unittest
from nde import transforms
from nde import distributions
from nde.flows import base
from nsflow.nde import transforms
from nsflow.nde import distributions
from nsflow.nde.flows import base


class FlowTest(torchtestcase.TorchTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torchtestcase
import unittest

from nde.flows import realnvp
from nsflow.nde.flows import realnvp


class SimpleRealNVPTest(torchtestcase.TorchTestCase):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
from torch.nn import functional as F

import utils
from nde import transforms
from nde.transforms import made as made_module
from nde.transforms import splines
from nsflow import utils
from nsflow.nde import transforms
from nsflow.nde.transforms import made as made_module
from nsflow.nde.transforms import splines

class AutoregressiveTransform(transforms.Transform):
"""Transforms each input variable with an invertible elementwise transformation.
Expand Down
2 changes: 1 addition & 1 deletion nde/transforms/base.py → nsflow/nde/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch import nn

import utils
from nsflow import utils


class InverseNotAvailable(Exception):
Expand Down
5 changes: 3 additions & 2 deletions nde/transforms/conv.py → nsflow/nde/transforms/conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import utils
from nde import transforms
from nsflow import utils

from nsflow.nde import transforms


class OneByOneConvolution(transforms.LULinear):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch
from torch.nn import functional as F

import utils
from nsflow import utils

from nde import transforms
from nde.transforms import splines
from nsflow.nde import transforms
from nsflow.nde.transforms import splines


class CouplingTransform(transforms.Transform):
Expand Down
8 changes: 4 additions & 4 deletions nde/transforms/linear.py → nsflow/nde/transforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from torch import nn
from torch.nn import functional as F, init

import utils
from nsflow import utils

from nde import transforms
from nsflow.nde import transforms


class LinearCache(object):
Expand Down Expand Up @@ -177,7 +177,7 @@ def inverse_no_cache(self, inputs):
"""
batch_size = inputs.shape[0]
outputs = inputs - self.bias
outputs, lu = torch.gesv(outputs.t(), self._weight) # Linear-system solver.
outputs, lu = torch.solve(outputs.t(), self._weight) # Linear-system solver.
outputs = outputs.t()
# The linear-system solver returns the LU decomposition of the weights, which we
# can use to obtain the log absolute determinant directly.
Expand Down Expand Up @@ -210,7 +210,7 @@ def weight_inverse_and_logabsdet(self):
"""
# If both weight inverse and logabsdet are needed, it's cheaper to compute both together.
identity = torch.eye(self.features, self.features)
weight_inv, lu = torch.gesv(identity, self._weight) # Linear-system solver.
weight_inv, lu = torch.solve(identity, self._weight) # Linear-system solver.
logabsdet = torch.sum(torch.log(torch.abs(torch.diag(lu))))
return weight_inv, logabsdet

Expand Down
8 changes: 4 additions & 4 deletions nde/transforms/lu.py → nsflow/nde/transforms/lu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from torch.nn import functional as F, init

from nde.transforms.linear import Linear
from nsflow.nde.transforms.linear import Linear


class LULinear(Linear):
Expand Down Expand Up @@ -103,8 +103,8 @@ def weight_inverse(self):
"""
lower, upper = self._create_lower_upper()
identity = torch.eye(self.features, self.features)
lower_inverse, _ = torch.trtrs(identity, lower, upper=False, unitriangular=True)
weight_inverse, _ = torch.trtrs(lower_inverse, upper, upper=True, unitriangular=False)
lower_inverse, _ = torch.triangular_solve(identity, lower, upper=False, unitriangular=True)
weight_inverse, _ = torch.triangular_solve(lower_inverse, upper, upper=True, unitriangular=False)
return weight_inverse

@property
Expand All @@ -117,4 +117,4 @@ def logabsdet(self):
where:
D = num of features
"""
return torch.sum(torch.log(self.upper_diag))
return torch.sum(torch.log(self.upper_diag))
2 changes: 1 addition & 1 deletion nde/transforms/made.py → nsflow/nde/transforms/made.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

import utils
from nsflow import utils

from torch import nn
from torch.nn import functional as F, init
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from torch import nn
from torch.nn import functional as F

import utils
from nsflow import utils

from nde import transforms
from nde.transforms import splines
from nsflow.nde import transforms
from nsflow.nde.transforms import splines

class Tanh(transforms.Transform):
def forward(self, inputs, context=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from torch import nn
from torch.nn import functional as F

import utils
from nsflow import utils

from nde import transforms
from nsflow.nde import transforms


# class BatchNorm(transforms.Transform):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from torch import nn

import utils
from nsflow import utils

from nde import transforms
from nsflow.nde import transforms


class HouseholderSequence(transforms.Transform):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Implementations of permutation-like transforms."""

import torch
import utils
from nsflow import utils

from nde import transforms
from nsflow.nde import transforms


class Permutation(transforms.Transform):
Expand Down
8 changes: 4 additions & 4 deletions nde/transforms/qr.py → nsflow/nde/transforms/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch import nn
from torch.nn import functional as F, init

from nde import transforms
from nde.transforms.linear import Linear
from nsflow.nde import transforms
from nsflow.nde.transforms.linear import Linear


class QRLinear(Linear):
Expand Down Expand Up @@ -70,7 +70,7 @@ def inverse_no_cache(self, inputs):
upper = self._create_upper()
outputs = inputs - self.bias
outputs, _ = self.orthogonal.inverse(outputs) # Ignore logabsdet since we know it's zero.
outputs, _ = torch.trtrs(outputs.t(), upper, upper=True)
outputs, _ = torch.triangular_solve(outputs.t(), upper, upper=True)
outputs = outputs.t()
logabsdet = -self.logabsdet()
logabsdet = logabsdet * torch.ones(outputs.shape[0])
Expand All @@ -96,7 +96,7 @@ def weight_inverse(self):
"""
upper = self._create_upper()
identity = torch.eye(self.features, self.features)
upper_inv, _ = torch.trtrs(identity, upper, upper=True)
upper_inv, _ = torch.triangular_solve(identity, upper, upper=True)
weight_inv, _ = self.orthogonal(upper_inv)
return weight_inv

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch

import utils
from nde import transforms
from nsflow import utils

from nsflow.nde import transforms


class SqueezeTransform(transforms.Transform):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.nn import functional as F

import utils
from nde import transforms
from nsflow import utils
from nsflow.nde import transforms

DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
Expand Down
Loading