diff --git a/docs/changelogs/v3.4.1.md b/docs/changelogs/v3.4.1.md index c8e656664..b0e87d378 100644 --- a/docs/changelogs/v3.4.1.md +++ b/docs/changelogs/v3.4.1.md @@ -16,6 +16,7 @@ * change default beta1, beta2 to 0.95 and 0.98 respectively * Skip adding `Lookahead` wrapper in case of `Ranger*` optimizers, which already have it in `create_optimizer()`. (#340) * Improved optimizer visualization. (#345) +* Rename `pytorch_optimizer.optimizer.gc` to `pytorch_optimizer.optimizer.gradient_centralization` to avoid possible conflict with Python built-in function `gc`. (#349) ### Bug diff --git a/docs/visualizations/rastrigin_ADOPT.png b/docs/visualizations/rastrigin_ADOPT.png index 42b2c652b..8c4f6e719 100644 Binary files a/docs/visualizations/rastrigin_ADOPT.png and b/docs/visualizations/rastrigin_ADOPT.png differ diff --git a/docs/visualizations/rosenbrock_ADOPT.png b/docs/visualizations/rosenbrock_ADOPT.png index 81c9d0974..668ccc723 100644 Binary files a/docs/visualizations/rosenbrock_ADOPT.png and b/docs/visualizations/rosenbrock_ADOPT.png differ diff --git a/examples/visualize_optimizers.py b/examples/visualize_optimizers.py index 387d68b5c..4638c635c 100644 --- a/examples/visualize_optimizers.py +++ b/examples/visualize_optimizers.py @@ -1,8 +1,8 @@ import math -import warnings from functools import partial from pathlib import Path -from typing import Callable, Dict, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union +from warnings import filterwarnings import numpy as np import torch @@ -14,23 +14,23 @@ from pytorch_optimizer.optimizer import OPTIMIZERS from pytorch_optimizer.optimizer.alig import l2_projection -warnings.filterwarnings('ignore', category=UserWarning) +filterwarnings('ignore', category=UserWarning) OPTIMIZERS_IGNORE = ('lomo', 'adalomo', 'demo', 'a2grad', 'alig') # BUG: fix `alig`, invalid .__name__ OPTIMIZERS_MODEL_INPUT_NEEDED = ('lomo', 'adalomo', 'adammini') OPTIMIZERS_GRAPH_NEEDED = ('adahessian', 'sophiah') OPTIMIZERS_CLOSURE_NEEDED = ('alig', 'bsam') -EVAL_PER_HYPYPERPARAM = 540 -OPTIMIZATION_STEPS = 300 -TESTING_OPTIMIZATION_STEPS = 650 -DIFFICULT_RASTRIGIN = False -USE_AVERAGE_LOSS_PENALTY = True -AVERAGE_LOSS_PENALTY_FACTOR = 1.0 -SEARCH_SEED = 42 -LOSS_MIN_TRESH = 0 - -default_search_space = {'lr': hp.uniform('lr', 0, 2)} -special_search_spaces = { +EVAL_PER_HYPERPARAM: int = 540 +OPTIMIZATION_STEPS: int = 300 +TESTING_OPTIMIZATION_STEPS: int = 650 +DIFFICULT_RASTRIGIN: bool = False +USE_AVERAGE_LOSS_PENALTY: bool = True +AVERAGE_LOSS_PENALTY_FACTOR: float = 1.0 +SEARCH_SEED: int = 42 +LOSS_MIN_THRESHOLD: float = 0.0 + +DEFAULT_SEARCH_SPACES = {'lr': hp.uniform('lr', 0, 2)} +SPECIAL_SEARCH_SPACES = { 'adafactor': {'lr': hp.uniform('lr', 0, 10)}, 'adams': {'lr': hp.uniform('lr', 0, 10)}, 'dadaptadagrad': {'lr': hp.uniform('lr', 0, 10)}, @@ -170,7 +170,7 @@ def execute_steps( optimizer_class: torch.optim.Optimizer, optimizer_config: Dict, num_iters: int = 500, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, List[float]]: """ Execute optimization steps for a given configuration. @@ -201,7 +201,6 @@ def closure() -> float: return closure - # Initialize the model and optimizer model = Model(func, initial_state) parameters = list(model.parameters()) optimizer_name: str = optimizer_class.__name__.lower() @@ -218,30 +217,25 @@ def closure() -> float: elif optimizer_name == 'bsam': optimizer_config['num_data'] = 1 - # Special initialization for memory-efficient optimizers if optimizer_name in OPTIMIZERS_MODEL_INPUT_NEEDED: optimizer = optimizer_class(model, **optimizer_config) else: optimizer = optimizer_class(parameters, **optimizer_config) - # Track optimization path - losses = [] steps = torch.zeros((2, num_iters + 1), dtype=torch.float32) steps[:, 0] = model.x.detach() + losses = [] for i in range(1, num_iters + 1): optimizer.zero_grad() + loss = model() losses.append(loss.item()) - # Special handling for second-order optimizers - create_graph = optimizer_name in OPTIMIZERS_GRAPH_NEEDED - loss.backward(create_graph=create_graph) + loss.backward(create_graph=optimizer_name in OPTIMIZERS_GRAPH_NEEDED) - # Gradient clipping for stability nn.utils.clip_grad_norm_(parameters, 1.0) - # Closure required for certain optimizers closure = create_closure(loss) if optimizer_name in OPTIMIZERS_CLOSURE_NEEDED else None optimizer.step(closure) @@ -279,25 +273,19 @@ def objective( - A penalty for boundary violations. - An optional penalty for higher average loss during optimization (if enabled). """ - # Execute optimization steps and get losses - steps, losses = execute_steps( # Modified to unpack losses - criterion, initial_state, optimizer_class, params, num_iters - ) + steps, losses = execute_steps(criterion, initial_state, optimizer_class, params, num_iters) - # Calculate boundary violations x_min_violation = torch.clamp(x_bounds[0] - steps[0], min=0).max() x_max_violation = torch.clamp(steps[0] - x_bounds[1], min=0).max() y_min_violation = torch.clamp(y_bounds[0] - steps[1], min=0).max() y_max_violation = torch.clamp(steps[1] - y_bounds[1], min=0).max() total_violation = x_min_violation + x_max_violation + y_min_violation + y_max_violation - # Calculate average loss penalty - avg_loss = sum(losses) / len(losses) if losses else 0.0 penalty = 75 * total_violation.item() if USE_AVERAGE_LOSS_PENALTY: + avg_loss: float = sum(losses) / len(losses) if losses else 0.0 penalty += avg_loss * AVERAGE_LOSS_PENALTY_FACTOR - # Calculate final distance to minimum final_position = steps[:, -1] final_distance = ((final_position[0] - minimum[0]) ** 2 + (final_position[1] - minimum[1]) ** 2).item() @@ -309,7 +297,7 @@ def plot_function( optimization_steps: torch.Tensor, output_path: Path, optimizer_name: str, - params: dict, + params: Dict, x_range: Tuple[float, float], y_range: Tuple[float, float], minimum: Tuple[float, float], @@ -335,26 +323,21 @@ def plot_function( fig = plt.figure(figsize=(8, 8)) ax = fig.add_subplot(1, 1, 1) - # Plot function contours and optimization path ax.contour(x_grid.numpy(), y_grid.numpy(), z.numpy(), 20, cmap='jet') ax.plot(optimization_steps[0], optimization_steps[1], color='r', marker='x', markersize=3) - # Mark global minimum and final position plt.plot(*minimum, 'gD', label='Global Minimum') plt.plot(optimization_steps[0, -1], optimization_steps[1, -1], 'bD', label='Final Position') - ax.set_title( - f'{func.__name__} func: {optimizer_name} with {TESTING_OPTIMIZATION_STEPS} iterations\n{ - ", ".join(f"{k}={round(v, 4)}" for k, v in params.items()) - }' - ) + config: str = ', '.join(f'{k}={round(v, 4)}' for k, v in params.items()) + ax.set_title(f'{func.__name__} func: {optimizer_name} with {TESTING_OPTIMIZATION_STEPS} iterations\n{config}') plt.legend() plt.savefig(str(output_path)) plt.close() def execute_experiments( - optimizers: list, + optimizers: List, func: Callable, initial_state: Tuple[float, float], output_dir: Path, @@ -362,7 +345,7 @@ def execute_experiments( x_range: Tuple[float, float], y_range: Tuple[float, float], minimum: Tuple[float, float], - seed: int = 42, + seed: int = SEARCH_SEED, ) -> None: """ Run optimization experiments for multiple optimizers. @@ -382,15 +365,14 @@ def execute_experiments( optimizer_name = optimizer_class.__name__ output_path = output_dir / f'{experiment_name}_{optimizer_name}.png' if output_path.exists(): - continue # Skip already generated plots + continue print( # noqa: T201 f'({i}/{len(optimizers)}) Processing {optimizer_name}... (Params to tune: {", ".join(search_space.keys())})' # noqa: E501 ) - # Select hyperparameter search space - num_hyperparams = len(search_space) - max_evals = EVAL_PER_HYPYPERPARAM * num_hyperparams # Scale evaluations based on hyperparameter count + num_hyperparams: int = len(search_space) + max_evals: int = EVAL_PER_HYPERPARAM * num_hyperparams objective_fn = partial( objective, @@ -402,43 +384,38 @@ def execute_experiments( y_bounds=y_range, num_iters=OPTIMIZATION_STEPS, ) + try: best_params = fmin( fn=objective_fn, space=search_space, algo=tpe.suggest, max_evals=max_evals, - loss_threshold=LOSS_MIN_TRESH, + loss_threshold=LOSS_MIN_THRESHOLD, rstate=np.random.default_rng(seed), ) except AllTrialsFailed: print(f'⚠️ {optimizer_name} failed to optimize {func.__name__}') # noqa: T201 continue - # Run final optimization with best parameters - steps, _ = execute_steps( # Modified to ignore losses - func, initial_state, optimizer_class, best_params, TESTING_OPTIMIZATION_STEPS - ) + steps, _ = execute_steps(func, initial_state, optimizer_class, best_params, TESTING_OPTIMIZATION_STEPS) - # Generate and save visualization plot_function(func, steps, output_path, optimizer_name, best_params, x_range, y_range, minimum) def main(): - """Main execution routine for optimization experiments.""" np.random.seed(SEARCH_SEED) torch.manual_seed(SEARCH_SEED) + output_dir = Path('.') / 'docs' / 'visualizations' output_dir.mkdir(parents=True, exist_ok=True) - # Prepare the list of optimizers and their search spaces optimizers = [ - (optimizer, special_search_spaces.get(optimizer_name, default_search_space)) + (optimizer, SPECIAL_SEARCH_SPACES.get(optimizer_name, DEFAULT_SEARCH_SPACES)) for optimizer_name, optimizer in OPTIMIZERS.items() if optimizer_name not in OPTIMIZERS_IGNORE ] - # Run experiments for the Rastrigin function print('Executing Rastrigin experiments...') # noqa: T201 execute_experiments( optimizers, @@ -452,7 +429,6 @@ def main(): seed=SEARCH_SEED, ) - # Run experiments for the Rosenbrock function print('Executing Rosenbrock experiments...') # noqa: T201 execute_experiments( optimizers, diff --git a/poetry.lock b/poetry.lock index f52dab7b7..a153c1417 100644 --- a/poetry.lock +++ b/poetry.lock @@ -776,30 +776,30 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.9.4" +version = "0.9.6" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "ruff-0.9.4-py3-none-linux_armv6l.whl", hash = "sha256:64e73d25b954f71ff100bb70f39f1ee09e880728efb4250c632ceed4e4cdf706"}, - {file = "ruff-0.9.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6ce6743ed64d9afab4fafeaea70d3631b4d4b28b592db21a5c2d1f0ef52934bf"}, - {file = "ruff-0.9.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:54499fb08408e32b57360f6f9de7157a5fec24ad79cb3f42ef2c3f3f728dfe2b"}, - {file = "ruff-0.9.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37c892540108314a6f01f105040b5106aeb829fa5fb0561d2dcaf71485021137"}, - {file = "ruff-0.9.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:de9edf2ce4b9ddf43fd93e20ef635a900e25f622f87ed6e3047a664d0e8f810e"}, - {file = "ruff-0.9.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87c90c32357c74f11deb7fbb065126d91771b207bf9bfaaee01277ca59b574ec"}, - {file = "ruff-0.9.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:56acd6c694da3695a7461cc55775f3a409c3815ac467279dfa126061d84b314b"}, - {file = "ruff-0.9.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0c93e7d47ed951b9394cf352d6695b31498e68fd5782d6cbc282425655f687a"}, - {file = "ruff-0.9.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1d4c8772670aecf037d1bf7a07c39106574d143b26cfe5ed1787d2f31e800214"}, - {file = "ruff-0.9.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc5f1d7afeda8d5d37660eeca6d389b142d7f2b5a1ab659d9214ebd0e025231"}, - {file = "ruff-0.9.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faa935fc00ae854d8b638c16a5f1ce881bc3f67446957dd6f2af440a5fc8526b"}, - {file = "ruff-0.9.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a6c634fc6f5a0ceae1ab3e13c58183978185d131a29c425e4eaa9f40afe1e6d6"}, - {file = "ruff-0.9.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:433dedf6ddfdec7f1ac7575ec1eb9844fa60c4c8c2f8887a070672b8d353d34c"}, - {file = "ruff-0.9.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d612dbd0f3a919a8cc1d12037168bfa536862066808960e0cc901404b77968f0"}, - {file = "ruff-0.9.4-py3-none-win32.whl", hash = "sha256:db1192ddda2200671f9ef61d9597fcef89d934f5d1705e571a93a67fb13a4402"}, - {file = "ruff-0.9.4-py3-none-win_amd64.whl", hash = "sha256:05bebf4cdbe3ef75430d26c375773978950bbf4ee3c95ccb5448940dc092408e"}, - {file = "ruff-0.9.4-py3-none-win_arm64.whl", hash = "sha256:585792f1e81509e38ac5123492f8875fbc36f3ede8185af0a26df348e5154f41"}, - {file = "ruff-0.9.4.tar.gz", hash = "sha256:6907ee3529244bb0ed066683e075f09285b38dd5b4039370df6ff06041ca19e7"}, + {file = "ruff-0.9.6-py3-none-linux_armv6l.whl", hash = "sha256:2f218f356dd2d995839f1941322ff021c72a492c470f0b26a34f844c29cdf5ba"}, + {file = "ruff-0.9.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b908ff4df65dad7b251c9968a2e4560836d8f5487c2f0cc238321ed951ea0504"}, + {file = "ruff-0.9.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b109c0ad2ececf42e75fa99dc4043ff72a357436bb171900714a9ea581ddef83"}, + {file = "ruff-0.9.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1de4367cca3dac99bcbd15c161404e849bb0bfd543664db39232648dc00112dc"}, + {file = "ruff-0.9.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3ee4d7c2c92ddfdaedf0bf31b2b176fa7aa8950efc454628d477394d35638b"}, + {file = "ruff-0.9.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dc1edd1775270e6aa2386119aea692039781429f0be1e0949ea5884e011aa8e"}, + {file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4a091729086dffa4bd070aa5dab7e39cc6b9d62eb2bef8f3d91172d30d599666"}, + {file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1bbc6808bf7b15796cef0815e1dfb796fbd383e7dbd4334709642649625e7c5"}, + {file = "ruff-0.9.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:589d1d9f25b5754ff230dce914a174a7c951a85a4e9270613a2b74231fdac2f5"}, + {file = "ruff-0.9.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc61dd5131742e21103fbbdcad683a8813be0e3c204472d520d9a5021ca8b217"}, + {file = "ruff-0.9.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5e2d9126161d0357e5c8f30b0bd6168d2c3872372f14481136d13de9937f79b6"}, + {file = "ruff-0.9.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:68660eab1a8e65babb5229a1f97b46e3120923757a68b5413d8561f8a85d4897"}, + {file = "ruff-0.9.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c4cae6c4cc7b9b4017c71114115db0445b00a16de3bcde0946273e8392856f08"}, + {file = "ruff-0.9.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:19f505b643228b417c1111a2a536424ddde0db4ef9023b9e04a46ed8a1cb4656"}, + {file = "ruff-0.9.6-py3-none-win32.whl", hash = "sha256:194d8402bceef1b31164909540a597e0d913c0e4952015a5b40e28c146121b5d"}, + {file = "ruff-0.9.6-py3-none-win_amd64.whl", hash = "sha256:03482d5c09d90d4ee3f40d97578423698ad895c87314c4de39ed2af945633caa"}, + {file = "ruff-0.9.6-py3-none-win_arm64.whl", hash = "sha256:0e2bb706a2be7ddfea4a4af918562fdc1bcb16df255e5fa595bbd800ce322a5a"}, + {file = "ruff-0.9.6.tar.gz", hash = "sha256:81761592f72b620ec8fa1068a6fd00e98a5ebee342a3642efd84454f3031dca9"}, ] [[package]] @@ -1005,4 +1005,4 @@ markers = {dev = "python_version < \"3.11\""} [metadata] lock-version = "2.1" python-versions = ">=3.8" -content-hash = "988a06e75b30be66a6537f89e46b1bca92582533c6f092e97ff38a19c7732a2b" +content-hash = "8a4aff899b0e935e06f9ebafabd0218def8457eff8fe4bc7c4c6938b12667846" diff --git a/pyproject.toml b/pyproject.toml index 3a3c0dd04..f0f0f9628 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pytorch_optimizer" -version = "3.4.0" +version = "3.4.1" description = "optimizer & lr scheduler & objective function collections in PyTorch" license = "Apache-2.0" authors = ["kozistr "] @@ -97,7 +97,7 @@ select = [ "TID", "ARG", "ERA", "RUF", "YTT", "PL", "Q" ] ignore = [ - "A005", "B905", + "B905", "D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413", "PLR0912", "PLR0913", "PLR0915", "PLR2004", "Q003", "ARG002", diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index 201faab36..cfae9ec71 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -6,7 +6,7 @@ from torch.optim import Optimizer from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError -from pytorch_optimizer.base.types import ( +from pytorch_optimizer.base.type import ( BETAS, CLOSURE, DEFAULTS, diff --git a/pytorch_optimizer/base/types.py b/pytorch_optimizer/base/type.py similarity index 100% rename from pytorch_optimizer/base/types.py rename to pytorch_optimizer/base/type.py diff --git a/pytorch_optimizer/loss/dice.py b/pytorch_optimizer/loss/dice.py index 11731aca0..7a364d4e7 100644 --- a/pytorch_optimizer/loss/dice.py +++ b/pytorch_optimizer/loss/dice.py @@ -4,7 +4,7 @@ from torch.nn.functional import logsigmoid, one_hot from torch.nn.modules.loss import _Loss -from pytorch_optimizer.base.types import CLASS_MODE +from pytorch_optimizer.base.type import CLASS_MODE def soft_dice_score( diff --git a/pytorch_optimizer/loss/jaccard.py b/pytorch_optimizer/loss/jaccard.py index 292ac6b31..1c1d5a22d 100644 --- a/pytorch_optimizer/loss/jaccard.py +++ b/pytorch_optimizer/loss/jaccard.py @@ -4,7 +4,7 @@ from torch.nn.functional import logsigmoid, one_hot from torch.nn.modules.loss import _Loss -from pytorch_optimizer.base.types import CLASS_MODE +from pytorch_optimizer.base.type import CLASS_MODE def soft_jaccard_score( diff --git a/pytorch_optimizer/lr_scheduler/__init__.py b/pytorch_optimizer/lr_scheduler/__init__.py index de0bf5a40..055bb16f2 100644 --- a/pytorch_optimizer/lr_scheduler/__init__.py +++ b/pytorch_optimizer/lr_scheduler/__init__.py @@ -14,7 +14,7 @@ StepLR, ) -from pytorch_optimizer.base.types import SCHEDULER +from pytorch_optimizer.base.type import SCHEDULER from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_perm_steps, get_chebyshev_schedule from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler diff --git a/pytorch_optimizer/lr_scheduler/experimental/deberta_v3_lr_scheduler.py b/pytorch_optimizer/lr_scheduler/experimental/deberta_v3_lr_scheduler.py index 85ed14014..26e5b621f 100644 --- a/pytorch_optimizer/lr_scheduler/experimental/deberta_v3_lr_scheduler.py +++ b/pytorch_optimizer/lr_scheduler/experimental/deberta_v3_lr_scheduler.py @@ -1,6 +1,6 @@ from torch import nn -from pytorch_optimizer.base.types import PARAMETERS +from pytorch_optimizer.base.type import PARAMETERS def deberta_v3_large_lr_scheduler( diff --git a/pytorch_optimizer/optimizer/__init__.py b/pytorch_optimizer/optimizer/__init__.py index 056298037..827145fe8 100644 --- a/pytorch_optimizer/optimizer/__init__.py +++ b/pytorch_optimizer/optimizer/__init__.py @@ -7,7 +7,7 @@ from torch import nn from torch.optim import SGD, Adam, AdamW, Optimizer -from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS +from pytorch_optimizer.base.type import OPTIMIZER, PARAMETERS from pytorch_optimizer.optimizer.a2grad import A2Grad from pytorch_optimizer.optimizer.adabelief import AdaBelief from pytorch_optimizer.optimizer.adabound import AdaBound @@ -49,7 +49,7 @@ from pytorch_optimizer.optimizer.fromage import Fromage from pytorch_optimizer.optimizer.ftrl import FTRL from pytorch_optimizer.optimizer.galore import GaLore -from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.optimizer.gradient_centralization import centralize_gradient from pytorch_optimizer.optimizer.grams import Grams from pytorch_optimizer.optimizer.gravity import Gravity from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW diff --git a/pytorch_optimizer/optimizer/a2grad.py b/pytorch_optimizer/optimizer/a2grad.py index 46ac668fe..6ee13fd1c 100644 --- a/pytorch_optimizer/optimizer/a2grad.py +++ b/pytorch_optimizer/optimizer/a2grad.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS VARIANTS = Literal['uni', 'inc', 'exp'] diff --git a/pytorch_optimizer/optimizer/adabelief.py b/pytorch_optimizer/optimizer/adabelief.py index f42cc34c8..c3e73fd32 100644 --- a/pytorch_optimizer/optimizer/adabelief.py +++ b/pytorch_optimizer/optimizer/adabelief.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaBelief(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adabound.py b/pytorch_optimizer/optimizer/adabound.py index 6685904f5..b361ee61d 100644 --- a/pytorch_optimizer/optimizer/adabound.py +++ b/pytorch_optimizer/optimizer/adabound.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaBound(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adadelta.py b/pytorch_optimizer/optimizer/adadelta.py index 4b3dd79a5..6cf5295a3 100644 --- a/pytorch_optimizer/optimizer/adadelta.py +++ b/pytorch_optimizer/optimizer/adadelta.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaDelta(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adafactor.py b/pytorch_optimizer/optimizer/adafactor.py index 351af00b7..687e6686b 100644 --- a/pytorch_optimizer/optimizer/adafactor.py +++ b/pytorch_optimizer/optimizer/adafactor.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaFactor(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adahessian.py b/pytorch_optimizer/optimizer/adahessian.py index 4cce37ac5..99b1ace6f 100644 --- a/pytorch_optimizer/optimizer/adahessian.py +++ b/pytorch_optimizer/optimizer/adahessian.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS class AdaHessian(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adai.py b/pytorch_optimizer/optimizer/adai.py index c5b173bd2..8bb9e6681 100644 --- a/pytorch_optimizer/optimizer/adai.py +++ b/pytorch_optimizer/optimizer/adai.py @@ -4,8 +4,8 @@ from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.optimizer.gradient_centralization import centralize_gradient class Adai(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adalite.py b/pytorch_optimizer/optimizer/adalite.py index 3249826c1..f7610013e 100644 --- a/pytorch_optimizer/optimizer/adalite.py +++ b/pytorch_optimizer/optimizer/adalite.py @@ -3,7 +3,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class Adalite(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adam_mini.py b/pytorch_optimizer/optimizer/adam_mini.py index 28e33484a..2d6ee3641 100644 --- a/pytorch_optimizer/optimizer/adam_mini.py +++ b/pytorch_optimizer/optimizer/adam_mini.py @@ -7,7 +7,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS class AdamMini(BaseOptimizer): # pragma: no cover diff --git a/pytorch_optimizer/optimizer/adamax.py b/pytorch_optimizer/optimizer/adamax.py index b0579929b..5a7c5851d 100644 --- a/pytorch_optimizer/optimizer/adamax.py +++ b/pytorch_optimizer/optimizer/adamax.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaMax(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adamg.py b/pytorch_optimizer/optimizer/adamg.py index 4e774107f..a58e45b38 100644 --- a/pytorch_optimizer/optimizer/adamg.py +++ b/pytorch_optimizer/optimizer/adamg.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdamG(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adamod.py b/pytorch_optimizer/optimizer/adamod.py index 45781ebe9..e5a419fde 100644 --- a/pytorch_optimizer/optimizer/adamod.py +++ b/pytorch_optimizer/optimizer/adamod.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaMod(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adamp.py b/pytorch_optimizer/optimizer/adamp.py index 8e6e26ebb..0e2942d05 100644 --- a/pytorch_optimizer/optimizer/adamp.py +++ b/pytorch_optimizer/optimizer/adamp.py @@ -4,8 +4,8 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.optimizer.gradient_centralization import centralize_gradient from pytorch_optimizer.optimizer.utils import projection diff --git a/pytorch_optimizer/optimizer/adams.py b/pytorch_optimizer/optimizer/adams.py index d5b907825..65b1428fd 100644 --- a/pytorch_optimizer/optimizer/adams.py +++ b/pytorch_optimizer/optimizer/adams.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdamS(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adamw.py b/pytorch_optimizer/optimizer/adamw.py index fa1f73318..20711257b 100644 --- a/pytorch_optimizer/optimizer/adamw.py +++ b/pytorch_optimizer/optimizer/adamw.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class StableAdamW(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adan.py b/pytorch_optimizer/optimizer/adan.py index 7a24c2d93..26d1396d6 100644 --- a/pytorch_optimizer/optimizer/adan.py +++ b/pytorch_optimizer/optimizer/adan.py @@ -5,8 +5,8 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.optimizer.gradient_centralization import centralize_gradient from pytorch_optimizer.optimizer.utils import get_global_gradient_norm diff --git a/pytorch_optimizer/optimizer/adanorm.py b/pytorch_optimizer/optimizer/adanorm.py index 056890b4f..71e7ab296 100644 --- a/pytorch_optimizer/optimizer/adanorm.py +++ b/pytorch_optimizer/optimizer/adanorm.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaNorm(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adapnm.py b/pytorch_optimizer/optimizer/adapnm.py index 153984f80..bc1c2669a 100644 --- a/pytorch_optimizer/optimizer/adapnm.py +++ b/pytorch_optimizer/optimizer/adapnm.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaPNM(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adashift.py b/pytorch_optimizer/optimizer/adashift.py index d61f383fa..52a8e1a75 100644 --- a/pytorch_optimizer/optimizer/adashift.py +++ b/pytorch_optimizer/optimizer/adashift.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaShift(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adasmooth.py b/pytorch_optimizer/optimizer/adasmooth.py index 3b28ba91e..2f0b3affa 100644 --- a/pytorch_optimizer/optimizer/adasmooth.py +++ b/pytorch_optimizer/optimizer/adasmooth.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdaSmooth(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/ademamix.py b/pytorch_optimizer/optimizer/ademamix.py index b5a4addf1..1fb843ddb 100644 --- a/pytorch_optimizer/optimizer/ademamix.py +++ b/pytorch_optimizer/optimizer/ademamix.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AdEMAMix(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/adopt.py b/pytorch_optimizer/optimizer/adopt.py index 912176dbe..d0b672c34 100644 --- a/pytorch_optimizer/optimizer/adopt.py +++ b/pytorch_optimizer/optimizer/adopt.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class ADOPT(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/agc.py b/pytorch_optimizer/optimizer/agc.py index cea0a2f15..351e37352 100644 --- a/pytorch_optimizer/optimizer/agc.py +++ b/pytorch_optimizer/optimizer/agc.py @@ -14,11 +14,9 @@ def agc( :param agc_clip_val: float. norm clip. :param eps: float. simple stop from div by zero and no relation to standard optimizer eps. """ - p_norm = unit_norm(p).clamp_min_(agc_eps) + max_norm = unit_norm(p).clamp_min_(agc_eps).mul_(agc_clip_val) g_norm = unit_norm(grad).clamp_min_(eps) - max_norm = p_norm * agc_clip_val - clipped_grad = grad * (max_norm / g_norm) return torch.where(g_norm > max_norm, clipped_grad, grad) diff --git a/pytorch_optimizer/optimizer/aggmo.py b/pytorch_optimizer/optimizer/aggmo.py index a3c3b967c..ee861af46 100644 --- a/pytorch_optimizer/optimizer/aggmo.py +++ b/pytorch_optimizer/optimizer/aggmo.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AggMo(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/aida.py b/pytorch_optimizer/optimizer/aida.py index 5d790946f..11ee8f8c9 100644 --- a/pytorch_optimizer/optimizer/aida.py +++ b/pytorch_optimizer/optimizer/aida.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class Aida(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/alig.py b/pytorch_optimizer/optimizer/alig.py index 0240a5299..6ec3b20f6 100644 --- a/pytorch_optimizer/optimizer/alig.py +++ b/pytorch_optimizer/optimizer/alig.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoClosureError, NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import get_global_gradient_norm diff --git a/pytorch_optimizer/optimizer/amos.py b/pytorch_optimizer/optimizer/amos.py index 1212815fb..32628cdb7 100644 --- a/pytorch_optimizer/optimizer/amos.py +++ b/pytorch_optimizer/optimizer/amos.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class Amos(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/apollo.py b/pytorch_optimizer/optimizer/apollo.py index ae24af843..0865bc50d 100644 --- a/pytorch_optimizer/optimizer/apollo.py +++ b/pytorch_optimizer/optimizer/apollo.py @@ -6,7 +6,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector SCALE_TYPE = Literal['channel', 'tensor'] diff --git a/pytorch_optimizer/optimizer/avagrad.py b/pytorch_optimizer/optimizer/avagrad.py index f759e6ee1..6a5ef28ac 100644 --- a/pytorch_optimizer/optimizer/avagrad.py +++ b/pytorch_optimizer/optimizer/avagrad.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class AvaGrad(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/came.py b/pytorch_optimizer/optimizer/came.py index ce6308abe..39c1d6b86 100644 --- a/pytorch_optimizer/optimizer/came.py +++ b/pytorch_optimizer/optimizer/came.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class CAME(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/dadapt.py b/pytorch_optimizer/optimizer/dadapt.py index 5001d8ef4..4c60c1696 100644 --- a/pytorch_optimizer/optimizer/dadapt.py +++ b/pytorch_optimizer/optimizer/dadapt.py @@ -10,7 +10,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import get_global_gradient_norm, to_real diff --git a/pytorch_optimizer/optimizer/demo.py b/pytorch_optimizer/optimizer/demo.py index 32eaadc1a..c0e48b701 100644 --- a/pytorch_optimizer/optimizer/demo.py +++ b/pytorch_optimizer/optimizer/demo.py @@ -7,7 +7,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, LOSS, PARAMETERS HAS_EINOPS: bool = find_spec('einops') is not None diff --git a/pytorch_optimizer/optimizer/diffgrad.py b/pytorch_optimizer/optimizer/diffgrad.py index da14d98e2..6021258b5 100644 --- a/pytorch_optimizer/optimizer/diffgrad.py +++ b/pytorch_optimizer/optimizer/diffgrad.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class DiffGrad(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/exadam.py b/pytorch_optimizer/optimizer/exadam.py index f7b8104d1..030d5fec5 100644 --- a/pytorch_optimizer/optimizer/exadam.py +++ b/pytorch_optimizer/optimizer/exadam.py @@ -3,7 +3,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class EXAdam(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/experimental/ranger25.py b/pytorch_optimizer/optimizer/experimental/ranger25.py index ec657e357..9a7f2ad35 100644 --- a/pytorch_optimizer/optimizer/experimental/ranger25.py +++ b/pytorch_optimizer/optimizer/experimental/ranger25.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.agc import agc diff --git a/pytorch_optimizer/optimizer/fadam.py b/pytorch_optimizer/optimizer/fadam.py index 747399846..6d0b045f8 100644 --- a/pytorch_optimizer/optimizer/fadam.py +++ b/pytorch_optimizer/optimizer/fadam.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class FAdam(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/focus.py b/pytorch_optimizer/optimizer/focus.py index 74d72f167..c165d6ec8 100644 --- a/pytorch_optimizer/optimizer/focus.py +++ b/pytorch_optimizer/optimizer/focus.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class FOCUS(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/fp16.py b/pytorch_optimizer/optimizer/fp16.py index f40d15f9e..6a65dd7e7 100644 --- a/pytorch_optimizer/optimizer/fp16.py +++ b/pytorch_optimizer/optimizer/fp16.py @@ -4,7 +4,7 @@ from torch import nn from torch.optim import Optimizer -from pytorch_optimizer.base.types import CLOSURE, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, PARAMETERS from pytorch_optimizer.optimizer.utils import clip_grad_norm, has_overflow diff --git a/pytorch_optimizer/optimizer/fromage.py b/pytorch_optimizer/optimizer/fromage.py index c14716549..8f89a160a 100644 --- a/pytorch_optimizer/optimizer/fromage.py +++ b/pytorch_optimizer/optimizer/fromage.py @@ -10,7 +10,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class Fromage(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/ftrl.py b/pytorch_optimizer/optimizer/ftrl.py index 48e86b78a..ed10ad09c 100644 --- a/pytorch_optimizer/optimizer/ftrl.py +++ b/pytorch_optimizer/optimizer/ftrl.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class FTRL(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/galore.py b/pytorch_optimizer/optimizer/galore.py index 2dfaadc48..500654a25 100644 --- a/pytorch_optimizer/optimizer/galore.py +++ b/pytorch_optimizer/optimizer/galore.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector diff --git a/pytorch_optimizer/optimizer/gc.py b/pytorch_optimizer/optimizer/gradient_centralization.py similarity index 100% rename from pytorch_optimizer/optimizer/gc.py rename to pytorch_optimizer/optimizer/gradient_centralization.py diff --git a/pytorch_optimizer/optimizer/grams.py b/pytorch_optimizer/optimizer/grams.py index 596c72183..1fdf64b9d 100644 --- a/pytorch_optimizer/optimizer/grams.py +++ b/pytorch_optimizer/optimizer/grams.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class Grams(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/gravity.py b/pytorch_optimizer/optimizer/gravity.py index 61e95669a..5c549c1c9 100644 --- a/pytorch_optimizer/optimizer/gravity.py +++ b/pytorch_optimizer/optimizer/gravity.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class Gravity(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/grokfast.py b/pytorch_optimizer/optimizer/grokfast.py index 8e4ee61dc..7beb73154 100644 --- a/pytorch_optimizer/optimizer/grokfast.py +++ b/pytorch_optimizer/optimizer/grokfast.py @@ -7,7 +7,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS FILTER_TYPE = Literal['mean', 'sum'] diff --git a/pytorch_optimizer/optimizer/kate.py b/pytorch_optimizer/optimizer/kate.py index 49fc1e308..c53f8219b 100644 --- a/pytorch_optimizer/optimizer/kate.py +++ b/pytorch_optimizer/optimizer/kate.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class Kate(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/lamb.py b/pytorch_optimizer/optimizer/lamb.py index 464218e5b..0ce40949f 100644 --- a/pytorch_optimizer/optimizer/lamb.py +++ b/pytorch_optimizer/optimizer/lamb.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import get_global_gradient_norm diff --git a/pytorch_optimizer/optimizer/laprop.py b/pytorch_optimizer/optimizer/laprop.py index cd4412422..943111671 100644 --- a/pytorch_optimizer/optimizer/laprop.py +++ b/pytorch_optimizer/optimizer/laprop.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class LaProp(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/lars.py b/pytorch_optimizer/optimizer/lars.py index c3c979a86..7ea33ee75 100644 --- a/pytorch_optimizer/optimizer/lars.py +++ b/pytorch_optimizer/optimizer/lars.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class LARS(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/lion.py b/pytorch_optimizer/optimizer/lion.py index 511d14732..afc5cd7c7 100644 --- a/pytorch_optimizer/optimizer/lion.py +++ b/pytorch_optimizer/optimizer/lion.py @@ -2,8 +2,8 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.optimizer.gradient_centralization import centralize_gradient class Lion(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/lomo.py b/pytorch_optimizer/optimizer/lomo.py index a049c3591..99d7b8f74 100644 --- a/pytorch_optimizer/optimizer/lomo.py +++ b/pytorch_optimizer/optimizer/lomo.py @@ -7,7 +7,7 @@ from torch.distributed import ReduceOp, all_reduce from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import DEFAULTS +from pytorch_optimizer.base.type import DEFAULTS from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler from pytorch_optimizer.optimizer.utils import has_overflow, is_deepspeed_zero3_enabled diff --git a/pytorch_optimizer/optimizer/lookahead.py b/pytorch_optimizer/optimizer/lookahead.py index 832c75a0c..42a322b21 100644 --- a/pytorch_optimizer/optimizer/lookahead.py +++ b/pytorch_optimizer/optimizer/lookahead.py @@ -5,7 +5,7 @@ from torch.optim import Optimizer from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE class Lookahead(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/madgrad.py b/pytorch_optimizer/optimizer/madgrad.py index 743489e25..3bb473923 100644 --- a/pytorch_optimizer/optimizer/madgrad.py +++ b/pytorch_optimizer/optimizer/madgrad.py @@ -9,7 +9,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class MADGRAD(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/mars.py b/pytorch_optimizer/optimizer/mars.py index ac6f4f31f..21e2343f4 100644 --- a/pytorch_optimizer/optimizer/mars.py +++ b/pytorch_optimizer/optimizer/mars.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.shampoo_utils import zero_power_via_newton_schulz_5 MARS_TYPE = Literal['adamw', 'lion', 'shampoo'] diff --git a/pytorch_optimizer/optimizer/msvag.py b/pytorch_optimizer/optimizer/msvag.py index 1cf34efd2..cce4f0c4f 100644 --- a/pytorch_optimizer/optimizer/msvag.py +++ b/pytorch_optimizer/optimizer/msvag.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class MSVAG(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/muon.py b/pytorch_optimizer/optimizer/muon.py index 3008df696..2f8c3474f 100644 --- a/pytorch_optimizer/optimizer/muon.py +++ b/pytorch_optimizer/optimizer/muon.py @@ -6,7 +6,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.shampoo_utils import zero_power_via_newton_schulz_5 diff --git a/pytorch_optimizer/optimizer/nero.py b/pytorch_optimizer/optimizer/nero.py index 48e480a02..d4a4fea03 100644 --- a/pytorch_optimizer/optimizer/nero.py +++ b/pytorch_optimizer/optimizer/nero.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import channel_view diff --git a/pytorch_optimizer/optimizer/novograd.py b/pytorch_optimizer/optimizer/novograd.py index 9070eab73..51b2e47e5 100644 --- a/pytorch_optimizer/optimizer/novograd.py +++ b/pytorch_optimizer/optimizer/novograd.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class NovoGrad(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/orthograd.py b/pytorch_optimizer/optimizer/orthograd.py index b5574786a..3cd1ae460 100644 --- a/pytorch_optimizer/optimizer/orthograd.py +++ b/pytorch_optimizer/optimizer/orthograd.py @@ -4,7 +4,7 @@ from torch.optim import Optimizer from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE class OrthoGrad(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/padam.py b/pytorch_optimizer/optimizer/padam.py index 011323876..0f45069fa 100644 --- a/pytorch_optimizer/optimizer/padam.py +++ b/pytorch_optimizer/optimizer/padam.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class PAdam(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/pid.py b/pytorch_optimizer/optimizer/pid.py index c537c1b12..a907a735f 100644 --- a/pytorch_optimizer/optimizer/pid.py +++ b/pytorch_optimizer/optimizer/pid.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class PID(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/pnm.py b/pytorch_optimizer/optimizer/pnm.py index 4cd0af044..d2dbdd5f5 100644 --- a/pytorch_optimizer/optimizer/pnm.py +++ b/pytorch_optimizer/optimizer/pnm.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class PNM(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/prodigy.py b/pytorch_optimizer/optimizer/prodigy.py index 63235c797..27774bb0e 100644 --- a/pytorch_optimizer/optimizer/prodigy.py +++ b/pytorch_optimizer/optimizer/prodigy.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class Prodigy(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/psgd.py b/pytorch_optimizer/optimizer/psgd.py index 26934985b..0e9cb6e54 100644 --- a/pytorch_optimizer/optimizer/psgd.py +++ b/pytorch_optimizer/optimizer/psgd.py @@ -7,7 +7,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, LOSS, PARAMETERS from pytorch_optimizer.optimizer.psgd_utils import norm_lower_bound MEMORY_SAVE_MODE_TYPE = Literal['one_diag', 'smart_one_diag', 'all_diag'] diff --git a/pytorch_optimizer/optimizer/qhadam.py b/pytorch_optimizer/optimizer/qhadam.py index 69789c882..524801f51 100644 --- a/pytorch_optimizer/optimizer/qhadam.py +++ b/pytorch_optimizer/optimizer/qhadam.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class QHAdam(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/qhm.py b/pytorch_optimizer/optimizer/qhm.py index 2cf1d21c6..c9a970d48 100644 --- a/pytorch_optimizer/optimizer/qhm.py +++ b/pytorch_optimizer/optimizer/qhm.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class QHM(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/radam.py b/pytorch_optimizer/optimizer/radam.py index 3ec09dd5d..bdd3ace9f 100644 --- a/pytorch_optimizer/optimizer/radam.py +++ b/pytorch_optimizer/optimizer/radam.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class RAdam(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/ranger.py b/pytorch_optimizer/optimizer/ranger.py index 525027d1f..a7e5b7981 100644 --- a/pytorch_optimizer/optimizer/ranger.py +++ b/pytorch_optimizer/optimizer/ranger.py @@ -2,8 +2,8 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.optimizer.gradient_centralization import centralize_gradient class Ranger(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/ranger21.py b/pytorch_optimizer/optimizer/ranger21.py index b67d0bd96..90cd25308 100644 --- a/pytorch_optimizer/optimizer/ranger21.py +++ b/pytorch_optimizer/optimizer/ranger21.py @@ -6,9 +6,9 @@ from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.agc import agc -from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.optimizer.gradient_centralization import centralize_gradient from pytorch_optimizer.optimizer.utils import normalize_gradient, unit_norm diff --git a/pytorch_optimizer/optimizer/sam.py b/pytorch_optimizer/optimizer/sam.py index 67f53dd51..6803c9c0b 100644 --- a/pytorch_optimizer/optimizer/sam.py +++ b/pytorch_optimizer/optimizer/sam.py @@ -10,8 +10,8 @@ from pytorch_optimizer.base.exception import NoClosureError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS -from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS +from pytorch_optimizer.optimizer.gradient_centralization import centralize_gradient from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats diff --git a/pytorch_optimizer/optimizer/schedulefree.py b/pytorch_optimizer/optimizer/schedulefree.py index 59cbf6e40..427942c2c 100644 --- a/pytorch_optimizer/optimizer/schedulefree.py +++ b/pytorch_optimizer/optimizer/schedulefree.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class ScheduleFreeSGD(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/sgd.py b/pytorch_optimizer/optimizer/sgd.py index de0bbaa71..6c6461c19 100644 --- a/pytorch_optimizer/optimizer/sgd.py +++ b/pytorch_optimizer/optimizer/sgd.py @@ -5,7 +5,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class AccSGD(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/sgdp.py b/pytorch_optimizer/optimizer/sgdp.py index 35691dbbb..18f107528 100644 --- a/pytorch_optimizer/optimizer/sgdp.py +++ b/pytorch_optimizer/optimizer/sgdp.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import projection diff --git a/pytorch_optimizer/optimizer/shampoo.py b/pytorch_optimizer/optimizer/shampoo.py index 3ea27bfc5..bc7e026c0 100644 --- a/pytorch_optimizer/optimizer/shampoo.py +++ b/pytorch_optimizer/optimizer/shampoo.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.shampoo_utils import ( LayerWiseGrafting, PreConditioner, diff --git a/pytorch_optimizer/optimizer/sm3.py b/pytorch_optimizer/optimizer/sm3.py index 78718a591..f351db357 100644 --- a/pytorch_optimizer/optimizer/sm3.py +++ b/pytorch_optimizer/optimizer/sm3.py @@ -1,7 +1,7 @@ import torch from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS @torch.no_grad() diff --git a/pytorch_optimizer/optimizer/soap.py b/pytorch_optimizer/optimizer/soap.py index 9024c467c..d15dfcebd 100644 --- a/pytorch_optimizer/optimizer/soap.py +++ b/pytorch_optimizer/optimizer/soap.py @@ -6,7 +6,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DATA_FORMAT, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DATA_FORMAT, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.shampoo_utils import merge_small_dims diff --git a/pytorch_optimizer/optimizer/sophia.py b/pytorch_optimizer/optimizer/sophia.py index cc0a5544c..b795e368a 100644 --- a/pytorch_optimizer/optimizer/sophia.py +++ b/pytorch_optimizer/optimizer/sophia.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS class SophiaH(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/spam.py b/pytorch_optimizer/optimizer/spam.py index d943e1620..e55267db2 100644 --- a/pytorch_optimizer/optimizer/spam.py +++ b/pytorch_optimizer/optimizer/spam.py @@ -7,7 +7,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class CosineDecay: diff --git a/pytorch_optimizer/optimizer/srmm.py b/pytorch_optimizer/optimizer/srmm.py index 1669eba37..09a167f0a 100644 --- a/pytorch_optimizer/optimizer/srmm.py +++ b/pytorch_optimizer/optimizer/srmm.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class SRMM(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/swats.py b/pytorch_optimizer/optimizer/swats.py index 6ec5b16ec..2ea424ff2 100644 --- a/pytorch_optimizer/optimizer/swats.py +++ b/pytorch_optimizer/optimizer/swats.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class SWATS(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/tam.py b/pytorch_optimizer/optimizer/tam.py index 0d2d395b3..c61cf4d88 100644 --- a/pytorch_optimizer/optimizer/tam.py +++ b/pytorch_optimizer/optimizer/tam.py @@ -3,7 +3,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class TAM(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/tiger.py b/pytorch_optimizer/optimizer/tiger.py index 637a4ebc0..07b39981b 100644 --- a/pytorch_optimizer/optimizer/tiger.py +++ b/pytorch_optimizer/optimizer/tiger.py @@ -2,7 +2,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS class Tiger(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/trac.py b/pytorch_optimizer/optimizer/trac.py index 6ab57256e..0ab157265 100644 --- a/pytorch_optimizer/optimizer/trac.py +++ b/pytorch_optimizer/optimizer/trac.py @@ -5,7 +5,7 @@ from torch.optim import Optimizer from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE +from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE def polyval(x: torch.Tensor, coef: torch.Tensor) -> torch.Tensor: diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 4c95e347f..2fc0c1d97 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -13,7 +13,7 @@ from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.utils import clip_grad_norm_ -from pytorch_optimizer.base.types import CLOSURE, LOSS, PARAMETERS +from pytorch_optimizer.base.type import CLOSURE, LOSS, PARAMETERS def parse_pytorch_version(version_string: str) -> List[int]: diff --git a/pytorch_optimizer/optimizer/yogi.py b/pytorch_optimizer/optimizer/yogi.py index eef7455d8..3c7795717 100644 --- a/pytorch_optimizer/optimizer/yogi.py +++ b/pytorch_optimizer/optimizer/yogi.py @@ -4,7 +4,7 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS class Yogi(BaseOptimizer): diff --git a/requirements-dev.txt b/requirements-dev.txt index 477de1203..5d07757ce 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -27,11 +27,11 @@ platformdirs==4.3.6 ; python_version >= "3.8" pluggy==1.5.0 ; python_version >= "3.8" pytest-cov==5.0.0 ; python_version >= "3.8" pytest==8.3.4 ; python_version >= "3.8" -ruff==0.9.4 ; python_version >= "3.8" +ruff==0.9.6 ; python_version >= "3.8" setuptools==75.8.0 ; python_version >= "3.12" sympy==1.13.1 ; python_version >= "3.9" -sympy==1.13.3 ; python_version >= "3.8" and python_version < "3.9" +sympy==1.13.3 ; python_version < "3.9" and python_version >= "3.8" tomli==2.2.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8" -torch==2.4.1+cpu ; python_version >= "3.8" and python_version < "3.9" +torch==2.4.1+cpu ; python_version < "3.9" and python_version >= "3.8" torch==2.6.0+cpu ; python_version >= "3.9" typing-extensions==4.12.2 ; python_version >= "3.8" diff --git a/requirements.txt b/requirements.txt index 2da826a8b..08df01680 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ numpy==1.24.4 ; python_version < "3.9" and python_version >= "3.8" numpy==2.0.2 ; python_version >= "3.9" setuptools==75.8.0 ; python_version >= "3.12" sympy==1.13.1 ; python_version >= "3.9" -sympy==1.13.3 ; python_version >= "3.8" and python_version < "3.9" -torch==2.4.1+cpu ; python_version >= "3.8" and python_version < "3.9" +sympy==1.13.3 ; python_version < "3.9" and python_version >= "3.8" +torch==2.4.1+cpu ; python_version < "3.9" and python_version >= "3.8" torch==2.6.0+cpu ; python_version >= "3.9" typing-extensions==4.12.2 ; python_version >= "3.8" diff --git a/tests/utils.py b/tests/utils.py index 1deeebd80..3bcb277e6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import functional as f -from pytorch_optimizer.base.types import LOSS +from pytorch_optimizer.base.type import LOSS from pytorch_optimizer.optimizer import AdamW, Lookahead, OrthoGrad