diff --git a/CHANGELOG.md b/CHANGELOG.md
index 74ebf69..be05abc 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,10 +1,20 @@
## FuxiCTR Versions
-### FuxiCTR v2.2
+### FuxiCTR v2.3
[Doing] Add support for saving pb file, exporting embeddings
-[Doing] Add support of NVTabular data
+[Doing] Add support of multi-gpu training
+
+**FuxiCTR v2.3.0, 2024-04-20**
++ [Refactor] Support data format of npz and parquet
+
+-------------------------------
+
+### FuxiCTR v2.2
+
+**FuxiCTR v2.2.3, 2024-04-20**
++ [Fix] Quick fix to v2.2.2 that miss one line when committing
-**FuxiCTR v2.2.2, 2024-04-18**
+**FuxiCTR v2.2.2, 2024-04-18 (Deprecated)**
+ [Feature] Update to use polars instead of pandas for faster feature processing
+ [Fix] When num_workers > 1, NpzBlockDataLoader cannot keep the reading order of samples ([#86](https://github.com/xue-pai/FuxiCTR/issues/86))
diff --git a/README.md b/README.md
index 0b809bc..7fa7733 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
-
+
@@ -102,7 +102,7 @@ We have benchmarked FuxiCTR models on a set of open datasets as follows:
FuxiCTR has the following dependencies:
-+ python 3.6+
++ python 3.9+
+ pytorch 1.10+ (required only for Torch models)
+ tensorflow 2.1+ (required only for TF models)
diff --git a/fuxictr/preprocess/feature_processor.py b/fuxictr/preprocess/feature_processor.py
index 8072c19..da254e7 100644
--- a/fuxictr/preprocess/feature_processor.py
+++ b/fuxictr/preprocess/feature_processor.py
@@ -71,10 +71,12 @@ def read_csv(self, data_path, sep=",", n_rows=None, **kwargs):
logging.info("Reading file: " + data_path)
file_names = sorted(glob.glob(data_path))
assert len(file_names) > 0, f"Invalid data path: {data_path}"
- # Require python >= 3.8 for use polars to scan multiple csv files
- file_names = file_names[0]
- ddf = pl.scan_csv(source=file_names, separator=sep, dtypes=self.dtype_dict,
- low_memory=False, n_rows=n_rows)
+ dfs = [
+ pl.scan_csv(source=file_name, separator=sep, dtypes=self.dtype_dict,
+ low_memory=False, n_rows=n_rows)
+ for file_name in file_names
+ ]
+ ddf = pl.concat(dfs)
return ddf
def preprocess(self, ddf):
diff --git a/fuxictr/preprocess/tokenizer.py b/fuxictr/preprocess/tokenizer.py
index c23bd59..f518395 100644
--- a/fuxictr/preprocess/tokenizer.py
+++ b/fuxictr/preprocess/tokenizer.py
@@ -22,6 +22,7 @@
from keras_preprocessing.sequence import pad_sequences
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
+from ..utils import load_pretrain_emb
class Tokenizer(object):
@@ -125,13 +126,9 @@ def encode_sequence(self, series):
return np.array(seqs)
def load_pretrained_vocab(self, feature_dtype, pretrain_path, expand_vocab=True):
- if pretrain_path.endswith(".h5"):
- with h5py.File(pretrain_path, 'r') as hf:
- keys = hf["key"][:]
- # in case mismatch of dtype between int and str
- keys = keys.astype(feature_dtype)
- elif pretrain_path.endswith(".npz"):
- keys = np.load(pretrain_path)["key"]
+ keys = load_pretrain_emb(pretrain_path, keys=["key"])
+ # in case mismatch of dtype between int and str
+ keys = keys.astype(feature_dtype)
# Update vocab with pretrained keys in case new tokens appear in validation or test set
# Do NOT update OOV index here since it is used in PretrainedEmbedding
if expand_vocab:
diff --git a/fuxictr/pytorch/dataloaders/npz_block_dataloader.py b/fuxictr/pytorch/dataloaders/npz_block_dataloader.py
index eb81592..0503099 100644
--- a/fuxictr/pytorch/dataloaders/npz_block_dataloader.py
+++ b/fuxictr/pytorch/dataloaders/npz_block_dataloader.py
@@ -18,15 +18,15 @@
import numpy as np
from itertools import chain
-import torch
+from torch.utils.data.dataloader import default_collate
from torch.utils import data
import glob
-class BlockDataPipe(data.IterDataPipe):
- def __init__(self, block_datapipe, feature_map):
+class IterDataPipe(data.IterDataPipe):
+ def __init__(self, data_blocks, feature_map):
self.feature_map = feature_map
- self.block_datapipe = block_datapipe
+ self.data_blocks = data_blocks
def load_data(self, data_path):
data_dict = np.load(data_path)
@@ -38,8 +38,7 @@ def load_data(self, data_path):
data_arrays.append(array.reshape(-1, 1))
else:
data_arrays.append(array)
- data_tensor = torch.from_numpy(np.hstack(data_arrays))
- return data_tensor
+ return np.hstack(data_arrays)
def read_block(self, data_block):
darray = self.load_data(data_block)
@@ -49,11 +48,11 @@ def read_block(self, data_block):
def __iter__(self):
worker_info = data.get_worker_info()
if worker_info is None: # single-process data loading
- block_list = self.block_datapipe
+ block_list = self.data_blocks
else: # in a worker process
block_list = [
block
- for idx, block in enumerate(self.block_datapipe)
+ for idx, block in enumerate(self.data_blocks)
if idx % worker_info.num_workers == worker_info.id
]
return chain.from_iterable(map(self.read_block, block_list))
@@ -71,13 +70,15 @@ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False,
self.feature_map = feature_map
self.batch_size = batch_size
self.num_batches, self.num_samples = self.count_batches_and_samples()
- datapipe = BlockDataPipe(data_blocks, feature_map)
+ datapipe = IterDataPipe(self.data_blocks, feature_map)
if shuffle:
datapipe = datapipe.shuffle(buffer_size=buffer_size)
else:
- num_workers = 1 # multiple workers cannot keep the order of data reading
- super(NpzBlockDataLoader, self).__init__(dataset=datapipe, batch_size=batch_size,
- num_workers=num_workers)
+ num_workers = 1 # multiple workers cannot keep the order of data reading
+ super(NpzBlockDataLoader, self).__init__(dataset=datapipe,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ collate_fn=BatchCollator(feature_map))
def __len__(self):
return self.num_batches
@@ -89,3 +90,16 @@ def count_batches_and_samples(self):
num_samples += block_size
num_batches = int(np.ceil(num_samples / self.batch_size))
return num_batches, num_samples
+
+
+class BatchCollator(object):
+ def __init__(self, feature_map):
+ self.feature_map = feature_map
+
+ def __call__(self, batch):
+ batch_tensor = default_collate(batch)
+ all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels
+ batch_dict = dict()
+ for col in all_cols:
+ batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)]
+ return batch_dict
diff --git a/fuxictr/pytorch/dataloaders/npz_dataloader.py b/fuxictr/pytorch/dataloaders/npz_dataloader.py
index e5740a1..3822e76 100644
--- a/fuxictr/pytorch/dataloaders/npz_dataloader.py
+++ b/fuxictr/pytorch/dataloaders/npz_dataloader.py
@@ -17,7 +17,7 @@
import numpy as np
from torch.utils import data
-import torch
+from torch.utils.data.dataloader import default_collate
class Dataset(data.Dataset):
@@ -41,8 +41,7 @@ def load_data(self, data_path):
data_arrays.append(array.reshape(-1, 1))
else:
data_arrays.append(array)
- data_tensor = torch.from_numpy(np.hstack(data_arrays))
- return data_tensor
+ return np.hstack(data_arrays)
class NpzDataLoader(data.DataLoader):
@@ -51,10 +50,24 @@ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, num_wor
data_path += ".npz"
self.dataset = Dataset(feature_map, data_path)
super(NpzDataLoader, self).__init__(dataset=self.dataset, batch_size=batch_size,
- shuffle=shuffle, num_workers=num_workers)
+ shuffle=shuffle, num_workers=num_workers,
+ collate_fn=BatchCollator(feature_map))
self.num_samples = len(self.dataset)
self.num_blocks = 1
self.num_batches = int(np.ceil(self.num_samples * 1.0 / self.batch_size))
def __len__(self):
return self.num_batches
+
+
+class BatchCollator(object):
+ def __init__(self, feature_map):
+ self.feature_map = feature_map
+
+ def __call__(self, batch):
+ batch_tensor = default_collate(batch)
+ all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels
+ batch_dict = dict()
+ for col in all_cols:
+ batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)]
+ return batch_dict
diff --git a/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py b/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py
new file mode 100644
index 0000000..8abb234
--- /dev/null
+++ b/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py
@@ -0,0 +1,79 @@
+# =========================================================================
+# Copyright (C) 2023-2024. FuxiCTR Authors. All rights reserved.
+# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+
+import numpy as np
+from itertools import chain
+from torch.utils import data
+import glob
+import polars as pl
+import pandas as pd
+
+
+class IterDataPipe(data.IterDataPipe):
+ def __init__(self, data_blocks, feature_map):
+ self.feature_map = feature_map
+ self.data_blocks = data_blocks
+
+ def read_block(self, data_block):
+ data_df = pd.read_parquet(data_block)
+ for idx in range(len(data_df)):
+ yield data_df.iloc[idx].to_dict()
+
+ def __iter__(self):
+ worker_info = data.get_worker_info()
+ if worker_info is None: # single-process data loading
+ block_list = self.data_blocks
+ else: # in a worker process
+ block_list = [
+ block
+ for idx, block in enumerate(self.data_blocks)
+ if idx % worker_info.num_workers == worker_info.id
+ ]
+ return chain.from_iterable(map(self.read_block, block_list))
+
+
+class ParquetBlockDataLoader(data.DataLoader):
+ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False,
+ num_workers=1, buffer_size=100000, **kwargs):
+ data_blocks = glob.glob(data_path + "/*.parquet")
+ assert len(data_blocks) > 0, f"invalid data_path: {data_path}"
+ if len(data_blocks) > 1:
+ data_blocks.sort() # sort by part name
+ self.data_blocks = data_blocks
+ self.num_blocks = len(self.data_blocks)
+ self.feature_map = feature_map
+ self.batch_size = batch_size
+ self.num_batches, self.num_samples = self.count_batches_and_samples()
+ datapipe = IterDataPipe(self.data_blocks, feature_map)
+ if shuffle:
+ datapipe = datapipe.shuffle(buffer_size=buffer_size)
+ else:
+ num_workers = 1 # multiple workers cannot keep the order of data reading
+ super().__init__(dataset=datapipe, batch_size=batch_size,
+ num_workers=num_workers)
+
+ def __len__(self):
+ return self.num_batches
+
+ def count_batches_and_samples(self):
+ num_samples = 0
+ for data_block in self.data_blocks:
+ df = pl.scan_parquet(data_block)
+ num_samples += df.select(pl.count()).collect().item()
+ num_batches = int(np.ceil(num_samples / self.batch_size))
+ return num_batches, num_samples
diff --git a/fuxictr/pytorch/dataloaders/parquet_dataloader.py b/fuxictr/pytorch/dataloaders/parquet_dataloader.py
new file mode 100644
index 0000000..a4612fb
--- /dev/null
+++ b/fuxictr/pytorch/dataloaders/parquet_dataloader.py
@@ -0,0 +1,47 @@
+# =========================================================================
+# Copyright (C) 2024. FuxiCTR Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+
+import numpy as np
+from torch.utils import data
+import pandas as pd
+
+
+class Dataset(data.Dataset):
+ def __init__(self, feature_map, data_path):
+ self.feature_map = feature_map
+ self.data_df = pd.read_parquet(data_path)
+
+ def __getitem__(self, index):
+ return self.data_df.iloc[index].to_dict()
+
+ def __len__(self):
+ return len(self.data_df)
+
+
+class ParquetDataLoader(data.DataLoader):
+ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, num_workers=1, **kwargs):
+ if not data_path.endswith(".parquet"):
+ data_path += ".parquet"
+ self.dataset = Dataset(feature_map, data_path)
+ super().__init__(dataset=self.dataset, batch_size=batch_size,
+ shuffle=shuffle, num_workers=num_workers)
+ self.num_samples = len(self.dataset)
+ self.num_blocks = 1
+ self.num_batches = int(np.ceil(self.num_samples / self.batch_size))
+
+ def __len__(self):
+ return self.num_batches
diff --git a/fuxictr/pytorch/dataloaders/rank_dataloader.py b/fuxictr/pytorch/dataloaders/rank_dataloader.py
index 1ca830f..ea2c425 100644
--- a/fuxictr/pytorch/dataloaders/rank_dataloader.py
+++ b/fuxictr/pytorch/dataloaders/rank_dataloader.py
@@ -17,17 +17,24 @@
from .npz_block_dataloader import NpzBlockDataLoader
from .npz_dataloader import NpzDataLoader
+from .parquet_block_dataloader import ParquetBlockDataLoader
+from .parquet_dataloader import ParquetDataLoader
import logging
class RankDataLoader(object):
def __init__(self, feature_map, stage="both", train_data=None, valid_data=None, test_data=None,
- batch_size=32, shuffle=True, streaming=False, **kwargs):
+ batch_size=32, shuffle=True, streaming=False, data_format="npz", **kwargs):
logging.info("Loading datasets...")
train_gen = None
valid_gen = None
test_gen = None
- DataLoader = NpzBlockDataLoader if streaming else NpzDataLoader
+ if data_format == "npz":
+ DataLoader = NpzBlockDataLoader if streaming else NpzDataLoader
+ elif data_format == "parquet":
+ DataLoader = ParquetBlockDataLoader if streaming else ParquetDataLoader
+ else:
+ raise ValueError(f"data_format={data_format} not supported.")
self.stage = stage
if stage in ["both", "train"]:
train_gen = DataLoader(feature_map, train_data, batch_size=batch_size, shuffle=shuffle, **kwargs)
diff --git a/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py b/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py
index 488333f..9ed7bb8 100644
--- a/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py
+++ b/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py
@@ -17,12 +17,12 @@
import torch
from torch import nn
-import h5py
import os
import io
import json
import numpy as np
import logging
+from ...utils import load_pretrain_emb
class PretrainedEmbedding(nn.Module):
@@ -66,17 +66,6 @@ def reset_parameters(self, embedding_initializer):
nn.init.zeros_(self.id_embedding.weight) # set oov token embeddings to zeros
embedding_initializer(self.id_embedding.weight[1:self.oov_idx, :])
- def get_pretrained_embedding(self, pretrain_path):
- logging.info("Loading pretrained_emb: {}".format(pretrain_path))
- if pretrain_path.endswith("h5"):
- with h5py.File(pretrain_path, 'r') as hf:
- keys = hf["key"][:]
- embeddings = hf["value"][:]
- elif pretrain_path.endswith("npz"):
- npz = np.load(pretrain_path)
- keys, embeddings = npz["key"], npz["value"]
- return keys, embeddings
-
def load_feature_vocab(self, vocab_path, feature_name):
with io.open(vocab_path, "r", encoding="utf-8") as fd:
vocab = json.load(fd)
@@ -94,7 +83,8 @@ def load_pretrained_embedding(self, vocab_size, pretrain_dim, pretrain_path, voc
embedding_matrix = np.random.normal(loc=0, scale=1.e-4, size=(vocab_size, pretrain_dim))
if padding_idx:
embedding_matrix[padding_idx, :] = np.zeros(pretrain_dim) # set as zero for PAD
- keys, embeddings = self.get_pretrained_embedding(pretrain_path)
+ logging.info("Loading pretrained_emb: {}".format(pretrain_path))
+ keys, embeddings = load_pretrain_emb(pretrain_path, keys=["key", "value"])
assert embeddings.shape[-1] == pretrain_dim, f"pretrain_dim={pretrain_dim} not correct."
vocab, vocab_type = self.load_feature_vocab(vocab_path, feature_name)
keys = keys.astype(vocab_type) # ensure the same dtype between pretrained keys and vocab keys
diff --git a/fuxictr/pytorch/models/rank_model.py b/fuxictr/pytorch/models/rank_model.py
index 2a5da95..88b5b76 100644
--- a/fuxictr/pytorch/models/rank_model.py
+++ b/fuxictr/pytorch/models/rank_model.py
@@ -1,4 +1,5 @@
# =========================================================================
+# Copyright (C) 2023. FuxiCTR Authors. All rights reserved.
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -110,18 +111,18 @@ def get_inputs(self, inputs, feature_source=None):
continue
if spec["type"] == "meta":
continue
- X_dict[feature] = inputs[:, self.feature_map.get_column_index(feature)].to(self.device)
+ X_dict[feature] = inputs[feature].to(self.device)
return X_dict
def get_labels(self, inputs):
""" Please override get_labels() when using multiple labels!
"""
labels = self.feature_map.labels
- y = inputs[:, self.feature_map.get_column_index(labels[0])].to(self.device)
+ y = inputs[labels[0]].to(self.device)
return y.float().view(-1, 1)
def get_group_id(self, inputs):
- return inputs[:, self.feature_map.get_column_index(self.feature_map.group_id)]
+ return inputs[self.feature_map.group_id]
def model_to_device(self):
self.to(device=self.device)
diff --git a/fuxictr/utils.py b/fuxictr/utils.py
index 0011ff6..7659327 100644
--- a/fuxictr/utils.py
+++ b/fuxictr/utils.py
@@ -20,6 +20,9 @@
import yaml
import glob
import json
+import h5py
+import numpy as np
+import pandas as pd
from collections import OrderedDict
@@ -90,6 +93,7 @@ def print_to_json(data, sort_keys=True):
def print_to_list(data):
return ' - '.join('{}: {:.6f}'.format(k, v) for k, v in data.items())
+
class Monitor(object):
def __init__(self, kv):
if isinstance(kv, str):
@@ -104,3 +108,20 @@ def get_value(self, logs):
def get_metrics(self):
return list(self.kv_pairs.keys())
+
+
+def load_pretrain_emb(pretrain_path, keys=["key", "value"]):
+ if type(keys) != list:
+ keys = [keys]
+ if pretrain_path.endswith("h5"):
+ with h5py.File(pretrain_path, 'r') as hf:
+ values = [hf[k][:] for k in keys]
+ elif pretrain_path.endswith("npz"):
+ npz = np.load(pretrain_path)
+ values = [npz[k] for k in keys]
+ elif pretrain_path.endswith("parquet"):
+ df = pd.read_parquet(pretrain_path)
+ values = [df[k].values for k in keys]
+ else:
+ raise ValueError(f"Embedding format not supported: {pretrain_path}")
+ return tuple(values)
diff --git a/fuxictr/version.py b/fuxictr/version.py
index 6f43348..1108fcc 100644
--- a/fuxictr/version.py
+++ b/fuxictr/version.py
@@ -1 +1 @@
-__version__="2.2.3"
+__version__="2.3.0"
diff --git a/requirements.txt b/requirements.txt
index 6711007..eb1bba3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,6 @@
-torch
keras_preprocessing
-PyYAML
pandas
+PyYAML
scikit-learn
numpy
h5py
diff --git a/setup.py b/setup.py
index 89bab60..9f141d1 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
setuptools.setup(
name="fuxictr",
- version="2.2.3",
+ version="2.3.0",
author="RECZOO",
author_email="reczoo@users.noreply.github.com",
description="A configurable, tunable, and reproducible library for CTR prediction",
@@ -17,7 +17,8 @@
exclude=["model_zoo", "tests", "data", "docs", "demo"]),
include_package_data=True,
python_requires=">=3.6",
- install_requires=["pandas", "numpy", "h5py", "PyYAML>=5.1", "scikit-learn", "tqdm"],
+ install_requires=["keras_preprocessing", "pandas", "PyYAML>=5.1", "scikit-learn",
+ "numpy", "h5py", "tqdm", "pyarrow", "polars"],
classifiers=(
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
@@ -25,7 +26,6 @@
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.7',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',