Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Oct 29, 2024
1 parent 2d92936 commit 46e65a6
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 33 deletions.
16 changes: 4 additions & 12 deletions classifiers/src/dolma_classifiers/inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import multiprocessing as mp
import time
from collections import defaultdict
from functools import partial
from itertools import zip_longest
from multiprocessing import Event, Process
from queue import Empty
from queue import Queue as QueueType
from threading import Event as EventType
from typing import Any, Generator, NamedTuple
from urllib.parse import urlparse

Expand All @@ -25,7 +25,7 @@
from transformers import BatchEncoding, PreTrainedTokenizer

from .loggers import ProgressLogger, WandbLogger, get_logger
from .models import Prediction, Registry
from .models import Registry
from .utils import cleanup, get_local_gpu_rank, sanitize_model_name, setup


Expand Down Expand Up @@ -99,13 +99,11 @@ def __iter__(self) -> Generator[Batch, None, None]:
self.output_paths_queue.put(OutputPath(source=path, count=count))



def collate_batch(batch: list[Batch], pad_token_id: int) -> Batch:
max_lengths = [len(b.encoding['input_ids'][0]) for b in batch] # pyright: ignore
padded_encodings = {
key: pad_sequence(
# assuming first dimension is batch size
[b.encoding[key][-1,:] for b in batch], # pyright: ignore
[b.encoding[key][-1, :] for b in batch], # pyright: ignore
batch_first=True,
padding_value=pad_token_id,
)
Expand All @@ -119,14 +117,13 @@ def collate_batch(batch: list[Batch], pad_token_id: int) -> Batch:
)



class AttributeRow(NamedTuple):
sources: list[str]
attributes: list[dict[str, Any]]


def writer_worker(
error_event: Event,
error_event: EventType,
scores_queue: QueueType[AttributeRow | None],
output_paths_queue: QueueType[OutputPath],
source_destination_mapping: dict[str, str],
Expand Down Expand Up @@ -218,8 +215,6 @@ def process_documents(
suffix: str | None = None
):
"""Processes a batch of files using distributed processing."""
console_logger = get_logger("process_documents")


classifier = Registry.get(
model_name=model_name,
Expand All @@ -232,9 +227,6 @@ def process_documents(
# to check if destination path exists (file already processed)
fs = fsspec.get_filesystem_class(urlparse(source_paths[0]).scheme)()

# this encoder will be used to write the attributes to the destination file
encoder = msgspec.json.Encoder()

source_destination_mapping = {
source_path: destination_path
for source_path, destination_path in zip(source_paths, destination_paths)
Expand Down
12 changes: 6 additions & 6 deletions classifiers/src/dolma_classifiers/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import NamedTuple, Type

import torch
Expand All @@ -16,7 +15,7 @@
from transformers.modeling_outputs import SequenceClassifierOutput

from .loggers import get_logger
from .utils import get_local_gpu_rank, sanitize_model_name
from .utils import sanitize_model_name


class Prediction(NamedTuple):
Expand All @@ -43,12 +42,14 @@ def __init__(
compile=compile,
trust_remote_code=trust_remote_code,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name) # pyright: ignore

if len(self.model.config.id2label) > 1:
label_name_fn = lambda label: f"{sanitize_model_name(model_name)}_{sanitize_model_name(label)}"
def label_name_fn(label: str):
return f"{sanitize_model_name(model_name)}_{sanitize_model_name(label)}"
else:
label_name_fn = lambda label: sanitize_model_name(model_name)
def label_name_fn(label: str):
return sanitize_model_name(model_name)

self.labels_map = {
id_: label_name_fn(label)
Expand Down Expand Up @@ -137,7 +138,6 @@ def forward(self, input_ids, attention_mask, **kwargs):
return SequenceClassifierOutput(logits=outputs[:, 0, :])



@Registry.add("nvidia/quality-classifier-deberta")
class DebertaQualityClassifier(BaseQualityClassifier):
def _make_model(
Expand Down
38 changes: 23 additions & 15 deletions classifiers/src/dolma_classifiers/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import multiprocessing
from dataclasses import dataclass
from typing import Callable, Generator, NamedTuple
from functools import partial
from typing import Callable
from urllib.parse import urlparse

import fsspec
import jq
import smart_open
import torch
from msgspec.json import Decoder
from torch.utils.data import Dataset
from tqdm import tqdm
Expand All @@ -18,15 +18,18 @@ class Document:
label: str


def read_file(path: str, label: str | None = None, selector: str | None = None) -> list[Document]:
def _label_selector_fn(row: dict, selector: Callable | None, label: str | None) -> str:
if selector is not None:
compiled_selector = jq.compile(selector)
label_fn = lambda row: str(compiled_selector.input(row).first())
return str(selector(row).first())
elif label is not None:
label_fn = lambda row: str(label)
return str(label)
else:
raise ValueError("Either `label` or `selector` must be provided")


def read_file(path: str, label: str | None = None, selector: str | None = None) -> list[Document]:
label_fn = partial(_label_selector_fn, label=label, selector=(jq.compile(selector) if selector else None))

decoder = Decoder()
documents = []

Expand All @@ -45,10 +48,12 @@ class DataConfig:
label: str | None = None
selector: str | None = None

def expand(self, fs: fsspec.AbstractFileSystem | None = None) -> list["DataConfig"]:
fs = fs or fsspec.get_filesystem_class(urlparse(self.path).scheme)()
paths = [str(p) for p in fs.glob(self.path)] if "*" in self.path else [self.path]
return [DataConfig(path=path, label=self.label, selector=self.selector) for path in paths]
@staticmethod
def expand(data_config: "DataConfig", fs: fsspec.AbstractFileSystem | None = None) -> list["DataConfig"]:
fs = fs or fsspec.get_filesystem_class(urlparse(data_config.path).scheme)()
assert fs is not None, f"Could not determine filesystem for {data_config.path}"
paths = [str(p) for p in fs.glob(data_config.path)] if "*" in data_config.path else [data_config.path]
return [DataConfig(path=path, label=data_config.label, selector=data_config.selector) for path in paths]


class ClassifierDataset(Dataset):
Expand All @@ -58,19 +63,22 @@ def __init__(
workers: int = 1,
):
with multiprocessing.Pool(workers) as pool:
expanded_configs = list(
tqdm(
pool.imap_unordered(lambda c: c.expand(), configs),
expanded_configs: list[DataConfig] = [
data_config
for data_configs in tqdm(
pool.imap_unordered(DataConfig.expand, configs),
total=len(configs),
desc="Expanding configs",
)
)
for data_config in data_configs
]

with multiprocessing.Pool(workers) as pool:
self.documents = list(
tqdm(
pool.imap_unordered(
lambda c: read_file(path=c.path, label=c.label, selector=c.selector), expanded_configs
lambda c: read_file(path=c.path, label=c.label, selector=c.selector),
expanded_configs
),
total=len(expanded_configs),
desc="Reading files",
Expand Down
3 changes: 3 additions & 0 deletions classifiers/src/dolma_classifiers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def get_local_gpu_rank() -> int:
def setup() -> tuple[int, int]:
if (rank := os.environ.get("RANK")) and (world_size := os.environ.get("WORLD_SIZE")):
dist.init_process_group("nccl", rank=int(rank), world_size=int(world_size))

os.environ["CUDA_VISIBLE_DEVICES"] = str(get_local_gpu_rank())

return get_rank_and_world_size()


Expand Down

0 comments on commit 46e65a6

Please sign in to comment.