Skip to content

Commit

Permalink
[feat]: add build faiss index for easyrec processor (#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji authored Jan 24, 2024
1 parent 0b4b465 commit b08d324
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 2 deletions.
77 changes: 77 additions & 0 deletions docs/source/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,80 @@ pai -name easy_rec_ext -project algo_public
- --export_done_file: 导出完成标志文件名, 导出完成后,在导出目录下创建一个文件表示导出完成了
- --clear_export: 删除旧的导出文件目录
- --place_embedding_on_cpu: 将embedding相关的操作放在cpu上,有助于提升模型在gpu环境下的推理速度

如果是双塔召回模型一般还需要进行模型切分和索引构建:

模型切分

```sql
pai -name easy_rec_ext
-Dcmd='custom'
-DentryFile='easy_rec/python/tools/split_model_pai.py'
-Dversion='{easyrec_version}'
-Dbuckets='oss://{oss_bucket}/'
-Darn='{oss_arn}'
-DossHost='oss-{region}-internal.aliyuncs.com'
-Dcluster='{{
\\"worker\\": {{
\\"count\\": 1,
\\"cpu\\": 100
}}
}}'
-Dextra_params='--model_dir=oss://{oss_bucket}/dssm/export/final --user_model_dir=oss://{oss_bucket}/dssm/export/user --item_model_dir=oss://{oss_bucket}/dssm/export/item --user_fg_json_path=oss://{oss_bucket}/dssm/user_fg.json --item_fg_json_path=oss://{oss_bucket}/dssm/item_fg.json';
```

- -Dextra_params:
- --model_dir: 待切分的saved_model目录
- --user_model_dir: 切分好的用户塔模型目录
- --item_model_dir: 切分好的物品塔模型目录
- --user_fg_json_path: 用户塔的fg json
- --item_fg_json_path: 物品塔的fg json

物品Emebdding预测和索引构建

```sql
pai -name easy_rec_ext
-Dcmd='predict'
-Dsaved_model_dir='oss://{oss_bucket}/dssm/export/item/'
-Dinput_table='odps://{project}/tables/item_feature_t'
-Doutput_table='odps://{project}/tables/dssm_item_embedding'
-Dreserved_cols='item_id'
-Doutput_cols='item_emb string'
-Dmodel_outputs='item_emb'
-Dbuckets='oss://{oss_bucket}/'
-Darn='{oss_arn}'
-DossHost='oss-{region}-internal.aliyuncs.com'
-Dcluster='{{
\\"worker\\": {{
\\"count\\": 16,
\\"cpu\\": 600,
\\"memory\\": 10000
}}
}}';
```

```sql
pai -name easy_rec_ext
-Dcmd='custom'
-DentryFile='easy_rec/python/tools/faiss_index_pai.py'
-Dtables='odps://{project}/tables/dssm_item_embedding'
-Dbuckets='oss://{oss_bucket}/'
-Darn='{oss_arn}'
-DossHost='oss-{region}-internal.aliyuncs.com'
-Dcluster='{{
\\"worker\\": {{
\\"count\\": 1,
\\"cpu\\": 100
}}
}}'
-Dextra_params='--index_output_dir=oss://{oss_bucket}/dssm/export/user';
```

-Dtables: 物品向量表

- -Dextra_params:
- --index_output_dir: 索引输出目录, 一般设置为已切分好的用户塔模型目录,便于用EasyRec Processor部署
- --index_type: 索引类型,可选 IVFFlat | HNSWFlat,默认为 IVFFlat
- --ivf_nlist: 索引类型为IVFFlat是,聚簇的数目
- --hnsw_M: 索引类型为HNSWFlat的索引参数M
- --hnsw_efConstruction: 索引类型为HNSWFlat的索引参数efConstruction
112 changes: 112 additions & 0 deletions easy_rec/python/tools/faiss_index_pai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import print_function

import logging
import os

import faiss
import numpy as np
import tensorflow as tf

logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')

tf.app.flags.DEFINE_string('tables', '', 'tables passed by pai command')
tf.app.flags.DEFINE_integer('batch_size', 1024, 'batch size')
tf.app.flags.DEFINE_integer('embedding_dim', 32, 'embedding dimension')
tf.app.flags.DEFINE_string('index_output_dir', '', 'index output directory')
tf.app.flags.DEFINE_string('index_type', 'IVFFlat', 'index type')
tf.app.flags.DEFINE_integer('ivf_nlist', 1000, 'nlist')
tf.app.flags.DEFINE_integer('hnsw_M', 32, 'hnsw M')
tf.app.flags.DEFINE_integer('hnsw_efConstruction', 200, 'hnsw efConstruction')
tf.app.flags.DEFINE_integer('debug', 0, 'debug index')

FLAGS = tf.app.flags.FLAGS


def main(argv):
reader = tf.python_io.TableReader(
FLAGS.tables, slice_id=0, slice_count=1, capacity=FLAGS.batch_size * 2)
i = 0
id_map_f = tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, 'id_mapping'), 'w')
embeddings = []
while True:
try:
records = reader.read(FLAGS.batch_size)
for j, record in enumerate(records):
if isinstance(record[0], bytes):
eid = record[0].decode('utf-8')
id_map_f.write('%s\n' % eid)

