Skip to content

Commit

Permalink
Add parquet dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
xpai committed Apr 21, 2024
1 parent 776268d commit 281ed5d
Show file tree
Hide file tree
Showing 15 changed files with 236 additions and 56 deletions.
16 changes: 13 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
## FuxiCTR Versions

### FuxiCTR v2.2
### FuxiCTR v2.3
[Doing] Add support for saving pb file, exporting embeddings
[Doing] Add support of NVTabular data
[Doing] Add support of multi-gpu training

**FuxiCTR v2.3.0, 2024-04-20**
+ [Refactor] Support data format of npz and parquet

-------------------------------

### FuxiCTR v2.2

**FuxiCTR v2.2.3, 2024-04-20**
+ [Fix] Quick fix to v2.2.2 that miss one line when committing

**FuxiCTR v2.2.2, 2024-04-18**
**FuxiCTR v2.2.2, 2024-04-18 (Deprecated)**
+ [Feature] Update to use polars instead of pandas for faster feature processing
+ [Fix] When num_workers > 1, NpzBlockDataLoader cannot keep the reading order of samples ([#86](https://github.com/xue-pai/FuxiCTR/issues/86))

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
</div>

<div align="center">
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/badge/python-3.6+-blue" style="max-width: 100%;" alt="Python version"></a>
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/badge/python-3.9+-blue" style="max-width: 100%;" alt="Python version"></a>
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/badge/pytorch-1.10+-blue" style="max-width: 100%;" alt="Pytorch version"></a>
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/badge/tensorflow-2.1+-blue" style="max-width: 100%;" alt="Pytorch version"></a>
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/pypi/v/fuxictr.svg" style="max-width: 100%;" alt="Pypi version"></a>
Expand Down Expand Up @@ -102,7 +102,7 @@ We have benchmarked FuxiCTR models on a set of open datasets as follows:

FuxiCTR has the following dependencies:

+ python 3.6+
+ python 3.9+
+ pytorch 1.10+ (required only for Torch models)
+ tensorflow 2.1+ (required only for TF models)

Expand Down
10 changes: 6 additions & 4 deletions fuxictr/preprocess/feature_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ def read_csv(self, data_path, sep=",", n_rows=None, **kwargs):
logging.info("Reading file: " + data_path)
file_names = sorted(glob.glob(data_path))
assert len(file_names) > 0, f"Invalid data path: {data_path}"
# Require python >= 3.8 for use polars to scan multiple csv files
file_names = file_names[0]
ddf = pl.scan_csv(source=file_names, separator=sep, dtypes=self.dtype_dict,
low_memory=False, n_rows=n_rows)
dfs = [
pl.scan_csv(source=file_name, separator=sep, dtypes=self.dtype_dict,
low_memory=False, n_rows=n_rows)
for file_name in file_names
]
ddf = pl.concat(dfs)
return ddf

def preprocess(self, ddf):
Expand Down
11 changes: 4 additions & 7 deletions fuxictr/preprocess/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from keras_preprocessing.sequence import pad_sequences
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
from ..utils import load_pretrain_emb


class Tokenizer(object):
Expand Down Expand Up @@ -125,13 +126,9 @@ def encode_sequence(self, series):
return np.array(seqs)

def load_pretrained_vocab(self, feature_dtype, pretrain_path, expand_vocab=True):
if pretrain_path.endswith(".h5"):
with h5py.File(pretrain_path, 'r') as hf:
keys = hf["key"][:]
# in case mismatch of dtype between int and str
keys = keys.astype(feature_dtype)
elif pretrain_path.endswith(".npz"):
keys = np.load(pretrain_path)["key"]
keys = load_pretrain_emb(pretrain_path, keys=["key"])
# in case mismatch of dtype between int and str
keys = keys.astype(feature_dtype)
# Update vocab with pretrained keys in case new tokens appear in validation or test set
# Do NOT update OOV index here since it is used in PretrainedEmbedding
if expand_vocab:
Expand Down
38 changes: 26 additions & 12 deletions fuxictr/pytorch/dataloaders/npz_block_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

import numpy as np
from itertools import chain
import torch
from torch.utils.data.dataloader import default_collate
from torch.utils import data
import glob


class BlockDataPipe(data.IterDataPipe):
def __init__(self, block_datapipe, feature_map):
class IterDataPipe(data.IterDataPipe):
def __init__(self, data_blocks, feature_map):
self.feature_map = feature_map
self.block_datapipe = block_datapipe
self.data_blocks = data_blocks

def load_data(self, data_path):
data_dict = np.load(data_path)
Expand All @@ -38,8 +38,7 @@ def load_data(self, data_path):
data_arrays.append(array.reshape(-1, 1))
else:
data_arrays.append(array)
data_tensor = torch.from_numpy(np.hstack(data_arrays))
return data_tensor
return np.hstack(data_arrays)

def read_block(self, data_block):
darray = self.load_data(data_block)
Expand All @@ -49,11 +48,11 @@ def read_block(self, data_block):
def __iter__(self):
worker_info = data.get_worker_info()
if worker_info is None: # single-process data loading
block_list = self.block_datapipe
block_list = self.data_blocks
else: # in a worker process
block_list = [
block
for idx, block in enumerate(self.block_datapipe)
for idx, block in enumerate(self.data_blocks)
if idx % worker_info.num_workers == worker_info.id
]
return chain.from_iterable(map(self.read_block, block_list))
Expand All @@ -71,13 +70,15 @@ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False,
self.feature_map = feature_map
self.batch_size = batch_size
self.num_batches, self.num_samples = self.count_batches_and_samples()
datapipe = BlockDataPipe(data_blocks, feature_map)
datapipe = IterDataPipe(self.data_blocks, feature_map)
if shuffle:
datapipe = datapipe.shuffle(buffer_size=buffer_size)
else:
num_workers = 1 # multiple workers cannot keep the order of data reading
super(NpzBlockDataLoader, self).__init__(dataset=datapipe, batch_size=batch_size,
num_workers=num_workers)
num_workers = 1 # multiple workers cannot keep the order of data reading
super(NpzBlockDataLoader, self).__init__(dataset=datapipe,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=BatchCollator(feature_map))

def __len__(self):
return self.num_batches
Expand All @@ -89,3 +90,16 @@ def count_batches_and_samples(self):
num_samples += block_size
num_batches = int(np.ceil(num_samples / self.batch_size))
return num_batches, num_samples


class BatchCollator(object):
def __init__(self, feature_map):
self.feature_map = feature_map

def __call__(self, batch):
batch_tensor = default_collate(batch)
all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels
batch_dict = dict()
for col in all_cols:
batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)]
return batch_dict
21 changes: 17 additions & 4 deletions fuxictr/pytorch/dataloaders/npz_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np
from torch.utils import data
import torch
from torch.utils.data.dataloader import default_collate


