-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathcolbert_index.py
98 lines (81 loc) · 2.19 KB
/
colbert_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
os.environ['CUDA_HOME'] = os.environ['CONDA_PREFIX']
os.environ['LIBRARY_PATH'] = os.environ['CONDA_PREFIX']+"/lib"
os.environ['LD_LIBRARY_PATH'] = os.environ['CONDA_PREFIX']+"/lib"
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Collection
from colbert import Indexer
from argparse import ArgumentParser
if __name__ == "__main__":
parser = ArgumentParser(
prog='ColBERT Indexing',
description='Indexing a collection of text.'
)
parser.add_argument(
'-b', '--nbits',
type=int,
default=2,
help='each dimension encoding bits'
)
parser.add_argument(
'-l', '--maxlen',
type=int,
default=300,
help='max number of tokens per document'
)
parser.add_argument(
'-c', '--checkpoint',
type=str,
default="colbertv2.0",
help='trained checkpoint to use'
)
parser.add_argument(
'-p', '--path',
type=str,
default="fringe_collection.tsv",
help='collection path'
)
parser.add_argument(
'-i', '--index',
type=str,
default="index_fringe",
help='new index name'
)
parser.add_argument(
'-e', '--experiment',
type=str,
default="exp_fringe",
help='experiment name'
)
parser.add_argument(
'-r', '--ranks',
type=int,
default=1,
help='number of GPUs to use'
)
parser.add_argument(
'-s', '--bsize',
type=int,
default=16,
help='batch size'
)
args = parser.parse_args()
collection = Collection(path=args.path)
with Run().context(
RunConfig(nranks=args.ranks, experiment=args.experiment)
):
config = ColBERTConfig(
doc_maxlen=args.maxlen,
nbits=args.nbits,
index_bsize=args.bsize
)
indexer = Indexer(
checkpoint=args.checkpoint,
config=config
)
indexer.index(
name=args.index,
collection=collection,
overwrite=True
)
print("Index created at: " + indexer.get_index())