forked from milvus-io/bootcamp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_main.py
78 lines (63 loc) · 2.68 KB
/
test_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from milvus_helpers import MilvusHelper
from load import insert_data, create_index
from performance_test import performance, percentile_test
from recall_test import recall
import gdown
import os
import tarfile
client = MilvusHelper()
collection_name = 'pytest'
def test_create_collection():
assert client.create_collection(collection_name) == True
assert client.has_collection(collection_name) == True
def test_insert():
if not os.path.exists('milvus_sift1m.tar.gz'):
url = "https://drive.google.com/uc?id=1lwS6iAN-ic3s6LLMlUhyxzPlIk7yV7et"
gdown.download(url)
if not os.path.exists('sift_data'):
os.mkdir('sift_data')
file = tarfile.open('milvus_sift1m.tar.gz')
file.extractall('sift_data')
insert_data(client, collection_name)
assert client.count(collection_name) == 1000000
def test_search_performance():
if not os.path.exists('query_data.tar.gz'):
url = "https://drive.google.com/uc?id=17jPDk93PQsB5yGh1J1YD9N7X8jvPEUQL"
gdown.download(url)
if not os.path.exists('sift_data'):
os.mkdir('sift_data')
file = tarfile.open('query_data.tar.gz')
file.extractall('sift_data')
client.load_data(collection_name)
load_progress = client.get_loading_progress(collection_name)
assert load_progress['num_loaded_entities'] == load_progress['num_total_entities']
performance(client, collection_name, 0)
result_file = 'performance/pytest_0_performance.csv'
assert os.path.exists(result_file)
assert len(open(result_file, 'r').readlines()) == 13
def test_search_recall():
if not os.path.exists('gnd.tar.gz'):
url = "https://drive.google.com/uc?id=1vBP9mKu5oxyognHtOBBRtLvyPvo8cCp0"
gdown.download(url)
if not os.path.exists('sift_data'):
os.mkdir('sift_data')
file = tarfile.open('gnd.tar.gz')
file.extractall('sift_data')
assert recall(client, collection_name, 0) == [1, 1, 1, 1]
client.release_data(collection_name)
load_progress = client.get_loading_progress(collection_name)
assert load_progress['num_loaded_entities'] == 0
def test_create_index():
index_type = "IVF_FLAT"
create_index(client, collection_name, index_type)
index_info = client.get_index_params(collection_name)
assert index_info[0]['index_type'] == 'IVF_FLAT'
index_progress = client.get_index_progress(collection_name)
assert index_progress['total_rows'] == index_progress['indexed_rows']
def test_drop_index():
client.delete_index(collection_name)
index_info = client.get_index_params(collection_name)
assert index_info == []
def test_drop_collection():
client.delete_collection(collection_name)
assert client.has_collection(collection_name) == False