diff --git a/xtuner/_lite/datasets/cache.py b/xtuner/_lite/datasets/cache.py index 9de84eab7..a46ef47ce 100644 --- a/xtuner/_lite/datasets/cache.py +++ b/xtuner/_lite/datasets/cache.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import hashlib +import os import tempfile from abc import ABC, abstractmethod from pathlib import Path @@ -7,6 +8,8 @@ from transformers import PreTrainedTokenizer +from xtuner._lite import get_logger + class CacheObj(TypedDict, total=False): num_tokens: int @@ -23,8 +26,15 @@ def hash(self) -> str: def calculate_file_sha256(path): + hash_method = os.environ.get("HASH_METHOD", "sha256") + assert ( + hash_method in hashlib.algorithms_guaranteed + ), f"hash method {hash_method} not supported" + if hash_method != "sha256": + get_logger().warning(f"Non-default hash method {hash_method} is used") + hash_method = getattr(hashlib, hash_method) with open(path, "rb") as f: - file_hash = hashlib.sha256() + file_hash = hash_method() file_hash.update(f.read()) return file_hash.hexdigest() diff --git a/xtuner/_lite/datasets/jsonl.py b/xtuner/_lite/datasets/jsonl.py index 1e96e5114..52f78fa95 100644 --- a/xtuner/_lite/datasets/jsonl.py +++ b/xtuner/_lite/datasets/jsonl.py @@ -8,6 +8,7 @@ import signal import sys from concurrent.futures import ProcessPoolExecutor +from functools import partial from typing import Any, Callable import numpy as np @@ -137,11 +138,14 @@ def count_offsets(self, cache_dir=None): return offsets - def _tokenize_by_offset(self, offset): + def _tokenize_by_offset(self, offset, only_num_tokens=False): with open(self.path) as f: f.seek(offset) data = json.loads(f.readline()) - return self.tokenize_fn(data) + tokenize = self.tokenize_fn(data) + if only_num_tokens: + tokenize = {"num_tokens": tokenize["num_tokens"]} + return tokenize def count_tokens(self, offsets, cache_dir=None): num_samples = len(offsets) @@ -169,7 +173,9 @@ def count_tokens(self, offsets, cache_dir=None): tokenized = list( tqdm( executor.map( - self._tokenize_by_offset, offsets_shard, chunksize=chunk_size + partial(self._tokenize_by_offset, only_num_tokens=True), + offsets_shard, + chunksize=chunk_size, ), desc=desc, total=len(offsets_shard),