class Dataset(data.Dataset):
Expand All @@ -41,8 +41,7 @@ def load_data(self, data_path):
data_arrays.append(array.reshape(-1, 1))
else:
data_arrays.append(array)
data_tensor = torch.from_numpy(np.hstack(data_arrays))
return data_tensor
return np.hstack(data_arrays)


class NpzDataLoader(data.DataLoader):
Expand All @@ -51,10 +50,24 @@ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, num_wor
data_path += ".npz"
self.dataset = Dataset(feature_map, data_path)
super(NpzDataLoader, self).__init__(dataset=self.dataset, batch_size=batch_size,
shuffle=shuffle, num_workers=num_workers)
shuffle=shuffle, num_workers=num_workers,
collate_fn=BatchCollator(feature_map))
self.num_samples = len(self.dataset)
self.num_blocks = 1
self.num_batches = int(np.ceil(self.num_samples * 1.0 / self.batch_size))

def __len__(self):
return self.num_batches


class BatchCollator(object):
def __init__(self, feature_map):
self.feature_map = feature_map

def __call__(self, batch):
batch_tensor = default_collate(batch)
all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels
batch_dict = dict()
for col in all_cols:
batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)]
return batch_dict
79 changes: 79 additions & 0 deletions fuxictr/pytorch/dataloaders/parquet_block_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# =========================================================================
# Copyright (C) 2023-2024. FuxiCTR Authors. All rights reserved.
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================


import numpy as np
from itertools import chain
from torch.utils import data
import glob
import polars as pl
import pandas as pd


