Skip to content

Commit

Permalink
[feature]: add py3 test (#154)
Browse files Browse the repository at this point in the history
* add py3 test, and fix py37 test fail bug

* update git lfs to retry using accelerate endpoint in case download failed
  • Loading branch information
chengmengli06 authored Apr 18, 2022
1 parent be2bf4e commit f4089e5
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 10 deletions.
1 change: 1 addition & 0 deletions .git_oss_config_pub
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ git_oss_data_dir = data/git_oss_sample_data
host = oss-cn-beijing.aliyuncs.com
git_oss_cache_dir = ${TMPDIR}/${PROJECT_NAME}/.git_oss_cache
git_oss_private_config = ~/.git_oss_config_private
accl_endpoint = oss-accelerate.aliyuncs.com
115 changes: 115 additions & 0 deletions .github/workflows/ci_py3.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
name: CI Build PY3
on:
pull_request:
types: [opened, reopened, synchronize]

jobs:
ci-test:
runs-on: EasyRec-py3-15
defaults:
run:
shell: bash {0}
steps:
- name: FetchCommit ${{ github.event.pull_request.head.sha }}
uses: actions/checkout@v2
with:
ref: ${{ github.event.pull_request.head.sha }}
submodules: recursive
- name: RunCiTest
id: run_ci_test
env:
TEST_DEVICES: ""
PULL_REQUEST_NUM: ${{ github.event.pull_request.number }}
run: |
source activate tf15_py3
python git-lfs/git_lfs.py pull
source scripts/ci_test.sh
- name: LabelAndComment
env:
CI_TEST_PASSED: ${{steps.run_ci_test.outputs.ci_test_passed}}
uses: actions/github-script@v5
with:
script: |
const { CI_TEST_PASSED } = process.env
labels = await github.rest.issues.listLabelsOnIssue({
issue_number: context.issue.number,
repo:context.repo.repo,
owner:context.repo.owner
})
console.log('labels.url=' + labels.url)
labels = labels.data
var label_names = []
if (labels != null) {
labels.forEach(tmp_lbl => label_names.push(tmp_lbl.name))
}
console.log(`ci_test_passed=${CI_TEST_PASSED} labels=${label_names}`);
var pass_label = null;
if (labels != null) {
pass_label = labels.find(label=>label.name=='ci_py3_test_passed');
}
var fail_label = null;
if (labels != null) {
fail_label = labels.find(label=>label.name=='ci_py3_test_failed');
}
if (pass_label) {
github.rest.issues.removeLabel({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
name: 'ci_py3_test_passed'
})
}
if (fail_label) {
github.rest.issues.removeLabel({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
name: 'ci_py3_test_failed'
})
}
if (CI_TEST_PASSED == 1) {
github.rest.issues.addLabels({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
labels: ['ci_py3_test_passed']
})
github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: "CI PY3 Test Passed"
})
} else {
github.rest.issues.addLabels({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
labels: ['ci_py3_test_failed']
})
github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: "CI PY3 Test Failed"
})
}
- name: SignalFail
env:
CI_TEST_PASSED: ${{steps.run_ci_test.outputs.ci_test_passed}}
run: |
echo "CI_TEST_PASSED=${CI_TEST_PASSED}"
if [ $CI_TEST_PASSED -ne 1 ]
then
echo "ci_py3_test_failed, will exit"
exit 1
fi
2 changes: 1 addition & 1 deletion easy_rec/python/model/multi_tower_bst.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def attention_net(self, net, dim, cur_seq_len, seq_size, name):

hist_mask = tf.sequence_mask(
cur_seq_len, maxlen=seq_size - 1) # [B, seq_size-1]
cur_id_mask = tf.ones([tf.shape(hist_mask)[0], 1], dtype=tf.bool) # [B, 1]
cur_id_mask = tf.ones(tf.stack([tf.shape(hist_mask)[0], 1]), dtype=tf.bool) # [B, 1]
mask = tf.concat([hist_mask, cur_id_mask], axis=1) # [B, seq_size]
masks = tf.reshape(tf.tile(mask, [1, seq_size]),
(-1, seq_size, seq_size)) # [B, seq_size, seq_size]
Expand Down
13 changes: 13 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import glob
import logging
import os
import sys
import unittest
from distutils.version import LooseVersion

