You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For stability, you compute the max_val = torch.max((exp_term_tmp).clamp(min=0), dim=1, keepdim=True)[0], which means you compute the max value for the first sample(max_val has size 1 because of [0] operation). But I think max value should be computed for each sample seperately without [0] operation, then the max_val has size of N*1. Is it right?
The text was updated successfully, but these errors were encountered:
The size of max_val is actually N*1, within the [0] operation. Because
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns the maximum value of each row of the input tensor in the given dimension dim. The second return value is the index location of each maximum value found (argmax).
@thbupt sorry for the late reply. @muchuanyun is correct. torch.max returns a tuple where the first element are the max values and the second element are the indices.
For stability, you compute the max_val = torch.max((exp_term_tmp).clamp(min=0), dim=1, keepdim=True)[0], which means you compute the max value for the first sample(max_val has size 1 because of [0] operation). But I think max value should be computed for each sample seperately without [0] operation, then the max_val has size of N*1. Is it right?
The text was updated successfully, but these errors were encountered: