Skip to content

Commit

Permalink
Work on extending nd to arbitrary n per dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 22, 2024
1 parent f6ad178 commit 5e955f1
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/block_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, cfg: DictConfig):

layer1 = high_order_fc_layers(
layer_type=cfg.layer_type,
n=n,
n=[3,n,n],
in_features=1,
out_features=10,
intialization="constant_random",
Expand Down
2 changes: 1 addition & 1 deletion examples/xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
if layer_type == "polynomial_2d":
layer1 = high_order_fc_layers(
layer_type=layer_type,
n=n,
n=[n,n],
in_features=2,
out_features=1,
segments=segments,
Expand Down
116 changes: 92 additions & 24 deletions high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math

from typing import List, Union
import torch
from torch import Tensor

Expand All @@ -22,31 +22,94 @@ def chebyshevLobatto(n: int):
return -torch.cos(torch.pi * torch.arange(n) / (n - 1))


# class LagrangeBasisND:
# """
# Single N dimensional element with Lagrange basis interpolation.
# """
# def __init__(
# self, n: int, length: float = 2.0, dimensions: int = 2, device: str = "cpu", **kwargs
# ):
# self.n = n
# self.dimensions = dimensions
# self.X = (length / 2.0) * chebyshevLobatto(n).to(device)
# self.device = device
# self.denominators = self._compute_denominators()
# self.num_basis = int(math.pow(n, dimensions))

# a = torch.arange(n)
# self.indexes = (
# torch.stack(torch.meshgrid([a] * dimensions, indexing="ij"))
# .reshape(dimensions, -1)
# .T.long().to(self.device)
# )

# def _compute_denominators(self):
# X_diff = self.X.unsqueeze(0) - self.X.unsqueeze(1) # [n, n]
# denom = torch.where(X_diff == 0, torch.tensor(1.0, device=self.device), X_diff)
# return denom

# def _compute_basis(self, x, indexes):
# """
# Computes the basis values for all index combinations.
# :param x: [batch, inputs, dimensions]
# :param indexes: [num_basis, dimensions]
# :returns: basis values [num_basis, batch, inputs]
# """
# x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, n]
# mask = (indexes.unsqueeze(1).unsqueeze(2).unsqueeze(4) != torch.arange(self.n, device=self.device).view(1, 1, 1, 1, self.n))
# denominators = self.denominators[indexes] # [num_basis, dimensions, n]

# b = torch.where(mask, x_diff.unsqueeze(0) / denominators.unsqueeze(1).unsqueeze(2), torch.tensor(1.0, device=self.device))
# #print('b.shape', b.shape)
# r = torch.prod(b, dim=-1) # [num_basis, batch, inputs, dimensions]

# return r.prod(dim=-1) # [num_basis, batch, inputs]

# def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
# """
# Interpolates the input using the Lagrange basis.
# :param x: size[batch, inputs, dimensions]
# :param w: size[output, inputs, num_basis]
# :returns: size[batch, output]
# """
# basis = self._compute_basis(x, self.indexes) # [num_basis, batch, inputs]
# #print('bassis.shape', basis.shape, 'w.shape', w.shape)
# out_sum = torch.einsum("ibk,oki->bo", basis, w) # [batch, output]

# return out_sum


import torch
import math
from typing import List

class LagrangeBasisND:
"""
Single N dimensional element with Lagrange basis interpolation.
Supports different n values for each dimension.
"""
def __init__(
self, n: int, length: float = 2.0, dimensions: int = 2, device: str = "cpu", **kwargs
self, n: Union[List[int],int], length: float = 2.0, device: str = "cpu", **kwargs
):
self.n = n
self.dimensions = dimensions
self.X = (length / 2.0) * chebyshevLobatto(n).to(device)
self.dimensions = len(n)
self.X = [(length / 2.0) * chebyshevLobatto(ni).to(device) for ni in n]
self.device = device
self.denominators = self._compute_denominators()
self.num_basis = int(math.pow(n, dimensions))
self.num_basis = math.prod(n)

a = torch.arange(n)
self.indexes = (
torch.stack(torch.meshgrid([a] * dimensions, indexing="ij"))
.reshape(dimensions, -1)
.T.long().to(self.device)
)
self.indexes = self._compute_indexes()

def _compute_denominators(self):
X_diff = self.X.unsqueeze(0) - self.X.unsqueeze(1) # [n, n]
denom = torch.where(X_diff == 0, torch.tensor(1.0, device=self.device), X_diff)
return denom
return [
torch.where(X_diff == 0, torch.tensor(1.0, device=self.device), X_diff)
for X_diff in [Xi.unsqueeze(0) - Xi.unsqueeze(1) for Xi in self.X]
]

def _compute_indexes(self):
ranges = [torch.arange(ni) for ni in self.n]
meshgrid = torch.stack(torch.meshgrid(ranges, indexing="ij"))
return meshgrid.reshape(self.dimensions, -1).T.long().to(self.device)

def _compute_basis(self, x, indexes):
"""
Expand All @@ -55,16 +118,24 @@ def _compute_basis(self, x, indexes):
:param indexes: [num_basis, dimensions]
:returns: basis values [num_basis, batch, inputs]
"""
x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, n]
mask = (indexes.unsqueeze(1).unsqueeze(2).unsqueeze(4) != torch.arange(self.n, device=self.device).view(1, 1, 1, 1, self.n))
denominators = self.denominators[indexes] # [num_basis, dimensions, n]
b_list = []
for d in range(self.dimensions):
x_diff = x[..., d].unsqueeze(-1) - self.X[d] # [batch, inputs, n[d]]
mask = (indexes[:, d].unsqueeze(1).unsqueeze(2) != torch.arange(self.n[d], device=self.device))
denominators = self.denominators[d][indexes[:, d]] # [num_basis, n[d]]

b = torch.where(mask, x_diff.unsqueeze(0) / denominators.unsqueeze(1).unsqueeze(2), torch.tensor(1.0, device=self.device))
#print('b.shape', b.shape)
r = torch.prod(b, dim=-1) # [num_basis, batch, inputs, dimensions]
# Reshape x_diff and denominators for proper broadcasting
x_diff_expanded = x_diff.unsqueeze(0) # [1, batch, inputs, n[d]]
denominators_expanded = denominators.unsqueeze(1).unsqueeze(2) # [num_basis, 1, 1, n[d]]

return r.prod(dim=-1) # [num_basis, batch, inputs]
# Ensure mask has the correct shape
mask = mask.unsqueeze(1) # [num_basis, 1, 1, n[d]]

b = torch.where(mask, x_diff_expanded / denominators_expanded, torch.tensor(1.0, device=self.device))
b_list.append(torch.prod(b, dim=-1)) # [num_basis, batch, inputs]

return torch.prod(torch.stack(b_list), dim=0) # [num_basis, batch, inputs]

def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Interpolates the input using the Lagrange basis.
Expand All @@ -73,13 +144,10 @@ def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
:returns: size[batch, output]
"""
basis = self._compute_basis(x, self.indexes) # [num_basis, batch, inputs]
#print('bassis.shape', basis.shape, 'w.shape', w.shape)
out_sum = torch.einsum("ibk,oki->bo", basis, w) # [batch, output]

return out_sum



class FourierBasis:
def __init__(self, length: float):
"""
Expand Down

0 comments on commit 5e955f1

Please sign in to comment.