diff --git a/soft_dtw_cuda.py b/soft_dtw_cuda.py index 4e4fef0..5e7bf57 100755 --- a/soft_dtw_cuda.py +++ b/soft_dtw_cuda.py @@ -61,8 +61,12 @@ def compute_softdtw_cuda(D, gamma, bandwidth, max_i, max_j, n_passes, R): # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds if I + J == p and (I < max_i and J < max_j): + # Don't compute if outside bandwidth - if not (abs(i - j) > bandwidth > 0): + i_sc, j_sc = i, j + if max_j > max_i: i_sc = i * max_j / max_i + if max_j < max_i: j_sc = j * max_i / max_j + if not (abs(i_sc - j_sc) > bandwidth > 0): r0 = -R[b, i - 1, j - 1] * inv_gamma r1 = -R[b, i - 1, j] * inv_gamma r2 = -R[b, i, j - 1] * inv_gamma @@ -101,7 +105,10 @@ def compute_softdtw_backward_cuda(D, R, inv_gamma, bandwidth, max_i, max_j, n_pa R[k, i, j] = -math.inf # Don't compute if outside bandwidth - if not (abs(i - j) > bandwidth > 0): + i_sc, j_sc = i, j + if max_j > max_i: i_sc = i * max_j / max_i + if max_j < max_i: j_sc = j * max_i / max_j + if not (abs(i_sc - j_sc) > bandwidth > 0): a = math.exp((R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) * inv_gamma) b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) * inv_gamma) c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) * inv_gamma) @@ -193,7 +200,10 @@ def compute_softdtw(D, gamma, bandwidth): for i in range(1, N + 1): # Check the pruning condition - if 0 < bandwidth < np.abs(i - j): + i_sc, j_sc = i, j + if M > N: i_sc = i * M / N + if M < N: j_sc = j * N / M + if 0 < bandwidth < np.abs(i_sc - j_sc): continue r0 = -R[b, i - 1, j - 1] / gamma @@ -226,7 +236,10 @@ def compute_softdtw_backward(D_, R, gamma, bandwidth): R[k, i, j] = -np.inf # Check the pruning condition - if 0 < bandwidth < np.abs(i - j): + i_sc, j_sc = i, j + if M > N: i_sc = i * M / N + if M < N: j_sc = j * N / M + if 0 < bandwidth < np.abs(i_sc - j_sc): continue a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma