Skip to content

Commit

Permalink
Create output dir in training.Loop if it's not there (same as Trainer…
Browse files Browse the repository at this point in the history
…). Also report training loss in Loop.

PiperOrigin-RevId: 319320458
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Jul 2, 2020
1 parent ef0d18f commit 307c502
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
1 change: 1 addition & 0 deletions trax/fastmath/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def _custom_grad(f_vjp, f_original):
'erf': jax_special.erf,
'expit': jax_special.expit,
'grad': jax.grad,
'value_and_grad': jax.value_and_grad,
'jit': jax.jit,
'logsumexp': jax_special.logsumexp,
'lt': lax.lt,
Expand Down
20 changes: 20 additions & 0 deletions trax/fastmath/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,26 @@ def grad(*args, **kwargs):
return backend()['grad'](*args, **kwargs)


def value_and_grad(*args, **kwargs):
"""Computes the gradient of the specified function together with the value."""
if 'value_and_grad' in backend():
return backend()['value_and_grad'](*args, **kwargs)
grad_fn = grad(*args, **kwargs)
fn = args[0]
has_aux = False
if has_aux in kwargs:
has_aux = kwargs['has_aux']
if not has_aux:
def val_and_grad(*fn_args, **fn_kwargs):
return fn(*fn_args, **fn_kwargs), grad_fn(*fn_args, **fn_kwargs)
return val_and_grad
def val_and_grad_aux(*fn_args, **fn_kwargs):
g, aux = grad_fn(*fn_args, **fn_kwargs)
res, _ = fn(*fn_args, **fn_kwargs)
return (res, aux), g
return val_and_grad_aux


def vjp(*args, **kwargs):
"""Computes the vector-Jacobian product for the specified function."""
return backend()['vjp'](*args, **kwargs)
Expand Down
30 changes: 22 additions & 8 deletions trax/supervised/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
self._model_in_training = tl.Serial(model, task.loss_layer)
self._eval_model = model if eval_model is None else eval_model
self._eval_task = eval_task
self._rjust_len = max([0] + [len(name) for name in eval_task.metric_names])

self._output_dir = os.path.expanduser(output_dir) if output_dir else None
if output_dir is not None:
tf.io.gfile.makedirs(output_dir)
default_fn = _at_step_1_and_periodically_at(task.n_steps_per_checkpoint)
self._checkpoint_at = checkpoint_at or default_fn
self._eval_at = eval_at or default_fn
Expand All @@ -120,9 +124,10 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
_, _ = task.optimizer.tree_init(self._model_in_training.weights)

self._gradients_and_state_fn = (
fastmath.jit(fastmath.grad(self._model_in_training.pure_fn,
argnums=1, # arg1 of pure_fn: weights
has_aux=True))) # return (gradients, state)
fastmath.jit(fastmath.value_and_grad(
self._model_in_training.pure_fn,
argnums=1, # arg1 of pure_fn: weights
has_aux=True))) # return (loss, state), gradients

if eval_task is not None:
model_with_metrics = _model_with_metrics(self._eval_model, eval_task)
Expand All @@ -142,13 +147,23 @@ def run(self, n_steps=1):
weights = self._model_in_training.weights
state = self._model_in_training.state
slots = self._task.optimizer.slots
loss_acc, step_acc = 0.0, 0
for _ in range(n_steps):
self._step += 1
weights, state, slots = self._run_one_step(weights, state, slots)
loss, weights, state, slots = self._run_one_step(weights, state, slots)
loss_acc += loss
step_acc += 1
if self._eval_at(self._step):
self._model_in_training.weights = weights
self._model_in_training.state = state
self._eval_model.weights = self._model.weights
# TODO(lukaszkaiser): move this to a better place with other reporting
loss_name = self._task.loss_layer.name
step_acc = max(1, step_acc) # only here do avoid potential divide-by-0
self._log_step('%s %s | % .8f' % (
'train'.ljust(5), loss_name.rjust(self._rjust_len),
loss_acc / float(step_acc)))
loss_acc, step_acc = 0.0, 0
self.run_evals(weights, state)
if self._checkpoint_at(self._step):
self.save_checkpoint(weights, state, slots)
Expand Down Expand Up @@ -199,11 +214,11 @@ def _run_one_step(self, weights, state, slots):
opt_params = optimizer._init_opt_params # pylint: disable=protected-access
opt_params.update({'learning_rate': self._task.learning_rate(step)})

gradients, updated_state = (
(loss, updated_state), gradients = (
self._gradients_and_state_fn(batch, weights, state, self.new_rng()))
updated_weights, updated_slots, _ = (
optimizer.tree_update(step, gradients, weights, slots, opt_params))
return updated_weights, updated_state, updated_slots
return loss, updated_weights, updated_state, updated_slots

def run_evals(self, weights=None, state=None):
"""Runs and records evals for this training session.
Expand All @@ -230,10 +245,9 @@ def run_evals(self, weights=None, state=None):
self._metrics_fn(batch, metrics_weights, metrics_state, rng))
sums += metric_values
averages = sums / n_batches
rjust_len = max([0] + [len(name) for name in eval_task.metric_names])
for name, average_value in zip(eval_task.metric_names, averages):
self._log_step('%s %s | % .8f' % (
'eval'.ljust(5), name.rjust(rjust_len), average_value))
'eval'.ljust(5), name.rjust(self._rjust_len), average_value))

def _log_step(self, msg):
"""Logs message, labeled with the current training step number."""
Expand Down

0 comments on commit 307c502

Please sign in to comment.