Skip to content

fix enable_cache #4075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
138 changes: 55 additions & 83 deletions swift/llm/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,106 +109,64 @@ def __len__(self) -> int:
return len(self.dataset)


class BasePackingDataset:
def calculate_matched_group(template, sequences, is_finished: bool = True):
if len(sequences) == 0:
return [], []
# https://arxiv.org/pdf/2404.10830
import binpacking
sequences = binpacking.to_constant_volume(sequences, template.max_length, weight_pos=1)
res = []
if sequences and not is_finished:
sequences, ret_sequences = sequences[:-1], sequences[-1]
else:
ret_sequences = []
for row in sequences:
packed = template.packing_row(row)
res.append(packed)
return res, ret_sequences


class PackingDataset(Dataset):

def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False):
template._packing = True
self.template = template
self.dataset = dataset
self.num_proc = num_proc
self.packing_interval = packing_interval
self.strict = strict
assert num_proc >= 1, f'num_proc: {num_proc}'
self.workers = []

@staticmethod
def calculate_matched_group(template, sequences, is_finished: bool = True):
if len(sequences) == 0:
return [], []
# https://arxiv.org/pdf/2404.10830
import binpacking
sequences = binpacking.to_constant_volume(sequences, template.max_length, weight_pos=1)
res = []
if sequences and not is_finished:
sequences, ret_sequences = sequences[:-1], sequences[-1]
else:
ret_sequences = []
for row in sequences:
packed = template.packing_row(row)
res.append(packed)
return res, ret_sequences

def _encode_data(self, data) -> Dict[str, Any]:
res = None
try:
res = self.template.encode(data)
except Exception as e:
if self.strict and not isinstance(e, MaxLengthError):
raise
return res or {}


class PackingDataset(BasePackingDataset, Dataset):

def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False):
super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict)
self.prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc='Packing')
self._queue = mp.Queue()
self._terminated_workers = 0
for i in range(self.num_proc):
shard_dataset = self.dataset.shard(self.num_proc, i)
worker = mp.Process(target=self._producer, args=(shard_dataset, ), daemon=True)
worker.start()
self.workers.append(worker)

self.packed_dataset = self.get_packed_dataset()
self.prog_bar.close()
for worker in self.workers:
worker.terminate()

def fetch_packing_data(self, res: Optional[list] = None):
res = res or []
for _ in range(self.packing_interval):
data = self._queue.get()
if data is None:
self._terminated_workers += 1
if self._terminated_workers == self.num_proc:
break
continue
self.prog_bar.update(1)
if data:
res.append((data, len(data['input_ids'])))
return res
dataset = dataset.to_iterable_dataset(num_shards=num_proc)
dataset = EncodePreprocessor(template)(dataset, num_proc=num_proc, strict=strict)
self.packed_dataset = self.get_packed_dataset(dataset)

def get_packed_dataset(self):
data = []
def get_packed_dataset(self, dataset):
data_list = []
result = []
while True:
data = self.fetch_packing_data(data)
is_finished = self._terminated_workers == self.num_proc
res, data = self.calculate_matched_group(self.template, data, is_finished=is_finished)
it = iter(dataset)
is_finished = False
prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc=f'Packing (num_proc={num_proc}):')

while not is_finished:
try:
for _ in range(self.packing_interval):
data = next(it)
prog_bar.update(1)
data_list.append((data, len(data['input_ids'])))
except StopIteration:
is_finished = True
res, data = calculate_matched_group(self.template, data_list, is_finished=is_finished)
result += res
if is_finished:
break
prog_bar.close()
return result

def _producer(self, shard_dataset):
for data in shard_dataset:
encoded_data = self._encode_data(data) # ignore
self._queue.put(encoded_data)
self._queue.put(None)
while True:
# Wait for the main process to terminate to avoid fd anomalies.
time.sleep(0.1)

def __getitem__(self, index):
return self.packed_dataset[index].copy()

def __len__(self):
return len(self.packed_dataset)


class IterablePackingDataset(BasePackingDataset, IterableDataset):
class IterablePackingDataset(IterableDataset):

def __init__(self,
template,
Expand All @@ -218,16 +176,30 @@ def __init__(self,
packing_interval: int = 128,
strict: bool = False,
cyclic: bool = False):
super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict)
self.template = template
self.dataset = dataset
self.num_proc = num_proc
self.packing_interval = packing_interval
self.strict = strict
self.cyclic = cyclic

self._in_queue = mp.Queue()
self._out_queue = mp.Queue()
self.workers = []
self.cyclic = cyclic
for _ in range(self.num_proc):
worker = mp.Process(target=self._processor, daemon=True)
worker.start()
self.workers.append(worker)

def _encode_data(self, data) -> Dict[str, Any]:
res = None
try:
res = self.template.encode(data)
except Exception as e:
if self.strict and not isinstance(e, MaxLengthError):
raise
return res or {}

def _processor(self):
while True:
data = self._in_queue.get()
Expand Down Expand Up @@ -276,7 +248,7 @@ def __iter__(self):
while True:
finished = self._put_data_in_queue(iterator)
data = self._fetch_data_out_queue(data)
res, data = self.calculate_matched_group(self.template, data, is_finished=finished)
res, data = calculate_matched_group(self.template, data, is_finished=finished)
yield from res
if finished:
break
Expand Down