embeddings.extend(
[list(map(float, record[1].split(b','))) for record in records])
i += 1
if i % 100 == 0:
logging.info('read %d embeddings.' % (i * FLAGS.batch_size))
except tf.python_io.OutOfRangeException:
break
reader.close()
id_map_f.close()

logging.info('Building faiss index..')
if FLAGS.index_type == 'IVFFlat':
quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, FLAGS.ivf_nlist,
faiss.METRIC_INNER_PRODUCT)
elif FLAGS.index_type == 'HNSWFlat':
index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, FLAGS.hnsw_M,
faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = FLAGS.hnsw_efConstruction
else:
raise NotImplementedError

embeddings = np.array(embeddings)
if FLAGS.index_type == 'IVFFlat':
logging.info('train embeddings...')
index.train(embeddings)

logging.info('build embeddings...')
index.add(embeddings)
faiss.write_index(index, 'faiss_index')

with tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, 'faiss_index'), 'wb') as f_out:
with open('faiss_index', 'rb') as f_in:
f_out.write(f_in.read())

if FLAGS.debug != 0:
# IVFFlat
for ivf_nlist in [100, 500, 1000, 2000]:
quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, ivf_nlist,
faiss.METRIC_INNER_PRODUCT)
index.train(embeddings)
index.add(embeddings)
index_name = 'faiss_index_ivfflat_nlist%d' % ivf_nlist
faiss.write_index(index, index_name)
with tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
with open(index_name, 'rb') as f_in:
f_out.write(f_in.read())

# HNSWFlat
for hnsw_M in [16, 32, 64, 128]:
for hnsw_efConstruction in [64, 128, 256, 512, 1024, 2048, 4096, 8196]:
if hnsw_efConstruction < hnsw_M * 2:
continue
index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, hnsw_M,
faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = hnsw_efConstruction
index.add(embeddings)
index_name = 'faiss_index_hnsw_M%d_ef%d' % (hnsw_M, hnsw_efConstruction)
faiss.write_index(index, index_name)
with tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
with open(index_name, 'rb') as f_in:
f_out.write(f_in.read())


if __name__ == '__main__':
tf.app.run()
16 changes: 15 additions & 1 deletion pai_jobs/deploy_ext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,21 @@ then
rm -rf kafka.tar.gz
fi

tar -cvzhf $RES_PATH easy_rec datahub lz4 cprotobuf kafka run.py
if [ ! -d "faiss" ]
then
if [ ! -e "faiss.tar.gz" ]
then
wget http://easyrec.oss-cn-beijing.aliyuncs.com/third_party/faiss.tar.gz
if [ $? -ne 0 ]
then
echo "faiss download failed."
fi
fi
tar -zvxf faiss.tar.gz
rm -rf faiss.tar.gz
fi

tar -cvzhf $RES_PATH easy_rec datahub lz4 cprotobuf kafka faiss run.py

# 2 means generate only
if [ $mode -ne 2 ]
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
known_third_party = absl,common_io,docutils,eas_prediction,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
Expand Down

0 comments on commit b08d324

Please sign in to comment.