diff --git a/tests/func/test_pytorch.py b/tests/func/test_pytorch.py index 64a32bd86..8782c8c8b 100644 --- a/tests/func/test_pytorch.py +++ b/tests/func/test_pytorch.py @@ -1,3 +1,5 @@ +from pathlib import Path + import open_clip import pytest from torch import Size, Tensor @@ -8,10 +10,10 @@ from datachain.lib.pytorch import PytorchDataset -@pytest.fixture -def fake_dataset(tmp_path, catalog): +@pytest.fixture(scope="module") +def fake_dataset(tmpdir_factory): # Create fake images in labeled dirs - data_path = tmp_path / "data" / "" + data_path = Path(tmpdir_factory.mktemp("data")) for i, (img, label) in enumerate(FakeData()): label = str(label) (data_path / label).mkdir(parents=True, exist_ok=True) @@ -37,11 +39,11 @@ def test_pytorch_dataset(fake_dataset): transform=transform, tokenizer=tokenizer, ) - for img, text, label in pt_dataset: - assert isinstance(img, Tensor) - assert isinstance(text, Tensor) - assert isinstance(label, int) - assert img.size() == Size([3, 64, 64]) + img, text, label = next(iter(pt_dataset)) + assert isinstance(img, Tensor) + assert isinstance(text, Tensor) + assert isinstance(label, int) + assert img.size() == Size([3, 64, 64]) def test_pytorch_dataset_sample(fake_dataset): @@ -62,8 +64,8 @@ def test_to_pytorch(fake_dataset): tokenizer = open_clip.get_tokenizer("ViT-B-32") pt_dataset = fake_dataset.to_pytorch(transform=transform, tokenizer=tokenizer) assert isinstance(pt_dataset, IterableDataset) - for img, text, label in pt_dataset: - assert isinstance(img, Tensor) - assert isinstance(text, Tensor) - assert isinstance(label, int) - assert img.size() == Size([3, 64, 64]) + img, text, label = next(iter(pt_dataset)) + assert isinstance(img, Tensor) + assert isinstance(text, Tensor) + assert isinstance(label, int) + assert img.size() == Size([3, 64, 64]) diff --git a/tests/unit/lib/conftest.py b/tests/unit/lib/conftest.py new file mode 100644 index 000000000..36d691f60 --- /dev/null +++ b/tests/unit/lib/conftest.py @@ -0,0 +1,21 @@ +import pytest +import torch +from torch import float32 +from torchvision.transforms import v2 + + +@pytest.fixture(scope="session") +def fake_clip_model(): + class Model: + def encode_image(self, tensor): + return torch.randn(len(tensor), 512) + + def encode_text(self, tensor): + return torch.randn(len(tensor), 512) + + def tokenizer(tensor, context_length=77): + return torch.randn(len(tensor), context_length) + + model = Model() + preprocess = v2.ToDtype(float32, scale=True) + return model, preprocess, tokenizer diff --git a/tests/unit/lib/test_clip.py b/tests/unit/lib/test_clip.py index b39ac275f..29111467f 100644 --- a/tests/unit/lib/test_clip.py +++ b/tests/unit/lib/test_clip.py @@ -1,4 +1,3 @@ -import open_clip import pytest from PIL import Image from transformers import CLIPModel, CLIPProcessor @@ -7,10 +6,6 @@ IMAGES = [Image.new(mode="RGB", size=(64, 64)), Image.new(mode="RGB", size=(32, 32))] TEXTS = ["text1", "text2"] -MODEL, _, PREPROCESS = open_clip.create_model_and_transforms( - "ViT-B-32", pretrained="laion2b_s34b_b79k" -) -TOKENIZER = open_clip.get_tokenizer("ViT-B-32") @pytest.mark.parametrize( @@ -20,15 +15,16 @@ @pytest.mark.parametrize("text", [None, "text", TEXTS]) @pytest.mark.parametrize("prob", [True, False]) @pytest.mark.parametrize("image_to_text", [True, False]) -def test_similarity_scores(images, text, prob, image_to_text): +def test_similarity_scores(fake_clip_model, images, text, prob, image_to_text): + model, preprocess, tokenizer = fake_clip_model if not (images or text): with pytest.raises(ValueError): scores = similarity_scores( - images, text, MODEL, PREPROCESS, TOKENIZER, prob, image_to_text + images, text, model, preprocess, tokenizer, prob, image_to_text ) else: scores = similarity_scores( - images, text, MODEL, PREPROCESS, TOKENIZER, prob, image_to_text + images, text, model, preprocess, tokenizer, prob, image_to_text ) assert isinstance(scores, list) if not images: diff --git a/tests/unit/lib/test_text.py b/tests/unit/lib/test_text.py index a19a6cebb..6ac7a3313 100644 --- a/tests/unit/lib/test_text.py +++ b/tests/unit/lib/test_text.py @@ -1,4 +1,3 @@ -import open_clip import torch from transformers import CLIPModel, CLIPProcessor @@ -6,10 +5,9 @@ from datachain.lib.text import convert_text -def test_convert_text(): +def test_convert_text(fake_clip_model): text = "thisismytext" - tokenizer_model = "ViT-B-32" - tokenizer = open_clip.get_tokenizer(tokenizer_model) + model, _, tokenizer = fake_clip_model converted_text = convert_text(text, tokenizer=tokenizer) assert isinstance(converted_text, torch.Tensor) @@ -22,7 +20,6 @@ def test_convert_text(): converted_text = convert_text( text, tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs ) - model, _, _ = open_clip.create_model_and_transforms(tokenizer_model) converted_text = convert_text(text, tokenizer=tokenizer, encoder=model.encode_text) assert converted_text.dtype == torch.float32