-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathinstructor.py
executable file
·47 lines (31 loc) · 1.22 KB
/
instructor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import nobuco
from nobuco import ChannelOrder, ChannelOrderingStrategy
from nobuco.layers.weight import WeightLayer
import tensorflow as tf
from tensorflow.lite.python.lite import TFLiteConverterV2
import keras
from InstructorEmbedding import INSTRUCTOR
import torch
from torch import nn
device = 'cpu'
instructor = INSTRUCTOR('hkunlp/instructor-large').eval().to(device)
input_ids = torch.zeros(size=(2, 16), dtype=torch.int64)
attention_mask = torch.ones(size=(2, 16), dtype=torch.int64)
keras_model = nobuco.pytorch_to_keras(
instructor,
args=[{'input_ids': input_ids, 'attention_mask': attention_mask}],
inputs_channel_order=ChannelOrder.TENSORFLOW,
)
model_path = 'instructor'
# keras_model.save(model_path + '.keras')
# print('Model saved')
# custom_objects = {'WeightLayer': WeightLayer}
# keras_model_restored = keras.models.load_model(model_path + '.keras', custom_objects=custom_objects)
# print('Model loaded')
converter = TFLiteConverterV2.from_keras_model(keras_model)
converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]
tflite_model = converter.convert()
with open(model_path + '.tflite', 'wb') as f:
f.write(tflite_model)