diff --git a/wsingular/__init__.py b/wsingular/__init__.py index 0cde53e..3d569a7 100644 --- a/wsingular/__init__.py +++ b/wsingular/__init__.py @@ -1,5 +1,4 @@ # Imports. -from torch.utils.tensorboard import SummaryWriter import torch import numpy as np from typing import Callable, Tuple @@ -14,7 +13,7 @@ def wasserstein_singular_vectors( n_iter: int, tau: float = 1e-3, p: int = 1, - writer: SummaryWriter = None, + writer = None, small_value: float = 1e-6, normalization_steps: int = 1, C_ref: torch.Tensor = None, @@ -146,7 +145,7 @@ def sinkhorn_singular_vectors( tau: float = 1e-3, eps: float = 5e-2, p: int = 1, - writer: SummaryWriter = None, + writer = None, small_value: float = 1e-6, normalization_steps: int = 1, C_ref: torch.Tensor = None, @@ -283,7 +282,7 @@ def stochastic_wasserstein_singular_vectors( p: int = 1, step_fn: Callable = lambda k: 1 / np.sqrt(k), mult_update: bool = False, - writer: SummaryWriter = None, + writer = None, small_value: float = 1e-6, normalization_steps: int = 1, C_ref: torch.Tensor = None, @@ -485,7 +484,7 @@ def stochastic_sinkhorn_singular_vectors( p: int = 1, step_fn: Callable = lambda k: 1 / np.sqrt(k), mult_update: bool = False, - writer: SummaryWriter = None, + writer = None, small_value: float = 1e-6, normalization_steps: int = 1, C_ref: torch.Tensor = None,