class IterDataPipe(data.IterDataPipe):
def __init__(self, data_blocks, feature_map):
self.feature_map = feature_map
self.data_blocks = data_blocks

def read_block(self, data_block):
data_df = pd.read_parquet(data_block)
for idx in range(len(data_df)):
yield data_df.iloc[idx].to_dict()

def __iter__(self):
worker_info = data.get_worker_info()
if worker_info is None: # single-process data loading
block_list = self.data_blocks
else: # in a worker process
block_list = [
block
for idx, block in enumerate(self.data_blocks)
if idx % worker_info.num_workers == worker_info.id
]
return chain.from_iterable(map(self.read_block, block_list))


class ParquetBlockDataLoader(data.DataLoader):
def __init__(self, feature_map, data_path, batch_size=32, shuffle=False,
num_workers=1, buffer_size=100000, **kwargs):
data_blocks = glob.glob(data_path + "/*.parquet")
assert len(data_blocks) > 0, f"invalid data_path: {data_path}"
if len(data_blocks) > 1:
data_blocks.sort() # sort by part name
self.data_blocks = data_blocks
self.num_blocks = len(self.data_blocks)
self.feature_map = feature_map
self.batch_size = batch_size
self.num_batches, self.num_samples = self.count_batches_and_samples()
datapipe = IterDataPipe(self.data_blocks, feature_map)
if shuffle:
datapipe = datapipe.shuffle(buffer_size=buffer_size)
else:
num_workers = 1 # multiple workers cannot keep the order of data reading
super().__init__(dataset=datapipe, batch_size=batch_size,
num_workers=num_workers)

def __len__(self):
return self.num_batches

def count_batches_and_samples(self):
num_samples = 0
for data_block in self.data_blocks:
df = pl.scan_parquet(data_block)
num_samples += df.select(pl.count()).collect().item()
num_batches = int(np.ceil(num_samples / self.batch_size))
return num_batches, num_samples
47 changes: 47 additions & 0 deletions fuxictr/pytorch/dataloaders/parquet_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# =========================================================================
# Copyright (C) 2024. FuxiCTR Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================


import numpy as np
from torch.utils import data
import pandas as pd


class Dataset(data.Dataset):
def __init__(self, feature_map, data_path):
self.feature_map = feature_map
self.data_df = pd.read_parquet(data_path)

def __getitem__(self, index):
return self.data_df.iloc[index].to_dict()

def __len__(self):
return len(self.data_df)


class ParquetDataLoader(data.DataLoader):
def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, num_workers=1, **kwargs):
if not data_path.endswith(".parquet"):
data_path += ".parquet"
self.dataset = Dataset(feature_map, data_path)
super().__init__(dataset=self.dataset, batch_size=batch_size,
shuffle=shuffle, num_workers=num_workers)
self.num_samples = len(self.dataset)
self.num_blocks = 1
self.num_batches = int(np.ceil(self.num_samples / self.batch_size))

def __len__(self):
return self.num_batches
11 changes: 9 additions & 2 deletions fuxictr/pytorch/dataloaders/rank_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@

from .npz_block_dataloader import NpzBlockDataLoader
from .npz_dataloader import NpzDataLoader
from .parquet_block_dataloader import ParquetBlockDataLoader
from .parquet_dataloader import ParquetDataLoader
import logging


class RankDataLoader(object):
def __init__(self, feature_map, stage="both", train_data=None, valid_data=None, test_data=None,
batch_size=32, shuffle=True, streaming=False, **kwargs):
batch_size=32, shuffle=True, streaming=False, data_format="npz", **kwargs):
logging.info("Loading datasets...")
train_gen = None
valid_gen = None
test_gen = None
DataLoader = NpzBlockDataLoader if streaming else NpzDataLoader
if data_format == "npz":
DataLoader = NpzBlockDataLoader if streaming else NpzDataLoader
elif data_format == "parquet":
DataLoader = ParquetBlockDataLoader if streaming else ParquetDataLoader
else:
raise ValueError(f"data_format={data_format} not supported.")
self.stage = stage
if stage in ["both", "train"]:
train_gen = DataLoader(feature_map, train_data, batch_size=batch_size, shuffle=shuffle, **kwargs)
Expand Down
Loading

0 comments on commit 281ed5d

Please sign in to comment.