Skip to content

Commit

Permalink
[feature] add support for export_done_file as export done signal (#447)
Browse files Browse the repository at this point in the history
* add support for export_done_file as export done signal
* add test cases and docs for export done
* fix predictor bug
  • Loading branch information
chengmengli06 committed Jan 11, 2024
1 parent 15e983c commit d443bca
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 16 deletions.
8 changes: 7 additions & 1 deletion docs/source/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ export_config {
#### Local

```bash
python -m easy_rec.python.export --pipeline_config_path dwd_avazu_ctr_deepmodel.config --export_dir ./export
python -m easy_rec.python.export --pipeline_config_path dwd_avazu_ctr_deepmodel.config --export_dir ./export --export_done_file EXPORT_DONE
```

- --pipeline_config_path: config文件路径
- --model_dir: 如果指定了model_dir将会覆盖config里面的model_dir,一般在周期性调度的时候使用
- --export_dir: 导出的目录
- --export_done_file: 导出完成标志文件名, 导出完成后,在导出目录下创建一个文件表示导出完成了
- --clear_export: 删除旧的导出文件目录

#### PAI

Expand All @@ -92,3 +94,7 @@ pai -name easy_rec_ext -project algo_public
- -Dbuckets: config所在的bucket和保存模型的bucket; 如果有多个bucket,逗号分割
- 如果是pai内部版,则不需要指定arn和ossHost, arn和ossHost放在-Dbuckets里面
- -Dbuckets=oss://easyrec/?role_arn=acs:ram::xxx:role/ev-ext-test-oss&host=oss-cn-beijing-internal.aliyuncs.com
- -Dextra_params: 其它参数, 没有在pai -name easy_rec_ext中定义的参数, 可以通过extra_params传入, 如:
- --export_done_file: 导出完成标志文件名, 导出完成后,在导出目录下创建一个文件表示导出完成了
- --clear_export: 删除旧的导出文件目录
- --place_embedding_on_cpu: 将embedding相关的操作放在cpu上,有助于提升模型在gpu环境下的推理速度
21 changes: 19 additions & 2 deletions easy_rec/python/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import tensorflow as tf
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import gfile

from easy_rec.python.main import export
from easy_rec.python.protos.train_pb2 import DistributionStrategy
Expand Down Expand Up @@ -57,6 +58,10 @@

tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir')
tf.app.flags.mark_flag_as_required('export_dir')

tf.app.flags.DEFINE_bool('clear_export', False, 'remove export_dir if exists')
tf.app.flags.DEFINE_string('export_done_file', '',
'a flag file to signal that export model is done')
FLAGS = tf.app.flags.FLAGS


Expand Down Expand Up @@ -121,8 +126,20 @@ def main(argv):
estimator_utils.init_hvd()
estimator_utils.init_sok()

export(FLAGS.export_dir, pipeline_config_path, FLAGS.checkpoint_path,
FLAGS.asset_files, FLAGS.verbose, **extra_params)
if FLAGS.clear_export:
logging.info('will clear export_dir=%s' % FLAGS.export_dir)
if gfile.IsDirectory(FLAGS.export_dir):
gfile.DeleteRecursively(FLAGS.export_dir)

export_out_dir = export(FLAGS.export_dir, pipeline_config_path,
FLAGS.checkpoint_path, FLAGS.asset_files,
FLAGS.verbose, **extra_params)

if FLAGS.export_done_file:
flag_file = os.path.join(export_out_dir, FLAGS.export_done_file)
logging.info('create export done file: %s' % flag_file)
with gfile.GFile(flag_file, 'w') as fout:
fout.write('ExportDone')


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/inference/odps_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import tensorflow as tf

from easy_rec.python.inference import Predictor
from easy_rec.python.inference.predictor import Predictor


class ODPSPredictor(Predictor):
Expand Down
3 changes: 1 addition & 2 deletions easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import six
import tensorflow as tf
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.platform import gfile

import easy_rec
from easy_rec.python.builders import strategy_builder
Expand Down Expand Up @@ -41,15 +42,13 @@
hvd = None

if tf.__version__ >= '2.0':
gfile = tf.compat.v1.gfile
from tensorflow.core.protobuf import config_pb2

ConfigProto = config_pb2.ConfigProto
GPUOptions = config_pb2.GPUOptions

tf = tf.compat.v1
else:
gfile = tf.gfile
GPUOptions = tf.GPUOptions
ConfigProto = tf.ConfigProto

Expand Down
8 changes: 3 additions & 5 deletions easy_rec/python/test/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@

import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

import easy_rec
from easy_rec.python.inference.predictor import Predictor
from easy_rec.python.utils import config_util
from easy_rec.python.utils import test_utils
from easy_rec.python.utils.test_utils import RunAsSubprocess

if tf.__version__ >= '2.0':
gfile = tf.compat.v1.gfile
else:
gfile = tf.gfile


class ExportTest(tf.test.TestCase):

Expand Down Expand Up @@ -119,6 +115,7 @@ def test_export_with_asset(self):
--pipeline_config_path %s
--export_dir %s
--asset_files fg.json:samples/model_config/taobao_fg.json
--export_done_file ExportDone
""" % (
config_path,
export_dir,
Expand All @@ -131,6 +128,7 @@ def test_export_with_asset(self):
export_dir = files[0]
assert gfile.Exists(export_dir + '/assets/taobao_fg.json')
assert gfile.Exists(export_dir + '/assets/pipeline.config')
assert gfile.Exists(export_dir + '/ExportDone')

def test_export_with_out_in_ckpt_config(self):
test_dir = test_utils.get_tmp_dir()
Expand Down
6 changes: 4 additions & 2 deletions easy_rec/python/utils/export_big_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def export_big_model(export_dir, pipeline_config, redis_params,
export_dir = os.path.join(export_dir,
meta_graph_def.meta_info_def.meta_graph_version)
export_dir = io_util.fix_oss_dir(export_dir)
logging.info('export_dir=%s' % export_dir)
if Exists(export_dir):
logging.info('will delete old dir: %s' % export_dir)
DeleteRecursively(export_dir)
Expand Down Expand Up @@ -304,7 +305,7 @@ def export_big_model(export_dir, pipeline_config, redis_params,

# remove temporary files
Remove(embed_name_to_id_file)
return
return export_dir


def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
Expand Down Expand Up @@ -553,6 +554,7 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
export_dir = os.path.join(export_dir,
meta_graph_def.meta_info_def.meta_graph_version)
export_dir = io_util.fix_oss_dir(export_dir)
logging.info('export_dir=%s' % export_dir)
if Exists(export_dir):
logging.info('will delete old dir: %s' % export_dir)
DeleteRecursively(export_dir)
Expand Down Expand Up @@ -625,4 +627,4 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,

# remove temporary files
Remove(embed_name_to_id_file)
return
return export_dir
14 changes: 11 additions & 3 deletions pai_jobs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tensorflow.python.platform import gfile

import easy_rec
from easy_rec.python.inference.predictor import ODPSPredictor
from easy_rec.python.inference.odps_predictor import ODPSPredictor
from easy_rec.python.inference.vector_retrieve import VectorRetrieve
from easy_rec.python.tools.pre_check import run_check
from easy_rec.python.utils import config_util
Expand Down Expand Up @@ -110,6 +110,8 @@
tf.app.flags.DEFINE_string('export_dir', '',
'directory where model should be exported to')
tf.app.flags.DEFINE_bool('clear_export', False, 'remove export_dir if exists')
tf.app.flags.DEFINE_string('export_done_file', '',
'a flag file to signal that export model is done')
tf.app.flags.DEFINE_integer('max_wait_ckpt_ts', 0,
'max wait time in seconds for checkpoints')
tf.app.flags.DEFINE_boolean('continue_train', True,
Expand Down Expand Up @@ -495,8 +497,14 @@ def main(argv):

extra_params = redis_params
extra_params.update(oss_params)
easy_rec.export(export_dir, pipeline_config, FLAGS.checkpoint_path,
FLAGS.asset_files, FLAGS.verbose, **extra_params)
export_out_dir = easy_rec.export(export_dir, pipeline_config,
FLAGS.checkpoint_path, FLAGS.asset_files,
FLAGS.verbose, **extra_params)
if FLAGS.export_done_file:
flag_file = os.path.join(export_out_dir, FLAGS.export_done_file)
logging.info('create export done file: %s' % flag_file)
with gfile.GFile(flag_file, 'w') as fout:
fout.write('ExportDone')
elif FLAGS.cmd == 'predict':
check_param('tables')
check_param('saved_model_dir')
Expand Down

0 comments on commit d443bca

Please sign in to comment.