Skip to content

Commit

Permalink
add faiss index pai doc
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Jan 23, 2024
1 parent b2aaee1 commit 009fef5
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 6 deletions.
77 changes: 77 additions & 0 deletions docs/source/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,80 @@ pai -name easy_rec_ext -project algo_public
- --export_done_file: 导出完成标志文件名, 导出完成后,在导出目录下创建一个文件表示导出完成了
- --clear_export: 删除旧的导出文件目录
- --place_embedding_on_cpu: 将embedding相关的操作放在cpu上,有助于提升模型在gpu环境下的推理速度

如果是双塔召回模型一般还需要进行模型切分和索引构建:

模型切分

```sql
pai -name easy_rec_ext
-Dcmd='custom'
-DentryFile='easy_rec/python/tools/split_model_pai.py'
-Dversion='{easyrec_version}'
-Dbuckets='oss://{oss_bucket}/'
-Darn='{oss_arn}'
-DossHost='oss-{region}-internal.aliyuncs.com'
-Dcluster='{{
\\"worker\\": {{
\\"count\\": 1,
\\"cpu\\": 100
}}
}}'
-Dextra_params='--model_dir=oss://{oss_bucket}/dssm/export/final --user_model_dir=oss://{oss_bucket}/dssm/export/user --item_model_dir=oss://{oss_bucket}/dssm/export/item --user_fg_json_path=oss://{oss_bucket}/dssm/user_fg.json --item_fg_json_path=oss://{oss_bucket}/dssm/item_fg.json';
```

- -Dextra_params:
- --model_dir: 待切分的saved_model目录
- --user_model_dir: 切分好的用户塔模型目录
- --item_model_dir: 切分好的物品塔模型目录
- --user_fg_json_path: 用户塔的fg json
- --item_fg_json_path: 物品塔的fg json

物品Emebdding预测和索引构建

```sql
pai -name easy_rec_ext
-Dcmd='predict'
-Dsaved_model_dir='oss://{oss_bucket}/dssm/export/item /'
-Dinput_table='odps://{project}/tables/item_feature_t'
-Doutput_table='odps://{project}/tables/dssm_item_embedding'
-Dreserved_cols='item_id'
-Doutput_cols='item_emb string'
-Dmodel_outputs='item_emb'
-Dbuckets='oss://{oss_bucket}/'
-Darn='{oss_arn}'
-DossHost='oss-{region}-internal.aliyuncs.com'
-Dcluster='{{
\\"worker\\": {{
\\"count\\": 16,
\\"cpu\\": 600,
\\"memory\\": 10000
}}
}}';
```

```sql
pai -name easy_rec_ext
-Dcmd='custom'
-DentryFile='easy_rec/python/tools/faiss_index_pai.py'
-Dtables='odps://{project}/tables/dssm_item_embedding'
-Dbuckets='oss://{oss_bucket}/'
-Darn='{oss_arn}'
-DossHost='oss-{region}-internal.aliyuncs.com'
-Dcluster='{{
\\"worker\\": {{
\\"count\\": 1,
\\"cpu\\": 100
}}
}}'
-Dextra_params='--index_output_dir=oss://{oss_bucket}/dssm/export/user';
```

-Dtables: 物品向量表

- -Dextra_params:
- --index_output_dir: 索引输出目录, 一般设置为已切分好的用户塔模型目录,便于用EasyRec Processor部署
- --index_type: 索引类型,可选 IVFFlat | HNSWFlat,默认为 IVFFlat
- --ivf_nlist: 索引类型为IVFFlat是,聚簇的数目
- --hnsw_M: 索引类型为HNSWFlat的索引参数M
- --hnsw_efConstruction: 索引类型为HNSWFlat的索引参数efConstruction
12 changes: 6 additions & 6 deletions easy_rec/python/tools/faiss_index_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
tf.app.flags.DEFINE_string('tables', '', 'tables passed by pai command')
tf.app.flags.DEFINE_integer('batch_size', 1024, 'batch size')
tf.app.flags.DEFINE_integer('embedding_dim', 32, 'embedding dimension')
tf.app.flags.DEFINE_string('user_model_path', '', 'user model path')
tf.app.flags.DEFINE_string('index_output_dir', '', 'index output directory')
tf.app.flags.DEFINE_string('index_type', 'IVFFlat', 'index type')
tf.app.flags.DEFINE_integer('ivf_nlist', 1000, 'nlist')
tf.app.flags.DEFINE_integer('hnsw_M', 32, 'hnsw M')
Expand All @@ -30,7 +30,7 @@ def main(argv):
FLAGS.tables, slice_id=0, slice_count=1, capacity=FLAGS.batch_size * 2)
i = 0
id_map_f = tf.gfile.GFile(
os.path.join(FLAGS.user_model_path, 'id_mapping'), 'w')
os.path.join(FLAGS.index_output_dir, 'id_mapping'), 'w')
embeddings = []
while True:
try:
Expand Down Expand Up @@ -71,8 +71,8 @@ def main(argv):
index.add(embeddings)
faiss.write_index(index, 'faiss_index')

with tf.gfile.GFile(os.path.join(FLAGS.user_model_path, 'faiss_index'),
'wb') as f_out:
with tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, 'faiss_index'), 'wb') as f_out:
with open('faiss_index', 'rb') as f_in:
f_out.write(f_in.read())

Expand All @@ -87,7 +87,7 @@ def main(argv):
index_name = 'faiss_index_ivfflat_nlist%d' % ivf_nlist
faiss.write_index(index, index_name)
with tf.gfile.GFile(
os.path.join(FLAGS.user_model_path, index_name), 'wb') as f_out:
os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
with open(index_name, 'rb') as f_in:
f_out.write(f_in.read())

Expand All @@ -103,7 +103,7 @@ def main(argv):
index_name = 'faiss_index_hnsw_M%d_ef%d' % (hnsw_M, hnsw_efConstruction)
faiss.write_index(index, index_name)
with tf.gfile.GFile(
os.path.join(FLAGS.user_model_path, index_name), 'wb') as f_out:
os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
with open(index_name, 'rb') as f_in:
f_out.write(f_in.read())

Expand Down

0 comments on commit 009fef5

Please sign in to comment.