Skip to content

Commit

Permalink
Finish debugging basisFlatND
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 18, 2024
1 parent ae91ef5 commit 591e5bd
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
7 changes: 4 additions & 3 deletions high_order_layers_torch/Basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def __init__(
self.basis = basis
self.dimensions = dimensions
a = torch.arange(n)
self.indexes = torch.stack(torch.meshgrid([a]*dimensions)).reshape(dimensions, -1).T
self.indexes = torch.stack(torch.meshgrid([a]*dimensions)).reshape(dimensions, -1).T.long()

def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
"""
Expand All @@ -429,11 +429,12 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
"""

basis = []
for index in range(self.indexes):
for index in self.indexes:
basis_j = self.basis(x, index=index)
basis.append(basis_j)
basis = torch.stack(basis)
out_sum = torch.einsum("ijk,lki->jl", basis, w)

out_sum = torch.einsum("ijk,kli->jl", basis, w)

return out_sum

Expand Down
18 changes: 12 additions & 6 deletions high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,32 @@ def __init__(self, n: int, length: float = 2.0, dimensions: int = 2):
self.dimensions = dimensions
self.X = (length / 2.0) * chebyshevLobatto(n)
self.denominators = self._compute_denominators()
self.num_basis = int(math.pow(n, dimensions))

def _compute_denominators(self):
denom = torch.ones((self.n, self.n), dtype=torch.float32)
denom = torch.ones([self.n, self.n], dtype=torch.float32)

for j in range(self.n):
for m in range(self.n):
if m != j:
denom[j, m] = self.X[j] - self.X[m]
return denom

def __call__(self, x, index: list[int]):
x_diff = x.unsqueeze(-1) - self.X

"""
:param x: [batch, inputs, dimensions]
:param index : [dimensions]
:returns: basis value [batch, inputs]
"""
x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, basis]
r = 1.0
for i, basis_i in enumerate(index) :
for i, basis_i in enumerate(index):
b = torch.where(
torch.arange(self.n) != basis_i,
x_diff[:, i, :] / self.denominators[basis_i],
x_diff[:, :,i, :] / self.denominators[basis_i],
torch.tensor(1.0),
)
r*=torch.prod(b, dim=-1)
r *= torch.prod(b, dim=-1)

return r

Expand Down
14 changes: 13 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,23 @@ def test_variable_dimension_input(n, in_features, out_features, segments):
def test_basis_nd() :
dimensions = 3
n=5

inputs =2
outputs = 7
batch = 13

lb = LagrangeBasisND(n=n, dimensions=dimensions)
basis = BasisFlatND(n=n, dimensions=dimensions, basis=lb)

num_basis = int(math.pow(n, dimensions))

# The indexes should be unique so we cover all indices
assert len(set(basis.indexes)) == math.pow(5, dimensions)
assert len(set(basis.indexes)) == num_basis

x = torch.rand((batch, inputs, dimensions))
weights = torch.rand((inputs, outputs, num_basis))
result = basis.interpolate(x,weights)
print('result.shape', result.shape)


@pytest.mark.parametrize("dimensions", [1, 2, 3, 4])
Expand Down

0 comments on commit 591e5bd

Please sign in to comment.