Skip to content

Commit

Permalink
feat(model/training): add load_with_strategy function for loading m…
Browse files Browse the repository at this point in the history
…odels using a specific distributed strategy
  • Loading branch information
ruancomelli committed Jun 18, 2022
1 parent 58b4e22 commit ec8fe3b
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions boiling_learning/model/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union

import tensorflow as tf
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.losses import Loss
from tensorflow.keras.metrics import Metric
from tensorflow.keras.optimizers import Optimizer
from typing_extensions import TypedDict
from typing_extensions import ParamSpec, TypedDict

from boiling_learning.datasets.datasets import DatasetTriplet
from boiling_learning.io import json
Expand All @@ -24,6 +24,9 @@
from boiling_learning.utils.typeutils import typename
from boiling_learning.utils.utils import resolve

_P = ParamSpec('_P')
_T = TypeVar('_T')


@describe.instance(Metric)
def _describe_typename(instance: Metric) -> str:
Expand Down Expand Up @@ -132,3 +135,19 @@ def strategy_scope(strategy: Optional[Described[tf.distribute.Strategy, Any]]) -

with context:
yield


def _wrap_with_strategy(
func: Callable[_P, _T]
) -> Callable[[Optional[Described[tf.distribute.Strategy, Any]]], Callable[_P, _T]]:
def _wrapper(strategy: Optional[Described[tf.distribute.Strategy, Any]]) -> Callable[_P, _T]:
def _wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
with strategy_scope(strategy):
return func(*args, **kwargs)

return _wrapped

return _wrapper


load_with_strategy = _wrap_with_strategy(load)

0 comments on commit ec8fe3b

Please sign in to comment.