Skip to content

Commit

Permalink
more review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Dec 6, 2024
1 parent af832de commit 09c426a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions n3fit/src/n3fit/backends/keras_backend/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
from time import time

# Callbacks need tensorflow installed even if the backend is pytorch
from keras.callbacks import Callback, TensorBoard
import numpy as np

Expand Down Expand Up @@ -196,7 +195,8 @@ def gen_tensorboard_callback(log_dir, profiling=False, histogram_freq=0):
If the profiling flag is set to True, it will also attempt
to save profiling data.
Note the usage of this callback can hurt performance.
Note the usage of this callback can hurt performance
At the moment can only be used with TensorFlow: https://github.com/keras-team/keras/issues/19121
Parameters
----------
Expand Down
4 changes: 4 additions & 0 deletions n3fit/src/n3fit/backends/keras_backend/internal_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def set_threading(threads, cores):
"Could not set tensorflow parallelism settings from n3fit, maybe tensorflow is already initialized by a third program"
)

else:
# Keras should've failed by now, if it doesn't it could be a new backend that works ootb?
log.warning(f"Backend {K.backend()} not recognized. You are entering uncharted territory")


def set_number_of_cores(max_cores=None, max_threads=None):
"""
Expand Down
8 changes: 8 additions & 0 deletions n3fit/src/n3fit/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ def check_dropout(parameters):
def check_tensorboard(tensorboard):
"""Check that the tensorbard callback can be enabled correctly"""
if tensorboard is not None:
# Check that Tensorflow is installed
try:
import tensorflow
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"The tensorboard callback requires `tensorflow` to be installed"
) from e

weight_freq = tensorboard.get("weight_freq", 0)
if weight_freq < 0:
raise CheckError(
Expand Down

0 comments on commit 09c426a

Please sign in to comment.