Skip to content

Commit

Permalink
test(tensorflow-lite): added lite import from tensorflow library too
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Mar 27, 2024
1 parent d9df600 commit 3f6dff1
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions predictor/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
try:
import tflite_runtime.interpreter as tflite
except ImportError:
print(
"TFlite_runtime is not installed , Predictions with .tflite extension won't work"
)
print("TFlite_runtime is not installed.")
try:
from tensorflow import keras
from tensorflow import keras, lite

except ImportError:
print("Tensorflow is not installed , Predictions with .h5 or .tf won't work")

Expand Down Expand Up @@ -66,7 +65,11 @@ def run_prediction(
start = time.time()
print(f"Using : {checkpoint_path}")
if checkpoint_path.endswith(".tflite"):
interpreter = tflite.Interpreter(model_path=checkpoint_path)
try:
interpreter = tflite.Interpreter(model_path=checkpoint_path)
except Exception as ex:
interpreter = lite.Interpreter(model_path=checkpoint_path)

interpreter.resize_tensor_input(
interpreter.get_input_details()[0]["index"], (BATCH_SIZE, 256, 256, 3)
)
Expand Down

0 comments on commit 3f6dff1

Please sign in to comment.