diff --git a/src/pipelines/utils.py b/src/pipelines/utils.py index 3cd5076..3d2771a 100644 --- a/src/pipelines/utils.py +++ b/src/pipelines/utils.py @@ -1,29 +1,22 @@ import torch -tensor_interpolation = None - - -def get_tensor_interpolation_method(): - return tensor_interpolation - - -def set_tensor_interpolation_method(is_slerp): - global tensor_interpolation - tensor_interpolation = slerp if is_slerp else linear - - def linear(v1, v2, t): return (1.0 - t) * v1 + t * v2 -def slerp( - v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 -) -> torch.Tensor: +def slerp(v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995) -> torch.Tensor: u0 = v0 / v0.norm() u1 = v1 / v1.norm() dot = (u0 * u1).sum() if dot.abs() > DOT_THRESHOLD: - # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') return (1.0 - t) * v0 + t * v1 omega = dot.acos() return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() + + +tensor_interpolation = slerp + + +def get_tensor_interpolation_method(): + return tensor_interpolation +