Expand Down Expand Up @@ -255,24 +256,36 @@ def test_metric_learning(self):
'samples/model_config/metric_learning_on_taobao.config', self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf((sys.version_info.major, sys.version_info.minor) > (3,6),
'Currently graph-learn not support python3.7'
)
def test_dssm_neg_sampler(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_neg_sampler_on_taobao.config',
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf((sys.version_info.major, sys.version_info.minor) > (3,6),
'Currently graph-learn not support python3.7'
)
def test_dssm_neg_sampler_v2(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_neg_sampler_v2_on_taobao.config',
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf((sys.version_info.major, sys.version_info.minor) > (3,6),
'Currently graph-learn not support python3.7'
)
def test_dssm_hard_neg_sampler(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_hard_neg_sampler_on_taobao.config',
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf((sys.version_info.major, sys.version_info.minor) > (3,6),
'Currently graph-learn not support python3.7'
)
def test_dssm_hard_neg_sampler_v2(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_hard_neg_sampler_v2_on_taobao.config',
Expand Down
50 changes: 41 additions & 9 deletions git-lfs/git_lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def get_yes_no(msg):
host = None
bucket_name = None
git_oss_private_path = None
enable_accelerate = 0
accl_endpoint = None
for line_str in fin:
line_str = line_str.strip()
if len(line_str) == 0:
Expand All @@ -248,6 +250,8 @@ def get_yes_no(msg):
git_oss_private_path = os.path.join(os.environ['HOME'], git_oss_private_path[2:])
elif line_tok[0] == 'git_oss_cache_dir':
git_oss_cache_dir = line_tok[1]
elif line_tok[0] == 'accl_endpoint':
accl_endpoint = line_tok[1]

logging.info('git_oss_data_dir=%s, host=%s, bucket_name=%s' % (
git_oss_data_dir, host, bucket_name))
Expand Down Expand Up @@ -353,15 +357,43 @@ def get_yes_no(msg):
remote_path = git_bin_url[leaf_path][1]
_, file_name_with_sig = os.path.split(remote_path)
tar_tmp_path = '%s/%s.tar.gz' % (git_oss_cache_dir, file_name_with_sig)
if not os.path.exists(tar_tmp_path):
if oss_bucket:
oss_bucket.get_object_to_file(remote_path, tar_tmp_path)
else:
url = 'http://%s.%s/%s' % (bucket_name, host, remote_path)
subprocess.check_output(['wget', url, '-O', tar_tmp_path])
else:
logging.info('%s is in cache' % file_name_with_sig)
subprocess.check_output(['tar', '-zxf', tar_tmp_path])

max_retry = 5
while max_retry > 0:
try:
if not os.path.exists(tar_tmp_path):
in_cache = False
if oss_bucket:
oss_bucket.get_object_to_file(remote_path, tar_tmp_path)
else:
url = 'http://%s.%s/%s' % (bucket_name, host, remote_path)
subprocess.check_output(['wget', url, '-O', tar_tmp_path])
else:
in_cache = True
logging.info('%s is in cache' % file_name_with_sig)
subprocess.check_output(['tar', '-zxf', tar_tmp_path])
local_sig = get_local_sig(leaf_files)
if local_sig == remote_sig:
break
if in_cache:
logging.warning('cache invalid, will download from remote')
os.remove(tar_tmp_path)
continue
logging.warning('download failed, local_sig(%s) != remote_sig(%s)' % (
local_sig, remote_sig))
except subprocess.CalledProcessError as ex:
logging.error("exception: %s" % str(ex))
except oss2.exceptions.RequestError as ex:
logging.error("exception: %s" % str(ex))

os.remove(tar_tmp_path)
if accl_endpoint is not None and host != accl_endpoint:
logging.info('will try accelerate endpoint: %s' % accl_endpoint)
host = accl_endpoint
if oss_auth:
oss_bucket = oss2.Bucket(oss_auth, host, bucket_name)
max_retry -= 1

logging.info('%s updated' % leaf_path)
any_update = True
if not any_update:
Expand Down

0 comments on commit f4089e5

Please sign in to comment.