Skip to content

Commit

Permalink
support saving&loading model and higher version tf
Browse files Browse the repository at this point in the history
- support saving&loading model
- support higher version of tensorflow
  • Loading branch information
shenweichen authored Apr 10, 2020
2 parents 732d8a1 + c36d2c0 commit 10b5256
Show file tree
Hide file tree
Showing 28 changed files with 1,284 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ labels: question
assignees: ''

---
Please refer to the [FAQ](https://deepmatch.readthedocs.io/en/latest/FAQ.html) in doc and search for the [related issues](https://github.com/shenweichen/DeepCTR/issues) before you ask the question.
Please refer to the [FAQ](https://deepmatch.readthedocs.io/en/latest/FAQ.html) in doc and search for the [related issues](https://github.com/shenweichen/DeepMatch/issues) before you ask the question.

**Describe the question(问题描述)**
A clear and concise description of what the question is.
Expand Down
54 changes: 54 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: CI

on:
push:
path:
- 'deepmatch/*'
- 'tests/*'
pull_request:
path:
- 'deepmatch/*'
- 'tests/*'

jobs:
build:

runs-on: ubuntu-latest
timeout-minutes: 120
strategy:
matrix:
python-version: [3.5,3.6,3.7]
tf-version: [1.4.0,1.14.0,2.1.0]

exclude:
- python-version: 3.7
tf-version: 1.4.0

steps:

- uses: actions/checkout@v1

- name: Setup python environment
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip3 install -q tensorflow==${{ matrix.tf-version }}
pip install -q requests
pip install -e .
- name: Test with pytest
timeout-minutes: 120
run: |
pip install -q pytest
pip install -q pytest-cov
pip install -q python-coveralls
pytest --cov=deepmatch --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/[email protected]
with:
token: ${{secrets.CODECOV_TOKEN}}
file: ./coverage.xml
flags: pytest
name: py${{ matrix.python-version }}-tf${{ matrix.tf-version }}
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@


[![Documentation Status](https://readthedocs.org/projects/deepmatch/badge/?version=latest)](https://deepmatch.readthedocs.io/)
![CI status](https://github.com/shenweichen/deepmatch/workflows/CI/badge.svg)
[![codecov](https://codecov.io/gh/shenweichen/DeepMatch/branch/master/graph/badge.svg)](https://codecov.io/gh/shenweichen/DeepMatch)
[![Disscussion](https://img.shields.io/badge/chat-wechat-brightgreen?style=flat)](./README.md#disscussiongroup)
[![License](https://img.shields.io/github/license/shenweichen/deepmatch.svg)](https://github.com/shenweichen/deepmatch/blob/master/LICENSE)

DeepMatch is a deep matching model library for recommendations & advertising. It's easy to **train models** and to **export representation vectors** for user and item which can be used for **ANN search**.You can use any complex model with `model.fit()`and `model.predict()` .

Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.html)
Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.html) or [**Run examples**](./examples/colab_MovieLen1M_YoutubeDNN.ipynb) !



## Models List
Expand Down
2 changes: 1 addition & 1 deletion deepmatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils import check_version

__version__ = '0.1.1'
__version__ = '0.1.2'
check_version(__version__)
17 changes: 17 additions & 0 deletions deepmatch/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from deepctr.layers import custom_objects
from deepctr.layers.utils import reduce_sum

from .core import PoolingLayer, Similarity, LabelAwareAttention, CapsuleLayer,SampledSoftmaxLayer,EmbeddingIndex
from ..utils import sampledsoftmaxloss

_custom_objects = {'PoolingLayer': PoolingLayer,
'Similarity': Similarity,
'LabelAwareAttention': LabelAwareAttention,
'CapsuleLayer': CapsuleLayer,
'reduce_sum':reduce_sum,
'SampledSoftmaxLayer':SampledSoftmaxLayer,
'sampledsoftmaxloss':sampledsoftmaxloss,
'EmbeddingIndex':EmbeddingIndex
}

custom_objects = dict(custom_objects, **_custom_objects)
79 changes: 56 additions & 23 deletions deepmatch/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,24 @@ def call(self, seq_value_len_list, mask=None, **kwargs):
hist = reduce_max(a, axis=-1, )
return hist

def get_config(self, ):
config = {'mode': self.mode, 'supports_masking': self.supports_masking}
base_config = super(PoolingLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class SampledSoftmaxLayer(Layer):
def __init__(self, item_embedding, num_sampled=5, **kwargs):
def __init__(self, num_sampled=5, **kwargs):
self.num_sampled = num_sampled
self.target_song_size = item_embedding.input_dim
self.item_embedding = item_embedding
super(SampledSoftmaxLayer, self).__init__(**kwargs)

def build(self, input_shape):
self.zero_bias = self.add_weight(shape=[self.target_song_size],
self.size = input_shape[0][0]
self.zero_bias = self.add_weight(shape=[self.size],
initializer=Zeros,
dtype=tf.float32,
trainable=False,
name="bias")
if not self.item_embedding.built:
self.item_embedding.build([])
self.trainable_weights.append(self.item_embedding.embeddings)
super(SampledSoftmaxLayer, self).build(input_shape)

def call(self, inputs_with_label_idx, training=None, **kwargs):
Expand All @@ -61,22 +62,22 @@ def call(self, inputs_with_label_idx, training=None, **kwargs):
target (i.e., a repeat of the training data) to compute the labels
argument
"""
inputs, label_idx = inputs_with_label_idx
embeddings, inputs, label_idx = inputs_with_label_idx

loss = tf.nn.sampled_softmax_loss(weights=self.item_embedding.embeddings,
loss = tf.nn.sampled_softmax_loss(weights=embeddings, # self.item_embedding.
biases=self.zero_bias,
labels=label_idx,
inputs=inputs,
num_sampled=self.num_sampled,
num_classes=self.target_song_size
num_classes=self.size, # self.target_song_size
)
return tf.expand_dims(loss, axis=1)

def compute_output_shape(self, input_shape):
return (None, 1)

def get_config(self, ):
config = {'item_embedding': self.item_embedding, 'num_sampled': self.num_sampled}
config = {'num_sampled': self.num_sampled}
base_config = super(SampledSoftmaxLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand All @@ -96,7 +97,7 @@ def build(self, input_shape):
def call(self, inputs, training=None, **kwargs):
keys = inputs[0]
query = inputs[1]
weight = tf.reduce_sum(keys * query, axis=-1, keep_dims=True)
weight = reduce_sum(keys * query, axis=-1, keep_dims=True)
weight = tf.pow(weight, self.pow_p) # [x,k_max,1]

if len(inputs) == 3:
Expand All @@ -112,7 +113,7 @@ def call(self, inputs, training=None, **kwargs):
weight = tf.where(seq_mask, weight, padding)

weight = softmax(weight, dim=1, name="weight")
output = tf.reduce_sum(keys * weight, axis=1)
output = reduce_sum(keys * weight, axis=1)

return output

Expand Down Expand Up @@ -151,23 +152,29 @@ def call(self, inputs, **kwargs):
def compute_output_shape(self, input_shape):
return (None, 1)

def get_config(self, ):
config = {'gamma': self.gamma, 'axis': self.axis, 'type': self.type}
base_config = super(Similarity, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class CapsuleLayer(Layer):
def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3,
initializer=RandomNormal(stddev=1.0), **kwargs):
init_std=1.0, **kwargs):
self.input_units = input_units
self.out_units = out_units
self.max_len = max_len
self.k_max = k_max
self.iteration_times = iteration_times
self.initializer = initializer
self.init_std = init_std
super(CapsuleLayer, self).__init__(**kwargs)

def build(self, input_shape):
self.routing_logits = self.add_weight(shape=[1, self.k_max, self.max_len], initializer=self.initializer,
self.routing_logits = self.add_weight(shape=[1, self.k_max, self.max_len],
initializer=RandomNormal(stddev=self.init_std),
trainable=False, name="B", dtype=tf.float32)
self.bilinear_mapping_matrix = self.add_weight(shape=[self.input_units, self.out_units],
initializer=self.initializer,
initializer=RandomNormal(stddev=self.init_std),
name="S", dtype=tf.float32)
super(CapsuleLayer, self).build(input_shape)

Expand All @@ -183,21 +190,47 @@ def call(self, inputs, **kwargs):
weight = tf.nn.softmax(routing_logits_with_padding)
behavior_embdding_mapping = tf.tensordot(behavior_embddings, self.bilinear_mapping_matrix, axes=1)
Z = tf.matmul(weight, behavior_embdding_mapping)
interet_capsules = squash(Z)
delta_routing_logits = tf.reduce_sum(
tf.matmul(interet_capsules, tf.transpose(behavior_embdding_mapping, perm=[0, 2, 1])),
interest_capsules = squash(Z)
delta_routing_logits = reduce_sum(
tf.matmul(interest_capsules, tf.transpose(behavior_embdding_mapping, perm=[0, 2, 1])),
axis=0, keep_dims=True
)
self.routing_logits.assign_add(delta_routing_logits)
interet_capsules = tf.reshape(interet_capsules, [-1, self.k_max, self.out_units])
return interet_capsules
interest_capsules = tf.reshape(interest_capsules, [-1, self.k_max, self.out_units])
return interest_capsules

def compute_output_shape(self, input_shape):
return (None, self.k_max, self.out_units)

def get_config(self, ):
config = {'input_units': self.input_units, 'out_units': self.out_units, 'max_len': self.max_len,
'k_max': self.k_max, 'iteration_times': self.iteration_times, "init_std": self.init_std}
base_config = super(CapsuleLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


def squash(inputs):
vec_squared_norm = tf.reduce_sum(tf.square(inputs), axis=-1, keep_dims=True)
vec_squared_norm = reduce_sum(tf.square(inputs), axis=-1, keep_dims=True)
scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + 1e-8)
vec_squashed = scalar_factor * inputs
return vec_squashed




class EmbeddingIndex(Layer):

def __init__(self, index,**kwargs):
self.index =index
super(EmbeddingIndex, self).__init__(**kwargs)

def build(self, input_shape):

super(EmbeddingIndex, self).build(
input_shape) # Be sure to call this somewhere!
def call(self, x, **kwargs):
return tf.constant(self.index)
def get_config(self, ):
config = {'index': self.index, }
base_config = super(EmbeddingIndex, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
4 changes: 2 additions & 2 deletions deepmatch/models/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from ..layers.core import Similarity


def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64, 16),
item_dnn_hidden_units=(64, 16),
def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64, 32),
item_dnn_hidden_units=(64, 32),
dnn_activation='tanh', dnn_use_bn=False,
l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, seed=1024, metric='cos'):
"""Instantiates the Deep Structured Semantic Model architecture.
Expand Down
41 changes: 25 additions & 16 deletions deepmatch/models/mind.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from deepmatch.utils import get_item_embedding
from ..inputs import create_embedding_matrix
from ..layers.core import CapsuleLayer, SampledSoftmaxLayer, PoolingLayer, LabelAwareAttention
from ..layers.core import CapsuleLayer, PoolingLayer, LabelAwareAttention,SampledSoftmaxLayer,EmbeddingIndex


def shape_target(target_emb_tmp, target_emb_size):
Expand Down Expand Up @@ -57,6 +57,10 @@ def MIND(user_feature_columns, item_feature_columns, num_sampled=5, k_max=2, p=1
raise ValueError("Now MIND only support 1 item feature like item_id")
item_feature_column = item_feature_columns[0]
item_feature_name = item_feature_column.name
item_vocabulary_size = item_feature_columns[0].vocabulary_size
item_embedding_dim = item_feature_columns[0].embedding_dim
#item_index = Input(tensor=tf.constant([list(range(item_vocabulary_size))]))

history_feature_list = [item_feature_name]

features = build_input_features(user_feature_columns)
Expand All @@ -75,24 +79,24 @@ def MIND(user_feature_columns, item_feature_columns, num_sampled=5, k_max=2, p=1
history_feature_columns.append(fc)
else:
sparse_varlen_feature_columns.append(fc)

seq_max_len = history_feature_columns[0].maxlen
inputs_list = list(features.values())

embedding_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding, init_std,
embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding, init_std,
seed, prefix="")

item_features = build_input_features(item_feature_columns)

query_emb_list = embedding_lookup(embedding_dict, item_features, item_feature_columns,
query_emb_list = embedding_lookup(embedding_matrix_dict, item_features, item_feature_columns,
history_feature_list,
history_feature_list, to_list=True)
keys_emb_list = embedding_lookup(embedding_dict, features, history_feature_columns, history_fc_names,
keys_emb_list = embedding_lookup(embedding_matrix_dict, features, history_feature_columns, history_fc_names,
history_fc_names, to_list=True)
dnn_input_emb_list = embedding_lookup(embedding_dict, features, sparse_feature_columns,
dnn_input_emb_list = embedding_lookup(embedding_matrix_dict, features, sparse_feature_columns,
mask_feat_list=history_feature_list, to_list=True)
dense_value_list = get_dense_input(features, dense_feature_columns)

sequence_embed_dict = varlen_embedding_lookup(embedding_dict, features, sparse_varlen_feature_columns)
sequence_embed_dict = varlen_embedding_lookup(embedding_matrix_dict, features, sparse_varlen_feature_columns)
sequence_embed_list = get_varlen_pooling_list(sequence_embed_dict, features, sparse_varlen_feature_columns,
to_list=True)

Expand All @@ -104,12 +108,12 @@ def MIND(user_feature_columns, item_feature_columns, num_sampled=5, k_max=2, p=1
history_emb = PoolingLayer()(NoMask()(keys_emb_list))
target_emb = PoolingLayer()(NoMask()(query_emb_list))

target_emb_size = target_emb.get_shape()[-1].value
max_len = history_emb.get_shape()[1].value
#target_emb_size = target_emb.get_shape()[-1].value
#max_len = history_emb.get_shape()[1].value
hist_len = features['hist_len']

high_capsule = CapsuleLayer(input_units=target_emb_size,
out_units=target_emb_size, max_len=max_len,
high_capsule = CapsuleLayer(input_units=item_embedding_dim,
out_units=item_embedding_dim, max_len=seq_max_len,
k_max=k_max)((history_emb, hist_len))

if len(dnn_input_emb_list) > 0 or len(dense_value_list) > 0:
Expand All @@ -121,27 +125,32 @@ def MIND(user_feature_columns, item_feature_columns, num_sampled=5, k_max=2, p=1
else:
user_deep_input = high_capsule

# user_deep_input._uses_learning_phase = True # attention_score._uses_learning_phase

user_embeddings = DNN(user_dnn_hidden_units, dnn_activation, l2_reg_dnn,
dnn_dropout, dnn_use_bn, seed, name="user_embedding")(user_deep_input)
item_inputs_list = list(item_features.values())

item_embedding = embedding_dict[item_feature_name]
item_embedding_matrix = embedding_matrix_dict[item_feature_name]

item_index = EmbeddingIndex(list(range(item_vocabulary_size)))(item_features[item_feature_name])

item_embedding_weight = NoMask()(item_embedding_matrix(item_index))

pooling_item_embedding_weight = PoolingLayer()([item_embedding_weight])

if dynamic_k:
user_embedding_final = LabelAwareAttention(k_max=k_max, pow_p=p, )((user_embeddings, target_emb, hist_len))
else:
user_embedding_final = LabelAwareAttention(k_max=k_max, pow_p=p, )((user_embeddings, target_emb))

output = SampledSoftmaxLayer(item_embedding, num_sampled=num_sampled)(
inputs=(user_embedding_final, item_features[item_feature_name]))
output = SampledSoftmaxLayer(num_sampled=num_sampled)(
inputs=(pooling_item_embedding_weight,user_embedding_final, item_features[item_feature_name]))
model = Model(inputs=inputs_list + item_inputs_list, outputs=output)

model.__setattr__("user_input", inputs_list)
model.__setattr__("user_embedding", user_embeddings)

model.__setattr__("item_input", item_inputs_list)
model.__setattr__("item_embedding", get_item_embedding(item_embedding, item_features[item_feature_name]))
model.__setattr__("item_embedding", get_item_embedding(pooling_item_embedding_weight, item_features[item_feature_name]))

return model
Loading

0 comments on commit 10b5256

Please sign in to comment.