Skip to content

Commit

Permalink
added support to pass cfg dict from cmd to pytorch loadgen
Browse files Browse the repository at this point in the history
  • Loading branch information
gfursin committed Apr 8, 2024
1 parent 292bfcf commit a8a5f65
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions cm-mlops/script/app-loadgen-generic-python/_cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ input_mapping:
modelpath: CM_ML_MODEL_FILE_WITH_PATH
modelcodepath: CM_ML_MODEL_CODE_WITH_PATH
modelcfgpath: CM_ML_MODEL_CFG_WITH_PATH
modelcfg: CM_ML_MODEL_CFG
modelsamplepath: CM_ML_MODEL_SAMPLE_WITH_PATH
output_dir: CM_MLPERF_OUTPUT_DIR
scenario: CM_MLPERF_LOADGEN_SCENARIO
Expand Down
26 changes: 26 additions & 0 deletions cm-mlops/script/app-loadgen-generic-python/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def preprocess(i):

env = i['env']


if 'CM_ML_MODEL_FILE_WITH_PATH' not in env:
return {'return': 1, 'error': 'Please select a variation specifying the model to run'}

Expand Down Expand Up @@ -51,8 +52,27 @@ def preprocess(i):
if env.get('CM_ML_MODEL_CODE_WITH_PATH', '') != '':
run_opts +=" --model_code "+env['CM_ML_MODEL_CODE_WITH_PATH']


if env.get('CM_ML_MODEL_CFG_WITH_PATH', '') != '':
run_opts +=" --model_cfg "+env['CM_ML_MODEL_CFG_WITH_PATH']
else:
# Check cfg from command line
cfg = env.get('CM_ML_MODEL_CFG', {})
if len(cfg)>0:
del (env['CM_ML_MODEL_CFG'])

import json, tempfile
tfile = tempfile.NamedTemporaryFile(mode="w+", suffix='.json')

fd, tfile = tempfile.mkstemp(suffix='.json', prefix='cm-cfg-')
os.close(fd)

with open(tfile, 'w') as fd:
json.dump(cfg, fd)

env['CM_APP_LOADGEN_GENERIC_PYTHON_TMP_CFG_FILE'] = tfile

run_opts +=" --model_cfg " + tfile

if env.get('CM_ML_MODEL_SAMPLE_WITH_PATH', '') != '':
run_opts +=" --model_sample_pickle "+env['CM_ML_MODEL_SAMPLE_WITH_PATH']
Expand All @@ -72,4 +92,10 @@ def preprocess(i):
def postprocess(i):

env = i['env']

tfile = env.get('CM_APP_LOADGEN_GENERIC_PYTHON_TMP_CFG_FILE', '')

if tfile!='' and os.path.isfile(tfile):
os.remove(tfile)

return {'return':0}

0 comments on commit a8a5f65

Please sign in to comment.