Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem when dim x not same as dim y #13

Open
monabf opened this issue Oct 25, 2021 · 3 comments
Open

Problem when dim x not same as dim y #13

monabf opened this issue Oct 25, 2021 · 3 comments

Comments

@monabf
Copy link

monabf commented Oct 25, 2021

Hi,

Thanks a lot for the implementation. However, I have a problem with the dimensions. For example, x is a 1D grid of 100 points, y is a 2D grid of 100 points, and I want to interpolate the 2D value of y at one point in x denoted xnew. Then I should obtain a 2D value for ynew, which should be the interpolation of each axis of y at the interpolated coordinate xnew. However, I obtain a scalar ynew. Conversely, when x is a 2D grid and y is a 1D grid, I obtain ynew of dimension 2 for xnew of dimension 2. Am I missing something here?

Below two small working examples, with respectively

  • x (1,N), y (2,N), xnew (1,P), which should give ynew (2,P)
  • x (N,), y (2,N), xnew (P,), which should also give ynew (2,P)

Thanks a lot for your help!

import torch
import matplotlib.pyplot as plt
import time
import numpy as np
from torchinterp1d import Interp1d


if __name__ == "__main__":
    # problem dimensions
    Dx = 1
    Dy = 2
    N = 10
    P = 1

    yq_gpu = None
    yq_cpu = None
    x = torch.linspace(0, 10, Dx*N).view(Dx,N)
    y = torch.linspace(0, 10, Dy*N).view(Dy, -1)
    xnew = torch.rand(Dx, P) * 10

    # calling the cpu version
    t0_cpu = time.time()
    yq_cpu = Interp1d()(x, y, xnew, yq_cpu)
    t1_cpu = time.time()

    print(x.shape, y.shape, xnew.shape, yq_cpu.shape)
    print(x, y, xnew, yq_cpu)



    # problem dimensions
    Dy = 2
    N = 10
    P = 1

    yq_gpu = None
    yq_cpu = None
    x = torch.linspace(0, 10, N).view(N,)
    y = torch.linspace(0, 10, Dy*N).view(Dy, -1)
    xnew = torch.rand(P) * 10

    # calling the cpu version
    t0_cpu = time.time()
    yq_cpu = Interp1d()(x, y, xnew, yq_cpu)
    t1_cpu = time.time()

    print(x.shape, y.shape, xnew.shape, yq_cpu.shape)
    print(x, y, xnew, yq_cpu)

Output:

torch.Size([1, 10]) torch.Size([2, 10]) torch.Size([1, 1]) torch.Size([1, 1])
tensor([[ 0.0000,  1.1111,  2.2222,  3.3333,  4.4444,  5.5556,  6.6667,  7.7778,
          8.8889, 10.0000]]) tensor([[ 0.0000,  0.5263,  1.0526,  1.5789,  2.1053,  2.6316,  3.1579,  3.6842,
          4.2105,  4.7368],
        [ 5.2632,  5.7895,  6.3158,  6.8421,  7.3684,  7.8947,  8.4211,  8.9474,
          9.4737, 10.0000]]) tensor([[5.6551]]) tensor([[2.6787]])
torch.Size([10]) torch.Size([2, 10]) torch.Size([1]) torch.Size([1, 1])
tensor([ 0.0000,  1.1111,  2.2222,  3.3333,  4.4444,  5.5556,  6.6667,  7.7778,
         8.8889, 10.0000]) tensor([[ 0.0000,  0.5263,  1.0526,  1.5789,  2.1053,  2.6316,  3.1579,  3.6842,
          4.2105,  4.7368],
        [ 5.2632,  5.7895,  6.3158,  6.8421,  7.3684,  7.8947,  8.4211,  8.9474,
          9.4737, 10.0000]]) tensor([0.9259]) tensor([[0.4386]])
@aliutkus
Copy link
Owner

dear @monabf, maybe you should just duplicate the xnew variable ?

you're right, the doc from the readme and the docstring from the function don't seem to match. As it is implemented now, it looks like out is same shape as xnew. I may have a look into this

@monabf
Copy link
Author

monabf commented Oct 25, 2021

Hi @aliutkus, you're right, using x = x.expand(Dy,-1) seems to solve the problem and avoids slicing away part of the y. Thanks for looking into it! It would be great to have this corrected or documented indeed

@zfhxi
Copy link

zfhxi commented Mar 27, 2023

To explore which of x and xnew should be expanded, I write a demo:

import torch
from utils.gpu.torch_interp1d import interp1d
import scipy.interpolate as syinterpolate

if __name__ == "__main__":
    N = 100
    D = 1024
    P = 30
    x = torch.arange(N).view(-1, N)  # -> 1,N
    y = torch.randn([D, N])
    xnew = torch.linspace(0, N, P)
    ynew = interp1d(x, y, xnew)
    ynewv2 = interp1d(x.expand(D, -1), y, xnew)
    ynewv3 = interp1d(x, y, xnew.expand(D, -1))
    ynewv4 = interp1d(x.expand(D, -1), y, xnew.expand(D, -1))

    # scipy interp1d
    x = x.squeeze().cpu().numpy()  # N
    y = y.transpose(-1, -2).cpu().numpy()  # N,D
    xnew = xnew.cpu().numpy()  # P
    fit_func = syinterpolate.interp1d(x, y, kind="linear", axis=0, fill_value="extrapolate")
    ynewv5 = fit_func(xnew)
    ynewv5 = torch.from_numpy(ynewv5).float().transpose(-1, -2)  # D,P

    print(ynew.shape)  # 1,P
    print(ynewv2.shape)  # D,P
    print(ynewv3.shape)  # D,P
    print(ynewv4.shape)  # D,P
    # print((ynewv2 == ynewv3).all())  # False
    # print((ynewv2 == ynewv4).all())  # True
    # print((ynewv2 == ynewv5).all())  # may be False, caused by computational accuracy of the cpu

    print((ynewv2 - ynewv3).abs().sum())  # >0
    print((ynewv2 - ynewv4).abs().sum())  # 0
    print((ynewv2 - ynewv5).abs().sum())  # slightly close to 0
    print((ynewv3 - ynewv4).abs().sum())  # >0
    print((ynewv3 - ynewv5).abs().sum())  # >0
    print((ynewv4 - ynewv5).abs().sum())  # slightly close to 0

    # conclusion: ynewv2 is equal to ynewv4 and they both are close to ynewv5

So, we should expand x.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants