Skip to content

Update jsonl dataset #1027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion xtuner/_lite/datasets/cache.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import hashlib
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, TypedDict

from transformers import PreTrainedTokenizer

from xtuner._lite import get_logger


class CacheObj(TypedDict, total=False):
num_tokens: int
Expand All @@ -23,8 +26,15 @@ def hash(self) -> str:


def calculate_file_sha256(path):
hash_method = os.environ.get("HASH_METHOD", "sha256")
assert (
hash_method in hashlib.algorithms_guaranteed
), f"hash method {hash_method} not supported"
if hash_method != "sha256":
get_logger().warning(f"Non-default hash method {hash_method} is used")
hash_method = getattr(hashlib, hash_method)
with open(path, "rb") as f:
file_hash = hashlib.sha256()
file_hash = hash_method()
file_hash.update(f.read())
return file_hash.hexdigest()

Expand Down
12 changes: 9 additions & 3 deletions xtuner/_lite/datasets/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import signal
import sys
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from typing import Any, Callable

import numpy as np
Expand Down Expand Up @@ -137,11 +138,14 @@ def count_offsets(self, cache_dir=None):

return offsets

def _tokenize_by_offset(self, offset):
def _tokenize_by_offset(self, offset, only_num_tokens=False):
with open(self.path) as f:
f.seek(offset)
data = json.loads(f.readline())
return self.tokenize_fn(data)
tokenize = self.tokenize_fn(data)
if only_num_tokens:
tokenize = {"num_tokens": tokenize["num_tokens"]}
return tokenize

def count_tokens(self, offsets, cache_dir=None):
num_samples = len(offsets)
Expand Down Expand Up @@ -169,7 +173,9 @@ def count_tokens(self, offsets, cache_dir=None):
tokenized = list(
tqdm(
executor.map(
self._tokenize_by_offset, offsets_shard, chunksize=chunk_size
partial(self._tokenize_by_offset, only_num_tokens=True),
offsets_shard,
chunksize=chunk_size,
),
desc=desc,
total=len(offsets_shard),
Expand Down