Skip to content

Commit

Permalink
optimize clip tests (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Berenbaum authored Jul 12, 2024
1 parent 4190e08 commit ee3d16c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 26 deletions.
28 changes: 15 additions & 13 deletions tests/func/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import open_clip
import pytest
from torch import Size, Tensor
Expand All @@ -8,10 +10,10 @@
from datachain.lib.pytorch import PytorchDataset


@pytest.fixture
def fake_dataset(tmp_path, catalog):
@pytest.fixture(scope="module")
def fake_dataset(tmpdir_factory):
# Create fake images in labeled dirs
data_path = tmp_path / "data" / ""
data_path = Path(tmpdir_factory.mktemp("data"))
for i, (img, label) in enumerate(FakeData()):
label = str(label)
(data_path / label).mkdir(parents=True, exist_ok=True)
Expand All @@ -37,11 +39,11 @@ def test_pytorch_dataset(fake_dataset):
transform=transform,
tokenizer=tokenizer,
)
for img, text, label in pt_dataset:
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
img, text, label = next(iter(pt_dataset))
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])


def test_pytorch_dataset_sample(fake_dataset):
Expand All @@ -62,8 +64,8 @@ def test_to_pytorch(fake_dataset):
tokenizer = open_clip.get_tokenizer("ViT-B-32")
pt_dataset = fake_dataset.to_pytorch(transform=transform, tokenizer=tokenizer)
assert isinstance(pt_dataset, IterableDataset)
for img, text, label in pt_dataset:
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
img, text, label = next(iter(pt_dataset))
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
21 changes: 21 additions & 0 deletions tests/unit/lib/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch
from torch import float32
from torchvision.transforms import v2


@pytest.fixture(scope="session")
def fake_clip_model():
class Model:
def encode_image(self, tensor):
return torch.randn(len(tensor), 512)

def encode_text(self, tensor):
return torch.randn(len(tensor), 512)

def tokenizer(tensor, context_length=77):
return torch.randn(len(tensor), context_length)

model = Model()
preprocess = v2.ToDtype(float32, scale=True)
return model, preprocess, tokenizer
12 changes: 4 additions & 8 deletions tests/unit/lib/test_clip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import open_clip
import pytest
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
Expand All @@ -7,10 +6,6 @@

IMAGES = [Image.new(mode="RGB", size=(64, 64)), Image.new(mode="RGB", size=(32, 32))]
TEXTS = ["text1", "text2"]
MODEL, _, PREPROCESS = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="laion2b_s34b_b79k"
)
TOKENIZER = open_clip.get_tokenizer("ViT-B-32")


@pytest.mark.parametrize(
Expand All @@ -20,15 +15,16 @@
@pytest.mark.parametrize("text", [None, "text", TEXTS])
@pytest.mark.parametrize("prob", [True, False])
@pytest.mark.parametrize("image_to_text", [True, False])
def test_similarity_scores(images, text, prob, image_to_text):
def test_similarity_scores(fake_clip_model, images, text, prob, image_to_text):
model, preprocess, tokenizer = fake_clip_model
if not (images or text):
with pytest.raises(ValueError):
scores = similarity_scores(
images, text, MODEL, PREPROCESS, TOKENIZER, prob, image_to_text
images, text, model, preprocess, tokenizer, prob, image_to_text
)
else:
scores = similarity_scores(
images, text, MODEL, PREPROCESS, TOKENIZER, prob, image_to_text
images, text, model, preprocess, tokenizer, prob, image_to_text
)
assert isinstance(scores, list)
if not images:
Expand Down
7 changes: 2 additions & 5 deletions tests/unit/lib/test_text.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import open_clip
import torch
from transformers import CLIPModel, CLIPProcessor

from datachain.lib.file import TextFile
from datachain.lib.text import convert_text


def test_convert_text():
def test_convert_text(fake_clip_model):
text = "thisismytext"
tokenizer_model = "ViT-B-32"
tokenizer = open_clip.get_tokenizer(tokenizer_model)
model, _, tokenizer = fake_clip_model
converted_text = convert_text(text, tokenizer=tokenizer)
assert isinstance(converted_text, torch.Tensor)

Expand All @@ -22,7 +20,6 @@ def test_convert_text():
converted_text = convert_text(
text, tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs
)
model, _, _ = open_clip.create_model_and_transforms(tokenizer_model)
converted_text = convert_text(text, tokenizer=tokenizer, encoder=model.encode_text)
assert converted_text.dtype == torch.float32

Expand Down

0 comments on commit ee3d16c

Please sign in to comment.