Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayballal95 committed Nov 9, 2024
1 parent cd84765 commit 4eb10b5
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pdf-extract = "0.7.7"
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.7.2" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.7.2" }
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.7.2" }
ort = {version = "=2.0.0-rc.6", features = ["cuda"]}
ort = {version = "=2.0.0-rc.8", features = ["cuda"]}
strum = "0.26.1"
strum_macros = "0.26"

Expand Down
4 changes: 3 additions & 1 deletion tests/model_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
ColpaliModel
)

from embed_anything import ONNXModel


@pytest.fixture
def clip_model() -> EmbeddingModel:
Expand Down Expand Up @@ -88,7 +90,7 @@ def openai_model() -> EmbeddingModel:

@pytest.fixture
def onnx_model() -> EmbeddingModel:
model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, "BGESmallENV15Q")
model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, ONNXModel.AllMiniLML6V2Q)
return model

@pytest.fixture
Expand Down
3 changes: 2 additions & 1 deletion tests/model_tests/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
embed_query,
embed_file,
embed_directory,
ONNXModel
)

import os
Expand Down Expand Up @@ -45,7 +46,7 @@ def test_bert_model_creation():
assert model is not None

def test_onnx_model_creation():
model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, "BGESmallENV15Q")
model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, ONNXModel.AllMiniLML6V2Q)
assert model is not None

@model_fixture_parametrize
Expand Down

0 comments on commit 4eb10b5

Please sign in to comment.