Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added device option to Wasserstein #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

sAbhay
Copy link

@sAbhay sAbhay commented Jun 20, 2022

Updated wasserstein submodule to include device

@dfdazac
Copy link
Owner

dfdazac commented Aug 25, 2022

@sAbhay thank you for your contribution! Sorry for the long delay on my response.
To keep compatibility with distributed training, where computations could run on different devices, I think it would be better to grab the device during the forward pass, rather than fixing it during initialization. For example,

def forward(self, x, y):
    device = x.device     
    ...

    mu = torch.empty(batch_size, x_points, dtype=torch.float,
                     requires_grad=False).fill_(1.0 / x_points).squeeze()
                     requires_grad=False, device=device).fill_(1.0 / x_points).squeeze()
    ...

This way, computations will run in whatever device x might be. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants