-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
how to implement cdist #378
Comments
Interesting problem. The best way to implement this is indeed with a custom CUDA kernel, in which way you can drop the for-loop in option2/3 and avoid the memory blowup of option1. I think option3 is okay if you don't want to go into this rabbit hole of custom CUDA. |
So i am currently busy implementing a cuda kernel that computes the euclidean distance between two matrices e.g. similar to torch.cdist. Unfortunately i ran into a problem while implementing the backward path when computing the local gradients: Some definitions: Let The derivative: Then there can be 4 cases that need to be treated differently with sparse matrices(Ignoring that the euclidean distance is not differentiable at zero for now): The problematic case is case 4! Here is the autograd boilerplate i am currently using: using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SparseCdist : public torch::autograd::Function<SparseCdist> {
public: static variable_list forward(
AutogradContext *ctx,
torch::Tensor a_rowptr_data,
torch::Tensor a_col_data,
torch::Tensor a_value_data,
torch::Tensor b_rowptr_data,
torch::Tensor b_col_data,
torch::Tensor b_value_data,
int dim_a,
int dim_b
) {
auto out = sparse_cdist(a_rowptr_data, a_col_data, a_value_data, b_rowptr_data, b_col_data, b_value_data, dim_a, dim_b);
ctx->saved_data["dim_a"] = dim_a;
ctx->saved_data["dim_b"] = dim_b;
ctx->save_for_backward({a_rowptr_data, a_col_data, a_value_data, b_rowptr_data, b_col_data, b_value_data, out});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto dim_a = ctx->saved_data["dim_a"].toInt();
auto dim_b = ctx->saved_data["dim_b"].toInt();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto a_rowptr_data = saved[0], a_col_data = saved[1], a_value_data = saved[2], b_rowptr_data = saved[3],
b_col_data = saved[4], b_value_data = saved[5], distance = saved[5];
auto grad_value_a = Variable();
if (torch::autograd::any_variable_requires_grad({a_value_data})){
std::cout << "grad_outs is: " << grad_out;
grad_value_a = sparse_bw_cdist(a_rowptr_data, a_col_data, a_value_data, b_rowptr_data, b_col_data, b_value_data, grad_out, distance, dim_a, dim_b);
}
auto grad_value_b = Variable();
if (torch::autograd::any_variable_requires_grad({b_value_data})){
std::cout << "grad_outs is: " << grad_out;
grad_value_b = sparse_bw_cdist(b_rowptr_data, b_col_data, b_value_data,a_rowptr_data, a_col_data, a_value_data, grad_out, distance, dim_b, dim_a);
}
return {Variable(), Variable(), grad_value_a,
Variable(), Variable(), grad_value_b, Variable(), Variable()};
}
}; What i would want to do is overwrite |
Maybe one workaround is to explicitly introduce zeros to X where X is implicitly is zero and Y is none zero? |
Hi there, thanks a million for this library.
I am trying to figure out how compute distances between every row(or column) of two matrices e.g. like
torch.cdist
.Here is one way:
and another:
and another:
where
sparse_distance
is defined as followsThe first one has the disadvantage that it creates huge matrices and eats a lot of memory, while the second one doesn't benefit from gpu parallelism. The third seems to be okish but maybe there is an even better solution? Ideally one would only load matrices a and b onto the gpu and only reserve additional memory for the result.
Maybe somebody has an idea how to do that with what is currently possible with torch sparse?
Or would it be necessary to write a specific cuda kernel for that?
Any suggestions are very welcome.
The text was updated successfully, but these errors were encountered: