diff --git a/CHANGELOG.md b/CHANGELOG.md index 74ebf69..be05abc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,20 @@ ## FuxiCTR Versions -### FuxiCTR v2.2 +### FuxiCTR v2.3 [Doing] Add support for saving pb file, exporting embeddings -[Doing] Add support of NVTabular data +[Doing] Add support of multi-gpu training + +**FuxiCTR v2.3.0, 2024-04-20** ++ [Refactor] Support data format of npz and parquet + +------------------------------- + +### FuxiCTR v2.2 + +**FuxiCTR v2.2.3, 2024-04-20** ++ [Fix] Quick fix to v2.2.2 that miss one line when committing -**FuxiCTR v2.2.2, 2024-04-18** +**FuxiCTR v2.2.2, 2024-04-18 (Deprecated)** + [Feature] Update to use polars instead of pandas for faster feature processing + [Fix] When num_workers > 1, NpzBlockDataLoader cannot keep the reading order of samples ([#86](https://github.com/xue-pai/FuxiCTR/issues/86)) diff --git a/README.md b/README.md index 0b809bc..7fa7733 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
-Python version +Python version Pytorch version Pytorch version Pypi version @@ -102,7 +102,7 @@ We have benchmarked FuxiCTR models on a set of open datasets as follows: FuxiCTR has the following dependencies: -+ python 3.6+ ++ python 3.9+ + pytorch 1.10+ (required only for Torch models) + tensorflow 2.1+ (required only for TF models) diff --git a/fuxictr/preprocess/feature_processor.py b/fuxictr/preprocess/feature_processor.py index 8072c19..da254e7 100644 --- a/fuxictr/preprocess/feature_processor.py +++ b/fuxictr/preprocess/feature_processor.py @@ -71,10 +71,12 @@ def read_csv(self, data_path, sep=",", n_rows=None, **kwargs): logging.info("Reading file: " + data_path) file_names = sorted(glob.glob(data_path)) assert len(file_names) > 0, f"Invalid data path: {data_path}" - # Require python >= 3.8 for use polars to scan multiple csv files - file_names = file_names[0] - ddf = pl.scan_csv(source=file_names, separator=sep, dtypes=self.dtype_dict, - low_memory=False, n_rows=n_rows) + dfs = [ + pl.scan_csv(source=file_name, separator=sep, dtypes=self.dtype_dict, + low_memory=False, n_rows=n_rows) + for file_name in file_names + ] + ddf = pl.concat(dfs) return ddf def preprocess(self, ddf): diff --git a/fuxictr/preprocess/tokenizer.py b/fuxictr/preprocess/tokenizer.py index c23bd59..f518395 100644 --- a/fuxictr/preprocess/tokenizer.py +++ b/fuxictr/preprocess/tokenizer.py @@ -22,6 +22,7 @@ from keras_preprocessing.sequence import pad_sequences from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing as mp +from ..utils import load_pretrain_emb class Tokenizer(object): @@ -125,13 +126,9 @@ def encode_sequence(self, series): return np.array(seqs) def load_pretrained_vocab(self, feature_dtype, pretrain_path, expand_vocab=True): - if pretrain_path.endswith(".h5"): - with h5py.File(pretrain_path, 'r') as hf: - keys = hf["key"][:] - # in case mismatch of dtype between int and str - keys = keys.astype(feature_dtype) - elif pretrain_path.endswith(".npz"): - keys = np.load(pretrain_path)["key"] + keys = load_pretrain_emb(pretrain_path, keys=["key"]) + # in case mismatch of dtype between int and str + keys = keys.astype(feature_dtype) # Update vocab with pretrained keys in case new tokens appear in validation or test set # Do NOT update OOV index here since it is used in PretrainedEmbedding if expand_vocab: diff --git a/fuxictr/pytorch/dataloaders/npz_block_dataloader.py b/fuxictr/pytorch/dataloaders/npz_block_dataloader.py index eb81592..0503099 100644 --- a/fuxictr/pytorch/dataloaders/npz_block_dataloader.py +++ b/fuxictr/pytorch/dataloaders/npz_block_dataloader.py @@ -18,15 +18,15 @@ import numpy as np from itertools import chain -import torch +from torch.utils.data.dataloader import default_collate from torch.utils import data import glob -class BlockDataPipe(data.IterDataPipe): - def __init__(self, block_datapipe, feature_map): +class IterDataPipe(data.IterDataPipe): + def __init__(self, data_blocks, feature_map): self.feature_map = feature_map - self.block_datapipe = block_datapipe + self.data_blocks = data_blocks def load_data(self, data_path): data_dict = np.load(data_path) @@ -38,8 +38,7 @@ def load_data(self, data_path): data_arrays.append(array.reshape(-1, 1)) else: data_arrays.append(array) - data_tensor = torch.from_numpy(np.hstack(data_arrays)) - return data_tensor + return np.hstack(data_arrays) def read_block(self, data_block): darray = self.load_data(data_block) @@ -49,11 +48,11 @@ def read_block(self, data_block): def __iter__(self): worker_info = data.get_worker_info() if worker_info is None: # single-process data loading - block_list = self.block_datapipe + block_list = self.data_blocks else: # in a worker process block_list = [ block - for idx, block in enumerate(self.block_datapipe) + for idx, block in enumerate(self.data_blocks) if idx % worker_info.num_workers == worker_info.id ] return chain.from_iterable(map(self.read_block, block_list)) @@ -71,13 +70,15 @@ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, self.feature_map = feature_map self.batch_size = batch_size self.num_batches, self.num_samples = self.count_batches_and_samples() - datapipe = BlockDataPipe(data_blocks, feature_map) + datapipe = IterDataPipe(self.data_blocks, feature_map) if shuffle: datapipe = datapipe.shuffle(buffer_size=buffer_size) else: - num_workers = 1 # multiple workers cannot keep the order of data reading - super(NpzBlockDataLoader, self).__init__(dataset=datapipe, batch_size=batch_size, - num_workers=num_workers) + num_workers = 1 # multiple workers cannot keep the order of data reading + super(NpzBlockDataLoader, self).__init__(dataset=datapipe, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=BatchCollator(feature_map)) def __len__(self): return self.num_batches @@ -89,3 +90,16 @@ def count_batches_and_samples(self): num_samples += block_size num_batches = int(np.ceil(num_samples / self.batch_size)) return num_batches, num_samples + + +class BatchCollator(object): + def __init__(self, feature_map): + self.feature_map = feature_map + + def __call__(self, batch): + batch_tensor = default_collate(batch) + all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels + batch_dict = dict() + for col in all_cols: + batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)] + return batch_dict diff --git a/fuxictr/pytorch/dataloaders/npz_dataloader.py b/fuxictr/pytorch/dataloaders/npz_dataloader.py index e5740a1..3822e76 100644 --- a/fuxictr/pytorch/dataloaders/npz_dataloader.py +++ b/fuxictr/pytorch/dataloaders/npz_dataloader.py @@ -17,7 +17,7 @@ import numpy as np from torch.utils import data -import torch +from torch.utils.data.dataloader import default_collate class Dataset(data.Dataset): @@ -41,8 +41,7 @@ def load_data(self, data_path): data_arrays.append(array.reshape(-1, 1)) else: data_arrays.append(array) - data_tensor = torch.from_numpy(np.hstack(data_arrays)) - return data_tensor + return np.hstack(data_arrays) class NpzDataLoader(data.DataLoader): @@ -51,10 +50,24 @@ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, num_wor data_path += ".npz" self.dataset = Dataset(feature_map, data_path) super(NpzDataLoader, self).__init__(dataset=self.dataset, batch_size=batch_size, - shuffle=shuffle, num_workers=num_workers) + shuffle=shuffle, num_workers=num_workers, + collate_fn=BatchCollator(feature_map)) self.num_samples = len(self.dataset) self.num_blocks = 1 self.num_batches = int(np.ceil(self.num_samples * 1.0 / self.batch_size)) def __len__(self): return self.num_batches + + +class BatchCollator(object): + def __init__(self, feature_map): + self.feature_map = feature_map + + def __call__(self, batch): + batch_tensor = default_collate(batch) + all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels + batch_dict = dict() + for col in all_cols: + batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)] + return batch_dict diff --git a/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py b/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py new file mode 100644 index 0000000..8abb234 --- /dev/null +++ b/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py @@ -0,0 +1,79 @@ +# ========================================================================= +# Copyright (C) 2023-2024. FuxiCTR Authors. All rights reserved. +# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +import numpy as np +from itertools import chain +from torch.utils import data +import glob +import polars as pl +import pandas as pd + + +class IterDataPipe(data.IterDataPipe): + def __init__(self, data_blocks, feature_map): + self.feature_map = feature_map + self.data_blocks = data_blocks + + def read_block(self, data_block): + data_df = pd.read_parquet(data_block) + for idx in range(len(data_df)): + yield data_df.iloc[idx].to_dict() + + def __iter__(self): + worker_info = data.get_worker_info() + if worker_info is None: # single-process data loading + block_list = self.data_blocks + else: # in a worker process + block_list = [ + block + for idx, block in enumerate(self.data_blocks) + if idx % worker_info.num_workers == worker_info.id + ] + return chain.from_iterable(map(self.read_block, block_list)) + + +class ParquetBlockDataLoader(data.DataLoader): + def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, + num_workers=1, buffer_size=100000, **kwargs): + data_blocks = glob.glob(data_path + "/*.parquet") + assert len(data_blocks) > 0, f"invalid data_path: {data_path}" + if len(data_blocks) > 1: + data_blocks.sort() # sort by part name + self.data_blocks = data_blocks + self.num_blocks = len(self.data_blocks) + self.feature_map = feature_map + self.batch_size = batch_size + self.num_batches, self.num_samples = self.count_batches_and_samples() + datapipe = IterDataPipe(self.data_blocks, feature_map) + if shuffle: + datapipe = datapipe.shuffle(buffer_size=buffer_size) + else: + num_workers = 1 # multiple workers cannot keep the order of data reading + super().__init__(dataset=datapipe, batch_size=batch_size, + num_workers=num_workers) + + def __len__(self): + return self.num_batches + + def count_batches_and_samples(self): + num_samples = 0 + for data_block in self.data_blocks: + df = pl.scan_parquet(data_block) + num_samples += df.select(pl.count()).collect().item() + num_batches = int(np.ceil(num_samples / self.batch_size)) + return num_batches, num_samples diff --git a/fuxictr/pytorch/dataloaders/parquet_dataloader.py b/fuxictr/pytorch/dataloaders/parquet_dataloader.py new file mode 100644 index 0000000..a4612fb --- /dev/null +++ b/fuxictr/pytorch/dataloaders/parquet_dataloader.py @@ -0,0 +1,47 @@ +# ========================================================================= +# Copyright (C) 2024. FuxiCTR Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +import numpy as np +from torch.utils import data +import pandas as pd + + +class Dataset(data.Dataset): + def __init__(self, feature_map, data_path): + self.feature_map = feature_map + self.data_df = pd.read_parquet(data_path) + + def __getitem__(self, index): + return self.data_df.iloc[index].to_dict() + + def __len__(self): + return len(self.data_df) + + +class ParquetDataLoader(data.DataLoader): + def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, num_workers=1, **kwargs): + if not data_path.endswith(".parquet"): + data_path += ".parquet" + self.dataset = Dataset(feature_map, data_path) + super().__init__(dataset=self.dataset, batch_size=batch_size, + shuffle=shuffle, num_workers=num_workers) + self.num_samples = len(self.dataset) + self.num_blocks = 1 + self.num_batches = int(np.ceil(self.num_samples / self.batch_size)) + + def __len__(self): + return self.num_batches diff --git a/fuxictr/pytorch/dataloaders/rank_dataloader.py b/fuxictr/pytorch/dataloaders/rank_dataloader.py index 1ca830f..ea2c425 100644 --- a/fuxictr/pytorch/dataloaders/rank_dataloader.py +++ b/fuxictr/pytorch/dataloaders/rank_dataloader.py @@ -17,17 +17,24 @@ from .npz_block_dataloader import NpzBlockDataLoader from .npz_dataloader import NpzDataLoader +from .parquet_block_dataloader import ParquetBlockDataLoader +from .parquet_dataloader import ParquetDataLoader import logging class RankDataLoader(object): def __init__(self, feature_map, stage="both", train_data=None, valid_data=None, test_data=None, - batch_size=32, shuffle=True, streaming=False, **kwargs): + batch_size=32, shuffle=True, streaming=False, data_format="npz", **kwargs): logging.info("Loading datasets...") train_gen = None valid_gen = None test_gen = None - DataLoader = NpzBlockDataLoader if streaming else NpzDataLoader + if data_format == "npz": + DataLoader = NpzBlockDataLoader if streaming else NpzDataLoader + elif data_format == "parquet": + DataLoader = ParquetBlockDataLoader if streaming else ParquetDataLoader + else: + raise ValueError(f"data_format={data_format} not supported.") self.stage = stage if stage in ["both", "train"]: train_gen = DataLoader(feature_map, train_data, batch_size=batch_size, shuffle=shuffle, **kwargs) diff --git a/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py b/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py index 488333f..9ed7bb8 100644 --- a/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py +++ b/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py @@ -17,12 +17,12 @@ import torch from torch import nn -import h5py import os import io import json import numpy as np import logging +from ...utils import load_pretrain_emb class PretrainedEmbedding(nn.Module): @@ -66,17 +66,6 @@ def reset_parameters(self, embedding_initializer): nn.init.zeros_(self.id_embedding.weight) # set oov token embeddings to zeros embedding_initializer(self.id_embedding.weight[1:self.oov_idx, :]) - def get_pretrained_embedding(self, pretrain_path): - logging.info("Loading pretrained_emb: {}".format(pretrain_path)) - if pretrain_path.endswith("h5"): - with h5py.File(pretrain_path, 'r') as hf: - keys = hf["key"][:] - embeddings = hf["value"][:] - elif pretrain_path.endswith("npz"): - npz = np.load(pretrain_path) - keys, embeddings = npz["key"], npz["value"] - return keys, embeddings - def load_feature_vocab(self, vocab_path, feature_name): with io.open(vocab_path, "r", encoding="utf-8") as fd: vocab = json.load(fd) @@ -94,7 +83,8 @@ def load_pretrained_embedding(self, vocab_size, pretrain_dim, pretrain_path, voc embedding_matrix = np.random.normal(loc=0, scale=1.e-4, size=(vocab_size, pretrain_dim)) if padding_idx: embedding_matrix[padding_idx, :] = np.zeros(pretrain_dim) # set as zero for PAD - keys, embeddings = self.get_pretrained_embedding(pretrain_path) + logging.info("Loading pretrained_emb: {}".format(pretrain_path)) + keys, embeddings = load_pretrain_emb(pretrain_path, keys=["key", "value"]) assert embeddings.shape[-1] == pretrain_dim, f"pretrain_dim={pretrain_dim} not correct." vocab, vocab_type = self.load_feature_vocab(vocab_path, feature_name) keys = keys.astype(vocab_type) # ensure the same dtype between pretrained keys and vocab keys diff --git a/fuxictr/pytorch/models/rank_model.py b/fuxictr/pytorch/models/rank_model.py index 2a5da95..88b5b76 100644 --- a/fuxictr/pytorch/models/rank_model.py +++ b/fuxictr/pytorch/models/rank_model.py @@ -1,4 +1,5 @@ # ========================================================================= +# Copyright (C) 2023. FuxiCTR Authors. All rights reserved. # Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -110,18 +111,18 @@ def get_inputs(self, inputs, feature_source=None): continue if spec["type"] == "meta": continue - X_dict[feature] = inputs[:, self.feature_map.get_column_index(feature)].to(self.device) + X_dict[feature] = inputs[feature].to(self.device) return X_dict def get_labels(self, inputs): """ Please override get_labels() when using multiple labels! """ labels = self.feature_map.labels - y = inputs[:, self.feature_map.get_column_index(labels[0])].to(self.device) + y = inputs[labels[0]].to(self.device) return y.float().view(-1, 1) def get_group_id(self, inputs): - return inputs[:, self.feature_map.get_column_index(self.feature_map.group_id)] + return inputs[self.feature_map.group_id] def model_to_device(self): self.to(device=self.device) diff --git a/fuxictr/utils.py b/fuxictr/utils.py index 0011ff6..7659327 100644 --- a/fuxictr/utils.py +++ b/fuxictr/utils.py @@ -20,6 +20,9 @@ import yaml import glob import json +import h5py +import numpy as np +import pandas as pd from collections import OrderedDict @@ -90,6 +93,7 @@ def print_to_json(data, sort_keys=True): def print_to_list(data): return ' - '.join('{}: {:.6f}'.format(k, v) for k, v in data.items()) + class Monitor(object): def __init__(self, kv): if isinstance(kv, str): @@ -104,3 +108,20 @@ def get_value(self, logs): def get_metrics(self): return list(self.kv_pairs.keys()) + + +def load_pretrain_emb(pretrain_path, keys=["key", "value"]): + if type(keys) != list: + keys = [keys] + if pretrain_path.endswith("h5"): + with h5py.File(pretrain_path, 'r') as hf: + values = [hf[k][:] for k in keys] + elif pretrain_path.endswith("npz"): + npz = np.load(pretrain_path) + values = [npz[k] for k in keys] + elif pretrain_path.endswith("parquet"): + df = pd.read_parquet(pretrain_path) + values = [df[k].values for k in keys] + else: + raise ValueError(f"Embedding format not supported: {pretrain_path}") + return tuple(values) diff --git a/fuxictr/version.py b/fuxictr/version.py index 6f43348..1108fcc 100644 --- a/fuxictr/version.py +++ b/fuxictr/version.py @@ -1 +1 @@ -__version__="2.2.3" +__version__="2.3.0" diff --git a/requirements.txt b/requirements.txt index 6711007..eb1bba3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -torch keras_preprocessing -PyYAML pandas +PyYAML scikit-learn numpy h5py diff --git a/setup.py b/setup.py index 89bab60..9f141d1 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="fuxictr", - version="2.2.3", + version="2.3.0", author="RECZOO", author_email="reczoo@users.noreply.github.com", description="A configurable, tunable, and reproducible library for CTR prediction", @@ -17,7 +17,8 @@ exclude=["model_zoo", "tests", "data", "docs", "demo"]), include_package_data=True, python_requires=">=3.6", - install_requires=["pandas", "numpy", "h5py", "PyYAML>=5.1", "scikit-learn", "tqdm"], + install_requires=["keras_preprocessing", "pandas", "PyYAML>=5.1", "scikit-learn", + "numpy", "h5py", "tqdm", "pyarrow", "polars"], classifiers=( "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", @@ -25,7 +26,6 @@ 'Intended Audience :: Education', 'Intended Audience :: Science/Research', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development',