Skip to content

Commit

Permalink
Basic training functionality (original)
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 4, 2024
1 parent b5808fe commit 6d49eaa
Show file tree
Hide file tree
Showing 20 changed files with 276 additions and 492 deletions.
8 changes: 4 additions & 4 deletions lib/gui/analysis/event_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from dataclasses import dataclass, field

import numpy as np
import tensorflow as tf
from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
errors_impl as tf_errors)
#import tensorflow as tf
#from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module
#from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
# errors_impl as tf_errors)

from lib.serializer import get_serializer

Expand Down
8 changes: 4 additions & 4 deletions lib/model/autoclip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" Auto clipper for clipping gradients. """
import numpy as np
import tensorflow as tf
import torch


class AutoClipper(): # pylint:disable=too-few-public-methods
Expand All @@ -25,7 +25,7 @@ def __init__(self, clip_percentile: int, history_size: int = 10000):
self._index = tf.Variable(0, trainable=False)
self._history_size = history_size

def _percentile(self, grad_history: tf.Tensor) -> tf.Tensor:
def _percentile(self, grad_history: torch.Tensor) -> torch.Tensor:
""" Compute the clip percentile of the gradient history
Parameters
Expand Down Expand Up @@ -69,7 +69,7 @@ def _percentile(self, grad_history: tf.Tensor) -> tf.Tensor:

return gathered_hist

def __call__(self, grads_and_vars: list[tf.Tensor]) -> list[tf.Tensor]:
def __call__(self, grads_and_vars: list[torch.Tensor]) -> list[torch.Tensor]:
""" Call the AutoClip function.
Parameters
Expand All @@ -86,7 +86,7 @@ def __call__(self, grads_and_vars: list[tf.Tensor]) -> list[tf.Tensor]:
return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars]

@classmethod
def _get_grad_norm(cls, gradients: tf.Tensor) -> tf.Tensor:
def _get_grad_norm(cls, gradients: torch.Tensor) -> torch.Tensor:
""" Obtain the L2 Norm for the gradients
Parameters
Expand Down
4 changes: 2 additions & 2 deletions lib/model/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import inspect

import keras
from keras.saving import get_custom_objects
import keras.backend as K

import numpy as np
import tensorflow as tf



Expand Down Expand Up @@ -307,4 +307,4 @@ def get_config(self):
# Update initializers into Keras custom objects
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and obj.__module__ == __name__:
keras.utils.get_custom_objects().update({name: obj})
get_custom_objects().update({name: obj})
Loading

0 comments on commit 6d49eaa

Please sign in to comment.