Skip to content

Commit

Permalink
Merge branch 'master' into feature/multi_task_add_dynamic_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
chengaofei committed Sep 10, 2024
2 parents 2bad4c8 + cbec539 commit ccf1862
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 82 deletions.
78 changes: 0 additions & 78 deletions .github/workflows/code_style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,85 +36,7 @@ jobs:
echo "ci_test_passed=0" >> $GITHUB_OUTPUT
fi
fi
- name: LabelAndComment
env:
CI_TEST_PASSED: ${{steps.run_ci_test.outputs.ci_test_passed}}
uses: actions/github-script@v5
with:
script: |
const { CI_TEST_PASSED } = process.env
labels = await github.rest.issues.listLabelsOnIssue({
issue_number: context.issue.number,
repo:context.repo.repo,
owner:context.repo.owner
})
console.log('labels.url=' + labels.url)
labels = labels.data
var label_names = []
if (labels != null) {
labels.forEach(tmp_lbl => label_names.push(tmp_lbl.name))
}
console.log(`ci_test_passed=${CI_TEST_PASSED} labels=${label_names}`);
var pass_label = null;
if (labels != null) {
pass_label = labels.find(label=>label.name=='code_style_test_passed');
}
var fail_label = null;
if (labels != null) {
fail_label = labels.find(label=>label.name=='code_style_test_failed');
}
if (pass_label) {
github.rest.issues.removeLabel({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
name: 'code_style_test_passed'
})
}
if (fail_label) {
github.rest.issues.removeLabel({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
name: 'code_style_test_failed'
})
}
if (CI_TEST_PASSED == 1) {
github.rest.issues.addLabels({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
labels: ['code_style_test_passed']
})
github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: "Code Style Test Passed"
})
} else {
github.rest.issues.addLabels({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
labels: ['code_style_test_failed']
})
github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: "Code Style Test Failed"
})
}
- name: SignalFail
env:
CI_TEST_PASSED: ${{steps.run_ci_test.outputs.ci_test_passed}}
Expand Down
8 changes: 5 additions & 3 deletions easy_rec/python/model/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ def build_predict_graph(self):
kernel_regularizer=self._l2_reg,
name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))

if self._loss_type == LossType.CLASSIFICATION:
if self._model_config.simi_func == Similarity.COSINE:
if self._model_config.simi_func == Similarity.COSINE:
user_tower_emb = self.norm(user_tower_emb)
item_tower_emb = self.norm(item_tower_emb)
temperature = self._model_config.temperature
else:
temperature = 1.0

user_item_sim = self.sim(user_tower_emb, item_tower_emb)
user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
if self._model_config.scale_simi:
sim_w = tf.get_variable(
'sim_w',
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/protos/dssm.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ message DSSM {
optional bool scale_simi = 5 [default = true];
optional string item_id = 9;
required bool ignore_in_batch_neg_sam = 10 [default = false];
// normalize user_tower_embedding and item_tower_embedding
optional float temperature = 11 [default = 1.0];
}
2 changes: 1 addition & 1 deletion easy_rec/python/test/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main(argv):
(len(all_tests), test_dir))

max_num_port_per_proc = 3
total_port_num = (max_num_port_per_proc + 2) * FLAGS.num_parallel
total_port_num = (max_num_port_per_proc + 2) * FLAGS.num_parallel * 10
all_available_ports = test_utils.get_ports_base(total_port_num).tolist()

procs = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,11 @@ model_config:{
}
}
l2_regularization: 1e-6
simi_func: INNER_PRODUCT
}
embedding_regularization: 5e-5
loss_type: L2_LOSS

}

export_config {
Expand Down

0 comments on commit ccf1862

Please sign in to comment.